Skip to content

Commit f496909

Browse files
LiuChiachiZeyuChen
andauthored
Fix seq2seq windows dtype bug (PaddlePaddle#1198)
* fix seq2seq windows dtype bug * fix Co-authored-by: Zeyu Chen <chenzeyu01@baidu.com>
1 parent f9f5537 commit f496909

File tree

1 file changed

+1
-1
lines changed
  • examples/machine_translation/seq2seq

1 file changed

+1
-1
lines changed

examples/machine_translation/seq2seq/data.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,6 @@ def prepare_train_input(insts, bos_id, eos_id, pad_id):
118118
src, src_length = Pad(pad_val=pad_id, ret_length=True)(
119119
[inst[0] for inst in insts])
120120
tgt, tgt_length = Pad(pad_val=pad_id, ret_length=True)(
121-
[inst[1] for inst in insts])
121+
[inst[1] for inst in insts], dtype="int64")
122122
tgt_mask = (tgt[:, :-1] != pad_id).astype("float32")
123123
return src, src_length, tgt[:, :-1], tgt[:, 1:, np.newaxis], tgt_mask

0 commit comments

Comments
 (0)