Skip to content

Commit 3dd0bd4

Browse files
authored
fix unified transformer dtype problem (PaddlePaddle#747)
* fix unified transformer dtype problem * fix win dtype bug
1 parent f61a294 commit 3dd0bd4

File tree

1 file changed

+4
-2
lines changed
  • examples/dialogue/unified_transformer

1 file changed

+4
-2
lines changed

examples/dialogue/unified_transformer/utils.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,10 @@ def pad_mask(batch_attention_mask):
124124
(max_len - example['seq_len']) + i * max_len
125125
for i, example in enumerate(batch_examples)
126126
])
127-
labels = np.concatenate(
128-
[np.array(example['labels']) for example in batch_examples])
127+
labels = np.concatenate([
128+
np.array(
129+
example['labels'], dtype='int64') for example in batch_examples
130+
])
129131
return input_ids, token_type_ids, position_ids, attention_mask, masked_positions, labels
130132
else:
131133
return input_ids, token_type_ids, position_ids, attention_mask

0 commit comments

Comments
 (0)