Skip to content

Commit 3b9c100

Browse files
authored
Merge pull request #236 from vslyu/fix_phase
fix phase & format save_step output information
2 parents 7cfe354 + dad5966 commit 3b9c100

File tree

3 files changed

+21
-18
lines changed

3 files changed

+21
-18
lines changed

core/trainer.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -76,18 +76,17 @@ def __init__(self, config=None):
7676

7777
_config = envs.load_yaml(config)
7878

79-
self._context["env"] = _config
80-
self._context["dataset"] = _config.get("dataset")
81-
8279
phases = []
8380
if phase_names is None:
8481
phases = _config.get("phase")
8582
else:
8683
for phase in _config.get("phase"):
8784
if phase["name"] in phase_names:
8885
phases.append(phase)
89-
9086
self._context["phases"] = phases
87+
_config["phase"] = phases
88+
self._context["env"] = _config
89+
self._context["dataset"] = _config.get("dataset")
9190
print("PaddleRec: Runner {} Begin".format(self._runner_name))
9291
self.which_engine()
9392
self.which_device()

core/trainers/framework/network.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -238,18 +238,18 @@ def build_network(self, context):
238238
else:
239239
context["fleet"].init_worker()
240240
context["dataset"] = {}
241-
for dataset in context["env"]["dataset"]:
242-
type = envs.get_global_env("dataset." + dataset["name"] +
241+
for phase in context["env"]["phase"]:
242+
type = envs.get_global_env("dataset." + phase["dataset_name"] +
243243
".type")
244244
if type == "DataLoader":
245245
data_loader = DataLoader(context)
246246
data_loader.get_dataloader(context, dataset_name,
247247
model._data_loader)
248248
elif type == "QueueDataset":
249249
dataset_class = QueueDataset(context)
250-
context["dataset"][dataset[
251-
"name"]] = dataset_class.create_dataset(
252-
dataset["name"], context)
250+
context["dataset"][phase[
251+
"dataset_name"]] = dataset_class.create_dataset(
252+
phase["dataset_name"], context)
253253
context["status"] = "startup_pass"
254254

255255
def _build_strategy(self, context):
@@ -336,7 +336,7 @@ def build_network(self, context):
336336
self._server(context)
337337
else:
338338
context["dataset"] = {}
339-
for dataset in context["env"]["dataset"]:
339+
for phase in context["env"]["phase"]:
340340
type = envs.get_global_env("dataset." + dataset["name"] +
341341
".type")
342342
if type == "DataLoader":
@@ -363,6 +363,7 @@ def __init__(self, context):
363363
def build_network(self, context):
364364
context["model"] = {}
365365
if len(context["env"]["phase"]) > 1:
366+
print("CollectiveNetwork phase:{}".format(context["env"]["phase"]))
366367
warnings.warn(
367368
"Cluster Train Only Support One Phase.",
368369
category=UserWarning,
@@ -407,16 +408,17 @@ def build_network(self, context):
407408
context["model"][model_dict["name"]]["compiled_program"] = None
408409

409410
context["dataset"] = {}
410-
for dataset in context["env"]["dataset"]:
411-
type = envs.get_global_env("dataset." + dataset["name"] + ".type")
411+
for phase in context["env"]["phase"]:
412+
type = envs.get_global_env("dataset." + phase["dataset_name"] +
413+
".type")
412414
if type == "QueueDataset":
413415
raise ValueError(
414416
"Collective don't support QueueDataset training, please use DataLoader."
415417
)
416418
dataset_class = QueueDataset(context)
417-
context["dataset"][dataset[
418-
"name"]] = dataset_class.create_dataset(dataset["name"],
419-
context)
419+
context["dataset"][phase[
420+
"dataset_name"]] = dataset_class.create_dataset(
421+
phase["dataset_name"], context)
420422
context["status"] = "startup_pass"
421423

422424
def _build_strategy(self, context):

core/trainers/framework/runner.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -436,9 +436,11 @@ def save_checkpoint_step():
436436
dirname = envs.get_global_env(name + "save_step_path", None)
437437
if dirname is None or dirname == "":
438438
return
439-
dirname = os.path.join(dirname, str(batch_id))
440-
logging.info("\tsave batch_id:%d model into: \"%s\"" %
441-
(batch_id, dirname))
439+
dirname = os.path.join(dirname,
440+
"epoch_" + str(context["current_epoch"]) +
441+
"_batch_" + str(batch_id))
442+
logging.info("\tsave epoch_id:%d, batch_id:%d model into: \"%s\"" %
443+
(context["current_epoch"], batch_id, dirname))
442444
if is_fleet:
443445
if context["fleet"].worker_index() == 0:
444446
context["fleet"].save_persistables(context["exe"], dirname)

0 commit comments

Comments
 (0)