Skip to content

Commit b34c74f

Browse files
authored
Update llama datacollator (PaddlePaddle#6301)
1 parent 9b993b2 commit b34c74f

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

examples/language_model/llama/finetune_generation.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,7 @@
1717
from functools import partial
1818

1919
import paddle
20-
from data import (
21-
DataCollatorForSupervisedDataset,
22-
convert_example,
23-
custom_instruction_convert_example,
24-
reader,
25-
)
20+
from data import convert_example, custom_instruction_convert_example, reader
2621
from modeling_pp import LlamaForCausalLMPipe
2722
from utils import (
2823
LlamaTrainer,
@@ -31,6 +26,7 @@
3126
save_infer_result,
3227
)
3328

29+
from paddlenlp.data import DataCollatorForSeq2Seq
3430
from paddlenlp.datasets import load_dataset
3531
from paddlenlp.peft import LoRAConfig, LoRAModel, PrefixConfig, PrefixModelForCausalLM
3632
from paddlenlp.peft.prefix import llama_postprocess_past_key_value
@@ -221,8 +217,10 @@ def main():
221217
dev_ds = dev_ds.map(partial(trans_func, is_test=is_test))
222218

223219
model_max_length = 1024 if not training_args.benchmark else 512
224-
collate_fn = DataCollatorForSupervisedDataset(
225-
tokenizer, max_length=model_max_length if data_args.always_pad_to_max_length else -1
220+
collate_fn = DataCollatorForSeq2Seq(
221+
return_tensors="pd",
222+
tokenizer=tokenizer,
223+
max_length=model_max_length if data_args.always_pad_to_max_length else -1,
226224
)
227225

228226
def compute_metrics_trainer(eval_preds, tokenizer):

0 commit comments

Comments
 (0)