26
26
from paddle .io import Dataset , BatchSampler , DistributedBatchSampler , DataLoader
27
27
from paddlenlp .transformers import AutoModelForSequenceClassification , AutoTokenizer
28
28
from paddlenlp .transformers import LinearDecayWithWarmup
29
- from paddlenlp .datasets import load_dataset
30
- from paddlenlp .data import Stack , Tuple , Pad
29
+ from datasets import load_dataset
31
30
from paddle .metric import Accuracy
32
31
from paddlenlp .ops .optimizer import layerwise_lr_decay
33
32
from paddle .optimizer import AdamW
33
+ from paddlenlp .data import DataCollatorWithPadding
34
34
35
35
all_languages = [
36
36
"ar" , "bg" , "de" , "el" , "en" , "es" , "fr" , "hi" , "ru" , "sw" , "th" , "tr" ,
@@ -137,6 +137,9 @@ def parse_args():
137
137
type = str ,
138
138
choices = ["cpu" , "gpu" , "xpu" ],
139
139
help = "The device to select to train the model, is must be cpu/gpu/xpu." )
140
+ parser .add_argument ("--overwrite_cache" ,
141
+ action = "store_true" ,
142
+ help = "Whether to overwrite cache for dataset." )
140
143
parser .add_argument ("--use_amp" ,
141
144
type = distutils .util .strtobool ,
142
145
default = False ,
@@ -164,8 +167,8 @@ def evaluate(model, loss_fct, metric, data_loader, language):
164
167
model .eval ()
165
168
metric .reset ()
166
169
for batch in data_loader :
167
- input_ids , position_ids , attention_mask , labels = batch
168
- logits = model (input_ids , position_ids , attention_mask )
170
+ labels = batch . pop ( "labels" )
171
+ logits = model (** batch )
169
172
loss = loss_fct (logits , labels )
170
173
correct = metric .compute (logits , labels )
171
174
metric .update (correct )
@@ -178,21 +181,25 @@ def evaluate(model, loss_fct, metric, data_loader, language):
178
181
179
182
def convert_example (example , tokenizer , max_seq_length = 256 ):
180
183
"""convert a example into necessary features"""
181
- # Get the label
182
- label = example ["label" ]
183
- premise = example ["premise" ]
184
- hypothesis = example ["hypothesis" ]
185
184
# Convert raw text to feature
186
- example = tokenizer (premise ,
187
- text_pair = hypothesis ,
188
- max_seq_len = max_seq_length )
189
- return example ["input_ids" ], example ["position_ids" ], example [
190
- "attention_mask" ], label
191
-
192
-
193
- def get_test_dataloader (args , language , batchify_fn , trans_func ):
194
- test_ds = load_dataset ("xnli" , language , splits = "test" )
195
- test_ds = test_ds .map (trans_func , lazy = True )
185
+ tokenized_example = tokenizer (example ["premise" ],
186
+ text_pair = example ["hypothesis" ],
187
+ max_length = max_seq_length ,
188
+ padding = False ,
189
+ truncation = True ,
190
+ return_position_ids = True ,
191
+ return_attention_mask = True ,
192
+ return_token_type_ids = False )
193
+ return tokenized_example
194
+
195
+
196
+ def get_test_dataloader (args , language , batchify_fn , trans_func ,
197
+ remove_columns ):
198
+ test_ds = load_dataset ("xnli" , language , split = "test" )
199
+ test_ds = test_ds .map (trans_func ,
200
+ batched = True ,
201
+ remove_columns = remove_columns ,
202
+ load_from_cache_file = not args .overwrite_cache )
196
203
test_batch_sampler = BatchSampler (test_ds ,
197
204
batch_size = args .batch_size ,
198
205
shuffle = False )
@@ -220,11 +227,7 @@ def __getitem__(self, idx):
220
227
last = language_idx - 1 if language_idx > 0 else language_idx
221
228
sample_idx = idx - self .cumsum_len [last ] if idx >= self .cumsum_len [
222
229
last ] else idx
223
- input_ids = self .datasets [language_idx ][sample_idx ][0 ]
224
- position_ids = self .datasets [language_idx ][sample_idx ][1 ]
225
- attention_mask = self .datasets [language_idx ][sample_idx ][2 ]
226
- label = self .datasets [language_idx ][sample_idx ][3 ]
227
- return input_ids , position_ids , attention_mask , label
230
+ return self .datasets [int (language_idx )][int (sample_idx )]
228
231
229
232
def __len__ (self ):
230
233
return self .cumsum_len [- 1 ]
@@ -240,25 +243,28 @@ def do_train(args):
240
243
trans_func = partial (convert_example ,
241
244
tokenizer = tokenizer ,
242
245
max_seq_length = args .max_seq_length )
246
+ remove_columns = ["premise" , "hypothesis" ]
243
247
if args .task_type == "cross-lingual-transfer" :
244
- train_ds = load_dataset ("xnli" , "en" , splits = "train" )
245
- train_ds = train_ds .map (trans_func , lazy = True )
248
+ train_ds = load_dataset ("xnli" , "en" , split = "train" )
249
+ train_ds = train_ds .map (trans_func ,
250
+ batched = True ,
251
+ remove_columns = remove_columns ,
252
+ load_from_cache_file = not args .overwrite_cache )
246
253
elif args .task_type == "translate-train-all" :
247
254
all_train_ds = []
248
255
for language in all_languages :
249
- train_ds = load_dataset ("xnli" , language , splits = "train" )
250
- all_train_ds .append (train_ds .map (trans_func , lazy = True ))
256
+ train_ds = load_dataset ("xnli" , language , split = "train" )
257
+ all_train_ds .append (
258
+ train_ds .map (trans_func ,
259
+ batched = True ,
260
+ remove_columns = remove_columns ,
261
+ load_from_cache_file = not args .overwrite_cache ))
251
262
train_ds = XnliDataset (all_train_ds )
252
263
train_batch_sampler = DistributedBatchSampler (train_ds ,
253
264
batch_size = args .batch_size ,
254
265
shuffle = True )
255
- batchify_fn = lambda samples , fn = Tuple (
256
- Pad (axis = 0 , pad_val = tokenizer .pad_token_id , dtype = "int64" ), # input_ids
257
- Pad (axis = 0 , pad_val = tokenizer .pad_token_id , dtype = "int64"
258
- ), # position_ids
259
- Pad (axis = 0 , pad_val = 0 , dtype = "int64" ), # attention_mask
260
- Stack (dtype = "int64" ) # labels
261
- ): fn (samples )
266
+ batchify_fn = DataCollatorWithPadding (tokenizer )
267
+
262
268
train_data_loader = DataLoader (dataset = train_ds ,
263
269
batch_sampler = train_batch_sampler ,
264
270
collate_fn = batchify_fn ,
@@ -318,11 +324,11 @@ def do_train(args):
318
324
for epoch in range (num_train_epochs ):
319
325
for step , batch in enumerate (train_data_loader ):
320
326
global_step += 1
321
- input_ids , position_ids , attention_mask , labels = batch
327
+ labels = batch . pop ( "labels" )
322
328
with paddle .amp .auto_cast (
323
329
args .use_amp ,
324
330
custom_white_list = ["layer_norm" , "softmax" , "gelu" ]):
325
- logits = model (input_ids , position_ids , attention_mask )
331
+ logits = model (** batch )
326
332
loss = loss_fct (logits , labels )
327
333
if args .use_amp :
328
334
scaled_loss = scaler .scale (loss )
@@ -344,7 +350,7 @@ def do_train(args):
344
350
for language in all_languages :
345
351
tic_eval = time .time ()
346
352
test_data_loader = get_test_dataloader (
347
- args , language , batchify_fn , trans_func )
353
+ args , language , batchify_fn , trans_func , remove_columns )
348
354
evaluate (model , loss_fct , metric , test_data_loader ,
349
355
language )
350
356
print ("eval done total : %s s" % (time .time () - tic_eval ))
0 commit comments