Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Commit ec22004

Browse files
Merge pull request #2 from mapingshuo/sequence_pad
add sequence_pad example
2 parents e123934 + 2c7e79f commit ec22004

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# api: paddle.fluid.layers.sequence_pad
16+
# env: local
17+
# device: cpu
18+
# text: sequence-pad
19+
20+
import paddle.fluid as fluid
21+
import numpy
22+
23+
x = fluid.layers.data(name="question", shape=[1], dtype="int64", lod_level=1)
24+
25+
# define net here
26+
embed = fluid.layers.embedding(input=x, size=[32, 2],
27+
param_attr=fluid.ParamAttr(name='emb.w'))
28+
29+
pad_value = fluid.layers.assign(input=numpy.array([0], dtype=numpy.float32))
30+
z, mask = fluid.layers.sequence_pad(x=embed, pad_value=pad_value)
31+
32+
place = fluid.CPUPlace()
33+
exe = fluid.Executor(place)
34+
feeder = fluid.DataFeeder(feed_list=[x], place=place)
35+
exe.run(fluid.default_startup_program())
36+
37+
# prepare a batch of data
38+
data = [([0, 1, 2, 3, 3],), ([0, 1, 2],)]
39+
40+
mask_out, z_out = exe.run(fluid.default_main_program(),
41+
feed=feeder.feed(data),
42+
fetch_list=[mask, z],
43+
return_numpy=True)
44+
45+
print(mask_out)
46+
print(z_out)
47+
48+
#[[5]
49+
# [3]]
50+
#[[[ 0.03990805 -0.10303718]
51+
# [ 0.08801201 -0.30412018]
52+
# [ 0.0706093 -0.18075395]
53+
# [-0.0283702 0.01683199]
54+
# [-0.0283702 0.01683199]]
55+
56+
# [[ 0.03990805 -0.10303718]
57+
# [ 0.08801201 -0.30412018]
58+
# [ 0.0706093 -0.18075395]
59+
# [ 0. 0. ]
60+
# [ 0. 0. ]]]

0 commit comments

Comments
 (0)