@@ -84,20 +84,10 @@ def __init__(self, job_config: JobConfig):
8484 # Device has to be set before creating TorchFT manager.
8585 device_module .set_device (self .device )
8686
87- # init distributed and build meshes
88- dist_utils .init_distributed (
89- job_config .comm ,
90- enable_cpu_backend = job_config .training .enable_cpu_offload ,
91- base_folder = job_config .job .dump_folder ,
92- )
93-
9487 job_config .maybe_log ()
9588
96- world_size = int (os .environ ["WORLD_SIZE" ])
97- parallelism_config = job_config .parallelism
98- self .parallel_dims = parallel_dims = self ._create_parallel_dims (
99- parallelism_config , world_size
100- )
89+ # init distributed and build meshes
90+ self .parallel_dims = parallel_dims = self .init_distributed ()
10191
10292 world_mesh = parallel_dims .world_mesh
10393 if parallel_dims .dp_enabled :
@@ -319,7 +309,8 @@ def __init__(self, job_config: JobConfig):
319309 )
320310
321311 loss_parallel_enabled = (
322- parallel_dims .tp_enabled and not parallelism_config .disable_loss_parallel
312+ parallel_dims .tp_enabled
313+ and not job_config .parallelism .disable_loss_parallel
323314 )
324315 self .train_context = dist_utils .get_train_context (loss_parallel_enabled )
325316 self .maybe_enable_amp = dist_utils .maybe_enable_amp (
@@ -367,7 +358,17 @@ def __init__(self, job_config: JobConfig):
367358 f"(warmup { job_config .lr_scheduler .warmup_steps } )"
368359 )
369360
370- def _create_parallel_dims (self , parallelism_config , world_size ) -> ParallelDims :
361+ def init_distributed (self ) -> ParallelDims :
362+ job_config = self .job_config
363+ dist_utils .init_distributed (
364+ job_config .comm ,
365+ enable_cpu_backend = job_config .training .enable_cpu_offload ,
366+ base_folder = job_config .job .dump_folder ,
367+ )
368+
369+ world_size = int (os .environ ["WORLD_SIZE" ])
370+ parallelism_config = job_config .parallelism
371+
371372 return ParallelDims (
372373 dp_shard = parallelism_config .data_parallel_shard_degree ,
373374 dp_replicate = parallelism_config .data_parallel_replicate_degree ,
0 commit comments