@@ -126,7 +126,7 @@ def gru_decoder_with_attention(enc_vec, enc_proj, current_word):
126
126
127
127
def main ():
128
128
paddle .init (use_gpu = False , trainer_count = 1 )
129
- is_generating = True
129
+ is_generating = False
130
130
131
131
# source and target dict dim.
132
132
dict_size = 30000
@@ -167,16 +167,47 @@ def event_handler(event):
167
167
168
168
# generate a english sequence to french
169
169
else :
170
- gen_creator = paddle .dataset .wmt14 .test (dict_size )
170
+ # use the first 3 samples for generation
171
+ gen_creator = paddle .dataset .wmt14 .gen (dict_size )
171
172
gen_data = []
173
+ gen_num = 3
172
174
for item in gen_creator ():
173
175
gen_data .append ((item [0 ], ))
174
- if len (gen_data ) == 3 :
176
+ if len (gen_data ) == gen_num :
175
177
break
176
178
177
179
beam_gen = seqToseq_net (source_dict_dim , target_dict_dim , is_generating )
180
+ # get the pretrained model, whose bleu = 26.92
178
181
parameters = paddle .dataset .wmt14 .model ()
179
- trg_dict = paddle .dataset .wmt14 .trg_dict (dict_size )
182
+ # prob is the prediction probabilities, and id is the prediction word.
183
+ beam_result = paddle .infer (
184
+ output_layer = beam_gen ,
185
+ parameters = parameters ,
186
+ input = gen_data ,
187
+ field = ['prob' , 'id' ])
188
+
189
+ # get the dictionary
190
+ src_dict , trg_dict = paddle .dataset .wmt14 .get_dict (dict_size )
191
+
192
+ # the delimited element of generated sequences is -1,
193
+ # the first element of each generated sequence is the sequence length
194
+ seq_list = []
195
+ seq = []
196
+ for w in beam_result [1 ]:
197
+ if w != - 1 :
198
+ seq .append (w )
199
+ else :
200
+ seq_list .append (' ' .join ([trg_dict .get (w ) for w in seq [1 :]]))
201
+ seq = []
202
+
203
+ prob = beam_result [0 ]
204
+ beam_size = 3
205
+ for i in xrange (gen_num ):
206
+ print "\n *******************************************************\n "
207
+ print "src:" , ' ' .join (
208
+ [src_dict .get (w ) for w in gen_data [i ][0 ]]), "\n "
209
+ for j in xrange (beam_size ):
210
+ print "prob = %f:" % (prob [i ][j ]), seq_list [i * beam_size + j ]
180
211
181
212
182
213
if __name__ == '__main__' :
0 commit comments