@@ -28,17 +28,9 @@ def load_parameter(file_name, h, w):
28
28
return np .fromfile (f , dtype = np .float32 ).reshape (h , w )
29
29
30
30
31
- def db_lstm ():
31
+ def db_lstm (word , predicate , ctx_n2 , ctx_n1 , ctx_0 , ctx_p1 , ctx_p2 , mark ,
32
+ ** ignored ):
32
33
# 8 features
33
- word = fluid .layers .data (name = 'word_data' , shape = [1 ], dtype = 'int64' )
34
- predicate = fluid .layers .data (name = 'verb_data' , shape = [1 ], dtype = 'int64' )
35
- ctx_n2 = fluid .layers .data (name = 'ctx_n2_data' , shape = [1 ], dtype = 'int64' )
36
- ctx_n1 = fluid .layers .data (name = 'ctx_n1_data' , shape = [1 ], dtype = 'int64' )
37
- ctx_0 = fluid .layers .data (name = 'ctx_0_data' , shape = [1 ], dtype = 'int64' )
38
- ctx_p1 = fluid .layers .data (name = 'ctx_p1_data' , shape = [1 ], dtype = 'int64' )
39
- ctx_p2 = fluid .layers .data (name = 'ctx_p2_data' , shape = [1 ], dtype = 'int64' )
40
- mark = fluid .layers .data (name = 'mark_data' , shape = [1 ], dtype = 'int64' )
41
-
42
34
predicate_embedding = fluid .layers .embedding (
43
35
input = predicate ,
44
36
size = [pred_len , word_dim ],
@@ -120,8 +112,25 @@ def to_lodtensor(data, place):
120
112
121
113
def main ():
122
114
# define network topology
123
- feature_out = db_lstm ()
124
- target = fluid .layers .data (name = 'target' , shape = [1 ], dtype = 'int64' )
115
+ word = fluid .layers .data (
116
+ name = 'word_data' , shape = [1 ], dtype = 'int64' , lod_level = 1 )
117
+ predicate = fluid .layers .data (
118
+ name = 'verb_data' , shape = [1 ], dtype = 'int64' , lod_level = 1 )
119
+ ctx_n2 = fluid .layers .data (
120
+ name = 'ctx_n2_data' , shape = [1 ], dtype = 'int64' , lod_level = 1 )
121
+ ctx_n1 = fluid .layers .data (
122
+ name = 'ctx_n1_data' , shape = [1 ], dtype = 'int64' , lod_level = 1 )
123
+ ctx_0 = fluid .layers .data (
124
+ name = 'ctx_0_data' , shape = [1 ], dtype = 'int64' , lod_level = 1 )
125
+ ctx_p1 = fluid .layers .data (
126
+ name = 'ctx_p1_data' , shape = [1 ], dtype = 'int64' , lod_level = 1 )
127
+ ctx_p2 = fluid .layers .data (
128
+ name = 'ctx_p2_data' , shape = [1 ], dtype = 'int64' , lod_level = 1 )
129
+ mark = fluid .layers .data (
130
+ name = 'mark_data' , shape = [1 ], dtype = 'int64' , lod_level = 1 )
131
+ feature_out = db_lstm (** locals ())
132
+ target = fluid .layers .data (
133
+ name = 'target' , shape = [1 ], dtype = 'int64' , lod_level = 1 )
125
134
crf_cost = fluid .layers .linear_chain_crf (
126
135
input = feature_out ,
127
136
label = target ,
@@ -139,6 +148,11 @@ def main():
139
148
paddle .dataset .conll05 .test (), buf_size = 8192 ),
140
149
batch_size = BATCH_SIZE )
141
150
place = fluid .CPUPlace ()
151
+ feeder = fluid .DataFeeder (
152
+ feed_list = [
153
+ word , ctx_n2 , ctx_n1 , ctx_0 , ctx_p1 , ctx_p2 , predicate , mark , target
154
+ ],
155
+ place = place )
142
156
exe = fluid .Executor (place )
143
157
144
158
exe .run (fluid .default_startup_program ())
@@ -150,28 +164,8 @@ def main():
150
164
batch_id = 0
151
165
for pass_id in xrange (PASS_NUM ):
152
166
for data in train_data ():
153
- word_data = to_lodtensor (map (lambda x : x [0 ], data ), place )
154
- ctx_n2_data = to_lodtensor (map (lambda x : x [1 ], data ), place )
155
- ctx_n1_data = to_lodtensor (map (lambda x : x [2 ], data ), place )
156
- ctx_0_data = to_lodtensor (map (lambda x : x [3 ], data ), place )
157
- ctx_p1_data = to_lodtensor (map (lambda x : x [4 ], data ), place )
158
- ctx_p2_data = to_lodtensor (map (lambda x : x [5 ], data ), place )
159
- verb_data = to_lodtensor (map (lambda x : x [6 ], data ), place )
160
- mark_data = to_lodtensor (map (lambda x : x [7 ], data ), place )
161
- target = to_lodtensor (map (lambda x : x [8 ], data ), place )
162
-
163
167
outs = exe .run (fluid .default_main_program (),
164
- feed = {
165
- 'word_data' : word_data ,
166
- 'ctx_n2_data' : ctx_n2_data ,
167
- 'ctx_n1_data' : ctx_n1_data ,
168
- 'ctx_0_data' : ctx_0_data ,
169
- 'ctx_p1_data' : ctx_p1_data ,
170
- 'ctx_p2_data' : ctx_p2_data ,
171
- 'verb_data' : verb_data ,
172
- 'mark_data' : mark_data ,
173
- 'target' : target
174
- },
168
+ feed = feeder .feed (data ),
175
169
fetch_list = [avg_cost ])
176
170
avg_cost_val = np .array (outs [0 ])
177
171
0 commit comments