Skip to content

Commit a316a6a

Browse files
authored
[seq2seq docs] Move evaluation down, fix typo (#5365)
1 parent 4bcc35c commit a316a6a

File tree

1 file changed

+48
-46
lines changed

1 file changed

+48
-46
lines changed

examples/seq2seq/README.md

Lines changed: 48 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ Summarization support is more mature than translation support.
33
Please tag @sshleifer with any issues/unexpected behaviors, or send a PR!
44
For `bertabs` instructions, see `bertabs/README.md`.
55

6+
67
### Data
78

89
CNN/DailyMail data
@@ -37,50 +38,6 @@ export ENRO_DIR=${PWD}/wmt_en_ro
3738
If you are using your own data, it must be formatted as one directory with 6 files: train.source, train.target, val.source, val.target, test.source, test.target.
3839
The `.source` files are the input, the `.target` files are the desired output.
3940

40-
### Evaluation Commands
41-
42-
To create summaries for each article in dataset, we use `run_eval.py`, here are a few commands that run eval for different tasks and models.
43-
If 'translation' is in your task name, the computed metric will be BLEU. Otherwise, ROUGE will be used.
44-
45-
For t5, you need to specify --task translation_{src}_to_{tgt} as follows:
46-
```bash
47-
export DATA_DIR=wmt_en_ro
48-
python run_eval.py t5_base \
49-
$DATA_DIR/val.source mbart_val_generations.txt \
50-
--reference_path $DATA_DIR/val.target \
51-
--score_path enro_bleu.json \
52-
--task translation_en_to_ro \
53-
--n_obs 100 \
54-
--device cuda \
55-
--fp16 \
56-
--bs 32
57-
```
58-
59-
This command works for MBART, although the BLEU score is suspiciously low.
60-
```bash
61-
export DATA_DIR=wmt_en_ro
62-
python run_eval.py facebook/mbart-large-en-ro $DATA_DIR/val.source mbart_val_generations.txt \
63-
--reference_path $DATA_DIR/val.target \
64-
--score_path enro_bleu.json \
65-
--task translation \
66-
--n_obs 100 \
67-
--device cuda \
68-
--fp16 \
69-
--bs 32
70-
```
71-
72-
Summarization (xsum will be very similar):
73-
```bash
74-
export DATA_DIR=cnn_dm
75-
python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_generations.txt \
76-
--reference_path $DATA_DIR/val.target \
77-
--score_path cnn_rouge.json \
78-
--task summarization \
79-
--n_obs 100 \
80-
--device cuda \
81-
--fp16 \
82-
--bs 32
83-
```
8441

8542

8643
### Summarization Finetuning
@@ -147,8 +104,7 @@ from transformers import AutoModelForSeq2SeqLM
147104
model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr')
148105
```
149106

150-
151-
### XSUM Shared Task
107+
#### XSUM Shared Task
152108
Compare XSUM results with others by using `--logger wandb_shared`. This requires `wandb` registration.
153109

154110
Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier!
@@ -165,6 +121,52 @@ Here is an example command, but you can do whatever you want. Hopefully this wil
165121

166122
You can see your wandb logs [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-)
167123

124+
### Evaluation Commands
125+
126+
To create summaries for each article in dataset, we use `run_eval.py`, here are a few commands that run eval for different tasks and models.
127+
If 'translation' is in your task name, the computed metric will be BLEU. Otherwise, ROUGE will be used.
128+
129+
For t5, you need to specify --task translation_{src}_to_{tgt} as follows:
130+
```bash
131+
export DATA_DIR=wmt_en_ro
132+
python run_eval.py t5_base \
133+
$DATA_DIR/val.source t5_val_generations.txt \
134+
--reference_path $DATA_DIR/val.target \
135+
--score_path enro_bleu.json \
136+
--task translation_en_to_ro \
137+
--n_obs 100 \
138+
--device cuda \
139+
--fp16 \
140+
--bs 32
141+
```
142+
143+
This command works for MBART, although the BLEU score is suspiciously low.
144+
```bash
145+
export DATA_DIR=wmt_en_ro
146+
python run_eval.py facebook/mbart-large-en-ro $DATA_DIR/val.source mbart_val_generations.txt \
147+
--reference_path $DATA_DIR/val.target \
148+
--score_path enro_bleu.json \
149+
--task translation \
150+
--n_obs 100 \
151+
--device cuda \
152+
--fp16 \
153+
--bs 32
154+
```
155+
156+
Summarization (xsum will be very similar):
157+
```bash
158+
export DATA_DIR=cnn_dm
159+
python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_generations.txt \
160+
--reference_path $DATA_DIR/val.target \
161+
--score_path cnn_rouge.json \
162+
--task summarization \
163+
--n_obs 100 \
164+
--device cuda \
165+
--fp16 \
166+
--bs 32
167+
```
168+
169+
168170
### DistilBART
169171

170172
For the CNN/DailyMail dataset, (relatively longer, more extractive summaries), we found a simple technique that works:

0 commit comments

Comments
 (0)