@@ -134,7 +134,7 @@ def run_worker(self):
134
134
if save_model_path and (not os .path .exists (save_model_path )):
135
135
os .makedirs (save_model_path )
136
136
137
- reader_type = self .config .get ("runner.reader_type" , None )
137
+ reader_type = self .config .get ("runner.reader_type" , "QueueDataset" )
138
138
epochs = int (self .config .get ("runner.epochs" ))
139
139
sync_mode = self .config .get ("runner.sync_mode" )
140
140
@@ -150,10 +150,6 @@ def run_worker(self):
150
150
self .dataset_train_loop (epoch )
151
151
elif reader_type == "InmemoryDataset" :
152
152
self .dataset_train_loop (epoch )
153
- elif reader_type == "DataLoader" :
154
- self .dataloader_train_loop (epoch )
155
- elif reader_type == None or reader_type == "RecDataset" :
156
- self .recdataset_train_loop (epoch )
157
153
158
154
epoch_time = time .time () - epoch_start_time
159
155
epoch_speed = self .example_nums / epoch_time
@@ -182,6 +178,8 @@ def run_worker(self):
182
178
def init_reader (self ):
183
179
if fleet .is_server ():
184
180
return
181
+ self .config ["runner.reader_type" ] = self .config .get (
182
+ "runner.reader_type" , "QueueDataset" )
185
183
self .reader , self .file_list = get_reader (self .input_data , config )
186
184
self .example_nums = 0
187
185
self .count_method = self .config .get ("runner.example_count_method" ,
@@ -222,91 +220,6 @@ def dataset_train_loop(self, epoch):
222
220
print_period = print_step ,
223
221
debug = debug )
224
222
225
- def dataloader_train_loop (self , epoch ):
226
- logger .info ("Epoch: {}, Running DataLoader Begin." .format (epoch ))
227
- batch_id = 0
228
- train_run_cost = 0.0
229
- total_examples = 0
230
- self .reader .start ()
231
- while True :
232
- try :
233
- train_start = time .time ()
234
- # --------------------------------------------------- #
235
- fetch_var = self .exe .run (
236
- program = paddle .static .default_main_program (),
237
- fetch_list = [var for _ , var in self .metrics .items ()])
238
- # --------------------------------------------------- #
239
- train_run_cost += time .time () - train_start
240
- total_examples += (self .config .get ("runner.train_batch_size" ))
241
- batch_id += 1
242
- print_step = int (config .get ("runner.print_interval" ))
243
- if batch_id % print_step == 0 :
244
- metrics_string = ""
245
- for var_idx , var_name in enumerate (self .metrics ):
246
- metrics_string += "{}: {}, " .format (
247
- var_name , fetch_var [var_idx ]
248
- if var_name != "LOSS" or not config ['pure_bf16' ]
249
- else bf16_to_fp32 (fetch_var [var_idx ][0 ]))
250
- profiler_string = ""
251
- profiler_string += "avg_batch_cost: {} sec, " .format (
252
- format ((train_run_cost ) / print_step , '.5f' ))
253
- profiler_string += "avg_samples: {}, " .format (
254
- format (total_examples / print_step , '.5f' ))
255
- profiler_string += "ips: {} {}/sec " .format (
256
- format (total_examples / (train_run_cost ), '.5f' ),
257
- self .count_method )
258
- logger .info ("Epoch: {}, Batch: {}, {} {}" .format (
259
- epoch , batch_id , metrics_string , profiler_string ))
260
- train_run_cost = 0.0
261
- total_examples = 0
262
- except paddle .fluid .core .EOFException :
263
- self .reader .reset ()
264
- break
265
-
266
- def recdataset_train_loop (self , epoch ):
267
- logger .info ("Epoch: {}, Running RecDatast Begin." .format (epoch ))
268
-
269
- input_data_names = [var .name for var in self .input_data ]
270
- batch_size = config .get ("runner.train_batch_size" , None )
271
- print_interval = config .get ("runner.print_interval" , None )
272
-
273
- batch_id = 0
274
- train_run_cost = 0.0
275
- train_reader_cost = 0.0
276
- total_samples = 0
277
- reader_start = time .time ()
278
- for batch_id , batch_data in enumerate (self .reader ()):
279
- train_reader_cost += time .time () - reader_start
280
- train_start = time .time ()
281
- # --------------------------------------------------- #
282
- fetch_batch_var = self .exe .run (
283
- program = paddle .static .default_main_program (),
284
- feed = dict (zip (input_data_names , batch_data )),
285
- fetch_list = [var for _ , var in self .metrics .items ()])
286
- # --------------------------------------------------- #
287
- train_run_cost += time .time () - train_start
288
- total_samples += batch_size
289
- if batch_id % print_interval == 0 :
290
- metric_str = ""
291
- for var_idx , var_name in enumerate (self .metrics ):
292
- metric_str += "{}: {}, " .format (
293
- var_name , fetch_batch_var [var_idx ]
294
- if var_name != "LOSS" or config ['pure_bf16' ] is False
295
- else bf16_to_fp32 (fetch_batch_var [var_idx ][0 ]))
296
- logger .info (
297
- "Epoch: {}, Batch_id: {}, " .format (epoch ,
298
- batch_id ) + metric_str +
299
- " avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} {}/sec"
300
- .format (train_reader_cost / print_interval , (
301
- train_reader_cost + train_run_cost ) / print_interval ,
302
- total_samples / print_interval , total_samples / (
303
- train_reader_cost + train_run_cost ),
304
- self .count_method ))
305
- train_reader_cost = 0.0
306
- train_run_cost = 0.0
307
- total_samples = 0
308
- reader_start = time .time ()
309
-
310
223
def heter_train_loop (self , epoch ):
311
224
logger .info (
312
225
"Epoch: {}, Running Begin. Check running metrics at heter_log" .
0 commit comments