@@ -948,15 +948,15 @@ def generate_export(
948
948
b , t = start_tokens .shape
949
949
950
950
self .net .eval ()
951
- out_tmp = start_tokens
951
+ out = start_tokens
952
952
mask = kwargs .pop ("mask" , None )
953
953
954
954
if mask is None :
955
- mask = paddle .full_like (out_tmp , True , dtype = paddle .bool )
955
+ mask = paddle .full_like (out , True , dtype = paddle .bool )
956
956
957
957
i_idx = paddle .full ([], 0 )
958
958
while i_idx < paddle .to_tensor (seq_len ):
959
- x = out_tmp [:, - self .max_seq_len :]
959
+ x = out [:, - self .max_seq_len :]
960
960
paddle .jit .api .set_dynamic_shape (x , [- 1 , - 1 ])
961
961
mask = mask [:, - self .max_seq_len :]
962
962
paddle .jit .api .set_dynamic_shape (mask , [- 1 , - 1 ])
@@ -969,7 +969,7 @@ def generate_export(
969
969
probs = F .softmax (filtered_logits / temperature , axis = - 1 )
970
970
971
971
sample = paddle .multinomial (probs , 1 )
972
- out = paddle .concat ((out_tmp , sample ), axis = - 1 )
972
+ out = paddle .concat ((out , sample ), axis = - 1 )
973
973
974
974
pad_mask = paddle .full (shape = [mask .shape [0 ], 1 ], fill_value = 1 , dtype = "bool" )
975
975
mask = paddle .concat ((mask , pad_mask ), axis = 1 )
0 commit comments