Skip to content

Commit 61c25f8

Browse files
authored
[RFC] Seperate init_distributed_env from the Trainer.__init__ (#2003)
This allows people to customize the distributed environment, including ParallelDims and distributed backend.
1 parent 4caa379 commit 61c25f8

File tree

2 files changed

+29
-16
lines changed

2 files changed

+29
-16
lines changed

torchtitan/experiments/torchcomms/train.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from torchtitan.distributed import ParallelDims
7+
import os
8+
9+
from torchtitan.distributed import ParallelDims, utils as dist_utils
810
from torchtitan.train import main, Trainer
911

1012
from .parallel_dims import TorchCommsParallelDims
@@ -13,7 +15,17 @@
1315
class TorchCommsTrainer(Trainer):
1416
parallel_dims: TorchCommsParallelDims
1517

16-
def _create_parallel_dims(self, parallelism_config, world_size) -> ParallelDims:
18+
def init_distributed(self) -> ParallelDims:
19+
job_config = self.job_config
20+
dist_utils.init_distributed(
21+
job_config.comm,
22+
enable_cpu_backend=job_config.training.enable_cpu_offload,
23+
base_folder=job_config.job.dump_folder,
24+
)
25+
26+
world_size = int(os.environ["WORLD_SIZE"])
27+
parallelism_config = job_config.parallelism
28+
1729
return TorchCommsParallelDims(
1830
dp_shard=parallelism_config.data_parallel_shard_degree,
1931
dp_replicate=parallelism_config.data_parallel_replicate_degree,

torchtitan/train.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -84,20 +84,10 @@ def __init__(self, job_config: JobConfig):
8484
# Device has to be set before creating TorchFT manager.
8585
device_module.set_device(self.device)
8686

87-
# init distributed and build meshes
88-
dist_utils.init_distributed(
89-
job_config.comm,
90-
enable_cpu_backend=job_config.training.enable_cpu_offload,
91-
base_folder=job_config.job.dump_folder,
92-
)
93-
9487
job_config.maybe_log()
9588

96-
world_size = int(os.environ["WORLD_SIZE"])
97-
parallelism_config = job_config.parallelism
98-
self.parallel_dims = parallel_dims = self._create_parallel_dims(
99-
parallelism_config, world_size
100-
)
89+
# init distributed and build meshes
90+
self.parallel_dims = parallel_dims = self.init_distributed()
10191

10292
world_mesh = parallel_dims.world_mesh
10393
if parallel_dims.dp_enabled:
@@ -319,7 +309,8 @@ def __init__(self, job_config: JobConfig):
319309
)
320310

321311
loss_parallel_enabled = (
322-
parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel
312+
parallel_dims.tp_enabled
313+
and not job_config.parallelism.disable_loss_parallel
323314
)
324315
self.train_context = dist_utils.get_train_context(loss_parallel_enabled)
325316
self.maybe_enable_amp = dist_utils.maybe_enable_amp(
@@ -367,7 +358,17 @@ def __init__(self, job_config: JobConfig):
367358
f"(warmup {job_config.lr_scheduler.warmup_steps})"
368359
)
369360

370-
def _create_parallel_dims(self, parallelism_config, world_size) -> ParallelDims:
361+
def init_distributed(self) -> ParallelDims:
362+
job_config = self.job_config
363+
dist_utils.init_distributed(
364+
job_config.comm,
365+
enable_cpu_backend=job_config.training.enable_cpu_offload,
366+
base_folder=job_config.job.dump_folder,
367+
)
368+
369+
world_size = int(os.environ["WORLD_SIZE"])
370+
parallelism_config = job_config.parallelism
371+
371372
return ParallelDims(
372373
dp_shard=parallelism_config.data_parallel_shard_degree,
373374
dp_replicate=parallelism_config.data_parallel_replicate_degree,

0 commit comments

Comments
 (0)