Skip to content

Commit 123c4c7

Browse files
author
gongenlei
authored
add hf ds and upgrade example (#2925)
1 parent 7932dd2 commit 123c4c7

File tree

2 files changed

+294
-37
lines changed

2 files changed

+294
-37
lines changed

model_zoo/ernie-m/run_classifier.py

Lines changed: 43 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@
2626
from paddle.io import Dataset, BatchSampler, DistributedBatchSampler, DataLoader
2727
from paddlenlp.transformers import AutoModelForSequenceClassification, AutoTokenizer
2828
from paddlenlp.transformers import LinearDecayWithWarmup
29-
from paddlenlp.datasets import load_dataset
30-
from paddlenlp.data import Stack, Tuple, Pad
29+
from datasets import load_dataset
3130
from paddle.metric import Accuracy
3231
from paddlenlp.ops.optimizer import layerwise_lr_decay
3332
from paddle.optimizer import AdamW
33+
from paddlenlp.data import DataCollatorWithPadding
3434

3535
all_languages = [
3636
"ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr",
@@ -137,6 +137,9 @@ def parse_args():
137137
type=str,
138138
choices=["cpu", "gpu", "xpu"],
139139
help="The device to select to train the model, is must be cpu/gpu/xpu.")
140+
parser.add_argument("--overwrite_cache",
141+
action="store_true",
142+
help="Whether to overwrite cache for dataset.")
140143
parser.add_argument("--use_amp",
141144
type=distutils.util.strtobool,
142145
default=False,
@@ -164,8 +167,8 @@ def evaluate(model, loss_fct, metric, data_loader, language):
164167
model.eval()
165168
metric.reset()
166169
for batch in data_loader:
167-
input_ids, position_ids, attention_mask, labels = batch
168-
logits = model(input_ids, position_ids, attention_mask)
170+
labels = batch.pop("labels")
171+
logits = model(**batch)
169172
loss = loss_fct(logits, labels)
170173
correct = metric.compute(logits, labels)
171174
metric.update(correct)
@@ -178,21 +181,25 @@ def evaluate(model, loss_fct, metric, data_loader, language):
178181

179182
def convert_example(example, tokenizer, max_seq_length=256):
180183
"""convert a example into necessary features"""
181-
# Get the label
182-
label = example["label"]
183-
premise = example["premise"]
184-
hypothesis = example["hypothesis"]
185184
# Convert raw text to feature
186-
example = tokenizer(premise,
187-
text_pair=hypothesis,
188-
max_seq_len=max_seq_length)
189-
return example["input_ids"], example["position_ids"], example[
190-
"attention_mask"], label
191-
192-
193-
def get_test_dataloader(args, language, batchify_fn, trans_func):
194-
test_ds = load_dataset("xnli", language, splits="test")
195-
test_ds = test_ds.map(trans_func, lazy=True)
185+
tokenized_example = tokenizer(example["premise"],
186+
text_pair=example["hypothesis"],
187+
max_length=max_seq_length,
188+
padding=False,
189+
truncation=True,
190+
return_position_ids=True,
191+
return_attention_mask=True,
192+
return_token_type_ids=False)
193+
return tokenized_example
194+
195+
196+
def get_test_dataloader(args, language, batchify_fn, trans_func,
197+
remove_columns):
198+
test_ds = load_dataset("xnli", language, split="test")
199+
test_ds = test_ds.map(trans_func,
200+
batched=True,
201+
remove_columns=remove_columns,
202+
load_from_cache_file=not args.overwrite_cache)
196203
test_batch_sampler = BatchSampler(test_ds,
197204
batch_size=args.batch_size,
198205
shuffle=False)
@@ -220,11 +227,7 @@ def __getitem__(self, idx):
220227
last = language_idx - 1 if language_idx > 0 else language_idx
221228
sample_idx = idx - self.cumsum_len[last] if idx >= self.cumsum_len[
222229
last] else idx
223-
input_ids = self.datasets[language_idx][sample_idx][0]
224-
position_ids = self.datasets[language_idx][sample_idx][1]
225-
attention_mask = self.datasets[language_idx][sample_idx][2]
226-
label = self.datasets[language_idx][sample_idx][3]
227-
return input_ids, position_ids, attention_mask, label
230+
return self.datasets[int(language_idx)][int(sample_idx)]
228231

229232
def __len__(self):
230233
return self.cumsum_len[-1]
@@ -240,25 +243,28 @@ def do_train(args):
240243
trans_func = partial(convert_example,
241244
tokenizer=tokenizer,
242245
max_seq_length=args.max_seq_length)
246+
remove_columns = ["premise", "hypothesis"]
243247
if args.task_type == "cross-lingual-transfer":
244-
train_ds = load_dataset("xnli", "en", splits="train")
245-
train_ds = train_ds.map(trans_func, lazy=True)
248+
train_ds = load_dataset("xnli", "en", split="train")
249+
train_ds = train_ds.map(trans_func,
250+
batched=True,
251+
remove_columns=remove_columns,
252+
load_from_cache_file=not args.overwrite_cache)
246253
elif args.task_type == "translate-train-all":
247254
all_train_ds = []
248255
for language in all_languages:
249-
train_ds = load_dataset("xnli", language, splits="train")
250-
all_train_ds.append(train_ds.map(trans_func, lazy=True))
256+
train_ds = load_dataset("xnli", language, split="train")
257+
all_train_ds.append(
258+
train_ds.map(trans_func,
259+
batched=True,
260+
remove_columns=remove_columns,
261+
load_from_cache_file=not args.overwrite_cache))
251262
train_ds = XnliDataset(all_train_ds)
252263
train_batch_sampler = DistributedBatchSampler(train_ds,
253264
batch_size=args.batch_size,
254265
shuffle=True)
255-
batchify_fn = lambda samples, fn=Tuple(
256-
Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"), # input_ids
257-
Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"
258-
), # position_ids
259-
Pad(axis=0, pad_val=0, dtype="int64"), # attention_mask
260-
Stack(dtype="int64") # labels
261-
): fn(samples)
266+
batchify_fn = DataCollatorWithPadding(tokenizer)
267+
262268
train_data_loader = DataLoader(dataset=train_ds,
263269
batch_sampler=train_batch_sampler,
264270
collate_fn=batchify_fn,
@@ -318,11 +324,11 @@ def do_train(args):
318324
for epoch in range(num_train_epochs):
319325
for step, batch in enumerate(train_data_loader):
320326
global_step += 1
321-
input_ids, position_ids, attention_mask, labels = batch
327+
labels = batch.pop("labels")
322328
with paddle.amp.auto_cast(
323329
args.use_amp,
324330
custom_white_list=["layer_norm", "softmax", "gelu"]):
325-
logits = model(input_ids, position_ids, attention_mask)
331+
logits = model(**batch)
326332
loss = loss_fct(logits, labels)
327333
if args.use_amp:
328334
scaled_loss = scaler.scale(loss)
@@ -344,7 +350,7 @@ def do_train(args):
344350
for language in all_languages:
345351
tic_eval = time.time()
346352
test_data_loader = get_test_dataloader(
347-
args, language, batchify_fn, trans_func)
353+
args, language, batchify_fn, trans_func, remove_columns)
348354
evaluate(model, loss_fct, metric, test_data_loader,
349355
language)
350356
print("eval done total : %s s" % (time.time() - tic_eval))

0 commit comments

Comments
 (0)