Skip to content

Commit d23d1db

Browse files
committed
remove exe.run in ps mode
1 parent e6e00ac commit d23d1db

File tree

1 file changed

+3
-90
lines changed

1 file changed

+3
-90
lines changed

tools/static_ps_trainer.py

100644100755
+3-90
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def run_worker(self):
134134
if save_model_path and (not os.path.exists(save_model_path)):
135135
os.makedirs(save_model_path)
136136

137-
reader_type = self.config.get("runner.reader_type", None)
137+
reader_type = self.config.get("runner.reader_type", "QueueDataset")
138138
epochs = int(self.config.get("runner.epochs"))
139139
sync_mode = self.config.get("runner.sync_mode")
140140

@@ -150,10 +150,6 @@ def run_worker(self):
150150
self.dataset_train_loop(epoch)
151151
elif reader_type == "InmemoryDataset":
152152
self.dataset_train_loop(epoch)
153-
elif reader_type == "DataLoader":
154-
self.dataloader_train_loop(epoch)
155-
elif reader_type == None or reader_type == "RecDataset":
156-
self.recdataset_train_loop(epoch)
157153

158154
epoch_time = time.time() - epoch_start_time
159155
epoch_speed = self.example_nums / epoch_time
@@ -182,6 +178,8 @@ def run_worker(self):
182178
def init_reader(self):
183179
if fleet.is_server():
184180
return
181+
self.config["runner.reader_type"] = self.config.get(
182+
"runner.reader_type", "QueueDataset")
185183
self.reader, self.file_list = get_reader(self.input_data, config)
186184
self.example_nums = 0
187185
self.count_method = self.config.get("runner.example_count_method",
@@ -222,91 +220,6 @@ def dataset_train_loop(self, epoch):
222220
print_period=print_step,
223221
debug=debug)
224222

225-
def dataloader_train_loop(self, epoch):
226-
logger.info("Epoch: {}, Running DataLoader Begin.".format(epoch))
227-
batch_id = 0
228-
train_run_cost = 0.0
229-
total_examples = 0
230-
self.reader.start()
231-
while True:
232-
try:
233-
train_start = time.time()
234-
# --------------------------------------------------- #
235-
fetch_var = self.exe.run(
236-
program=paddle.static.default_main_program(),
237-
fetch_list=[var for _, var in self.metrics.items()])
238-
# --------------------------------------------------- #
239-
train_run_cost += time.time() - train_start
240-
total_examples += (self.config.get("runner.train_batch_size"))
241-
batch_id += 1
242-
print_step = int(config.get("runner.print_interval"))
243-
if batch_id % print_step == 0:
244-
metrics_string = ""
245-
for var_idx, var_name in enumerate(self.metrics):
246-
metrics_string += "{}: {}, ".format(
247-
var_name, fetch_var[var_idx]
248-
if var_name != "LOSS" or not config['pure_bf16']
249-
else bf16_to_fp32(fetch_var[var_idx][0]))
250-
profiler_string = ""
251-
profiler_string += "avg_batch_cost: {} sec, ".format(
252-
format((train_run_cost) / print_step, '.5f'))
253-
profiler_string += "avg_samples: {}, ".format(
254-
format(total_examples / print_step, '.5f'))
255-
profiler_string += "ips: {} {}/sec ".format(
256-
format(total_examples / (train_run_cost), '.5f'),
257-
self.count_method)
258-
logger.info("Epoch: {}, Batch: {}, {} {}".format(
259-
epoch, batch_id, metrics_string, profiler_string))
260-
train_run_cost = 0.0
261-
total_examples = 0
262-
except paddle.fluid.core.EOFException:
263-
self.reader.reset()
264-
break
265-
266-
def recdataset_train_loop(self, epoch):
267-
logger.info("Epoch: {}, Running RecDatast Begin.".format(epoch))
268-
269-
input_data_names = [var.name for var in self.input_data]
270-
batch_size = config.get("runner.train_batch_size", None)
271-
print_interval = config.get("runner.print_interval", None)
272-
273-
batch_id = 0
274-
train_run_cost = 0.0
275-
train_reader_cost = 0.0
276-
total_samples = 0
277-
reader_start = time.time()
278-
for batch_id, batch_data in enumerate(self.reader()):
279-
train_reader_cost += time.time() - reader_start
280-
train_start = time.time()
281-
# --------------------------------------------------- #
282-
fetch_batch_var = self.exe.run(
283-
program=paddle.static.default_main_program(),
284-
feed=dict(zip(input_data_names, batch_data)),
285-
fetch_list=[var for _, var in self.metrics.items()])
286-
# --------------------------------------------------- #
287-
train_run_cost += time.time() - train_start
288-
total_samples += batch_size
289-
if batch_id % print_interval == 0:
290-
metric_str = ""
291-
for var_idx, var_name in enumerate(self.metrics):
292-
metric_str += "{}: {}, ".format(
293-
var_name, fetch_batch_var[var_idx]
294-
if var_name != "LOSS" or config['pure_bf16'] is False
295-
else bf16_to_fp32(fetch_batch_var[var_idx][0]))
296-
logger.info(
297-
"Epoch: {}, Batch_id: {}, ".format(epoch,
298-
batch_id) + metric_str +
299-
" avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} {}/sec"
300-
.format(train_reader_cost / print_interval, (
301-
train_reader_cost + train_run_cost) / print_interval,
302-
total_samples / print_interval, total_samples / (
303-
train_reader_cost + train_run_cost),
304-
self.count_method))
305-
train_reader_cost = 0.0
306-
train_run_cost = 0.0
307-
total_samples = 0
308-
reader_start = time.time()
309-
310223
def heter_train_loop(self, epoch):
311224
logger.info(
312225
"Epoch: {}, Running Begin. Check running metrics at heter_log".

0 commit comments

Comments
 (0)