Skip to content

Commit 20fcfd7

Browse files
authored
set pg names (#1986)
Summary: - we need to pass the global rank information to pytorch so that the pg name can include the pg information - this is necessary to differentiate the default pg's on different replicas - these need to different because flight recorder matches collectives based on pg name as well - add ft training to experiments folder, we'll move remaining pieces of ft to this gradually but make new features only available through this folder --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/torchtitan/pull/1986). * #1988 * #1987 * __->__ #1986 Co-authored-by: Tushar Jain <tushar00jain@users.noreply.github.com>
1 parent e37f83f commit 20fcfd7

File tree

2 files changed

+56
-1
lines changed

2 files changed

+56
-1
lines changed

torchtitan/distributed/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,10 @@ def maybe_enable_amp(
259259

260260

261261
def init_distributed(
262-
comm_config: CommConfig, enable_cpu_backend: bool = False, base_folder: str = ""
262+
comm_config: CommConfig,
263+
enable_cpu_backend: bool = False,
264+
base_folder: str = "",
265+
ranks: list[int] | None = None,
263266
):
264267
def _warn_overwrite_env(env, val):
265268
if env in os.environ:
@@ -303,6 +306,7 @@ def _get_distributed_backend(enable_cpu_backend):
303306
torch.distributed.init_process_group(
304307
backend=_get_distributed_backend(enable_cpu_backend),
305308
timeout=timedelta(seconds=comm_config.init_timeout_seconds),
309+
_ranks=ranks if ranks is not None else [],
306310
)
307311

308312

torchtitan/experiments/ft/train.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
9+
from torchtitan.distributed import ParallelDims, utils as dist_utils
10+
from torchtitan.train import main, Trainer
11+
12+
13+
class FTTrainer(Trainer):
14+
def init_distributed(self) -> ParallelDims:
15+
job_config = self.job_config
16+
17+
# determine the global ranks when fault tolerance is enabled
18+
global_ranks = []
19+
ft_config = job_config.fault_tolerance
20+
if ft_config.enable:
21+
group_size = ft_config.group_size
22+
replica_id = ft_config.replica_id
23+
first_rank = replica_id * group_size
24+
last_rank = first_rank + group_size - 1
25+
global_ranks = list(range(first_rank, last_rank + 1))
26+
27+
# init distributed and build meshes
28+
dist_utils.init_distributed(
29+
job_config.comm,
30+
enable_cpu_backend=job_config.training.enable_cpu_offload,
31+
base_folder=job_config.job.dump_folder,
32+
ranks=global_ranks,
33+
)
34+
35+
world_size = int(os.environ["WORLD_SIZE"])
36+
parallelism_config = job_config.parallelism
37+
38+
return ParallelDims(
39+
dp_shard=parallelism_config.data_parallel_shard_degree,
40+
dp_replicate=parallelism_config.data_parallel_replicate_degree,
41+
cp=parallelism_config.context_parallel_degree,
42+
tp=parallelism_config.tensor_parallel_degree,
43+
pp=parallelism_config.pipeline_parallel_degree,
44+
ep=parallelism_config.expert_parallel_degree,
45+
etp=parallelism_config.expert_tensor_parallel_degree,
46+
world_size=world_size,
47+
)
48+
49+
50+
if __name__ == "__main__":
51+
main(FTTrainer)

0 commit comments

Comments
 (0)