@@ -238,18 +238,18 @@ def build_network(self, context):
238
238
else :
239
239
context ["fleet" ].init_worker ()
240
240
context ["dataset" ] = {}
241
- for dataset in context ["env" ]["dataset " ]:
242
- type = envs .get_global_env ("dataset." + dataset [ "name " ] +
241
+ for phase in context ["env" ]["phase " ]:
242
+ type = envs .get_global_env ("dataset." + phase [ "dataset_name " ] +
243
243
".type" )
244
244
if type == "DataLoader" :
245
245
data_loader = DataLoader (context )
246
246
data_loader .get_dataloader (context , dataset_name ,
247
247
model ._data_loader )
248
248
elif type == "QueueDataset" :
249
249
dataset_class = QueueDataset (context )
250
- context ["dataset" ][dataset [
251
- "name " ]] = dataset_class .create_dataset (
252
- dataset [ "name " ], context )
250
+ context ["dataset" ][phase [
251
+ "dataset_name " ]] = dataset_class .create_dataset (
252
+ phase [ "dataset_name " ], context )
253
253
context ["status" ] = "startup_pass"
254
254
255
255
def _build_strategy (self , context ):
@@ -336,7 +336,7 @@ def build_network(self, context):
336
336
self ._server (context )
337
337
else :
338
338
context ["dataset" ] = {}
339
- for dataset in context ["env" ]["dataset " ]:
339
+ for phase in context ["env" ]["phase " ]:
340
340
type = envs .get_global_env ("dataset." + dataset ["name" ] +
341
341
".type" )
342
342
if type == "DataLoader" :
@@ -363,6 +363,7 @@ def __init__(self, context):
363
363
def build_network (self , context ):
364
364
context ["model" ] = {}
365
365
if len (context ["env" ]["phase" ]) > 1 :
366
+ print ("CollectiveNetwork phase:{}" .format (context ["env" ]["phase" ]))
366
367
warnings .warn (
367
368
"Cluster Train Only Support One Phase." ,
368
369
category = UserWarning ,
@@ -407,16 +408,17 @@ def build_network(self, context):
407
408
context ["model" ][model_dict ["name" ]]["compiled_program" ] = None
408
409
409
410
context ["dataset" ] = {}
410
- for dataset in context ["env" ]["dataset" ]:
411
- type = envs .get_global_env ("dataset." + dataset ["name" ] + ".type" )
411
+ for phase in context ["env" ]["phase" ]:
412
+ type = envs .get_global_env ("dataset." + phase ["dataset_name" ] +
413
+ ".type" )
412
414
if type == "QueueDataset" :
413
415
raise ValueError (
414
416
"Collective don't support QueueDataset training, please use DataLoader."
415
417
)
416
418
dataset_class = QueueDataset (context )
417
- context ["dataset" ][dataset [
418
- "name " ]] = dataset_class .create_dataset (dataset [ "name" ],
419
- context )
419
+ context ["dataset" ][phase [
420
+ "dataset_name " ]] = dataset_class .create_dataset (
421
+ phase [ "dataset_name" ], context )
420
422
context ["status" ] = "startup_pass"
421
423
422
424
def _build_strategy (self , context ):
0 commit comments