Skip to content

Commit c51ab42

Browse files
authored
Merge pull request #1784 from luotao1/beam
add seqtext_print for seqToseq demo
2 parents 92edc2d + 555b2df commit c51ab42

File tree

2 files changed

+48
-7
lines changed

2 files changed

+48
-7
lines changed

demo/seqToseq/api_train_v2.py

+35-4
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def gru_decoder_with_attention(enc_vec, enc_proj, current_word):
126126

127127
def main():
128128
paddle.init(use_gpu=False, trainer_count=1)
129-
is_generating = True
129+
is_generating = False
130130

131131
# source and target dict dim.
132132
dict_size = 30000
@@ -167,16 +167,47 @@ def event_handler(event):
167167

168168
# generate a english sequence to french
169169
else:
170-
gen_creator = paddle.dataset.wmt14.test(dict_size)
170+
# use the first 3 samples for generation
171+
gen_creator = paddle.dataset.wmt14.gen(dict_size)
171172
gen_data = []
173+
gen_num = 3
172174
for item in gen_creator():
173175
gen_data.append((item[0], ))
174-
if len(gen_data) == 3:
176+
if len(gen_data) == gen_num:
175177
break
176178

177179
beam_gen = seqToseq_net(source_dict_dim, target_dict_dim, is_generating)
180+
# get the pretrained model, whose bleu = 26.92
178181
parameters = paddle.dataset.wmt14.model()
179-
trg_dict = paddle.dataset.wmt14.trg_dict(dict_size)
182+
# prob is the prediction probabilities, and id is the prediction word.
183+
beam_result = paddle.infer(
184+
output_layer=beam_gen,
185+
parameters=parameters,
186+
input=gen_data,
187+
field=['prob', 'id'])
188+
189+
# get the dictionary
190+
src_dict, trg_dict = paddle.dataset.wmt14.get_dict(dict_size)
191+
192+
# the delimited element of generated sequences is -1,
193+
# the first element of each generated sequence is the sequence length
194+
seq_list = []
195+
seq = []
196+
for w in beam_result[1]:
197+
if w != -1:
198+
seq.append(w)
199+
else:
200+
seq_list.append(' '.join([trg_dict.get(w) for w in seq[1:]]))
201+
seq = []
202+
203+
prob = beam_result[0]
204+
beam_size = 3
205+
for i in xrange(gen_num):
206+
print "\n*******************************************************\n"
207+
print "src:", ' '.join(
208+
[src_dict.get(w) for w in gen_data[i][0]]), "\n"
209+
for j in xrange(beam_size):
210+
print "prob = %f:" % (prob[i][j]), seq_list[i * beam_size + j]
180211

181212

182213
if __name__ == '__main__':

python/paddle/v2/dataset/wmt14.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5'
2727
# this is a small set of data for test. The original data is too large and will be add later.
2828
URL_TRAIN = 'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz'
29-
MD5_TRAIN = 'a755315dd01c2c35bde29a744ede23a6'
29+
MD5_TRAIN = '0791583d57d5beb693b9414c5b36798c'
3030
# this is the pretrained model, whose bleu = 26.92
3131
URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz'
3232
MD5_MODEL = '4ce14a26607fb8a1cc23bcdedb1895e4'
@@ -108,17 +108,27 @@ def test(dict_size):
108108
download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'test/test', dict_size)
109109

110110

111+
def gen(dict_size):
112+
return reader_creator(
113+
download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'gen/gen', dict_size)
114+
115+
111116
def model():
112117
tar_file = download(URL_MODEL, 'wmt14', MD5_MODEL)
113118
with gzip.open(tar_file, 'r') as f:
114119
parameters = Parameters.from_tar(f)
115120
return parameters
116121

117122

118-
def trg_dict(dict_size):
123+
def get_dict(dict_size, reverse=True):
124+
# if reverse = False, return dict = {'a':'001', 'b':'002', ...}
125+
# else reverse = true, return dict = {'001':'a', '002':'b', ...}
119126
tar_file = download(URL_TRAIN, 'wmt14', MD5_TRAIN)
120127
src_dict, trg_dict = __read_to_dict__(tar_file, dict_size)
121-
return trg_dict
128+
if reverse:
129+
src_dict = {v: k for k, v in src_dict.items()}
130+
trg_dict = {v: k for k, v in trg_dict.items()}
131+
return src_dict, trg_dict
122132

123133

124134
def fetch():

0 commit comments

Comments
 (0)