Skip to content

Commit a013851

Browse files
committed
trigger profiling on abort
Summary: record the profile trace if the training process receives SIGABRT e.g. when Process Group watchdog aborts the process
1 parent 50b2f30 commit a013851

File tree

3 files changed

+38
-20
lines changed

3 files changed

+38
-20
lines changed

torchtitan/experiments/forge/example_train.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -282,12 +282,13 @@ def train(self):
282282
self.checkpointer.load(step=job_config.checkpoint.load_step)
283283
logger.info(f"Training starts at step {self.step + 1}.")
284284

285+
torch_profiler = maybe_enable_profiling(
286+
job_config.profiling,
287+
global_step=self.step,
288+
base_folder=job_config.job.dump_folder,
289+
)
290+
285291
with (
286-
maybe_enable_profiling(
287-
job_config.profiling,
288-
global_step=self.step,
289-
base_folder=job_config.job.dump_folder,
290-
) as torch_profiler,
291292
maybe_enable_memory_snapshot(
292293
job_config.profiling,
293294
global_step=self.step,

torchtitan/tools/profiling.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
MEMORY_SNAPSHOT_MAX_ENTRIES = 100000
1919

2020

21-
@contextlib.contextmanager
2221
def maybe_enable_profiling(
2322
profiling_config: ProfilingConfig,
2423
*,
@@ -68,20 +67,20 @@ def trace_handler(prof):
6867
gpu_device_profiled = torch.profiler.ProfilerActivity.CUDA
6968
elif torch.xpu.is_available():
7069
gpu_device_profiled = torch.profiler.ProfilerActivity.XPU
71-
with torch.profiler.profile(
70+
torch_profiler = torch.profiler.profile(
7271
activities=[
7372
torch.profiler.ProfilerActivity.CPU,
7473
gpu_device_profiled,
7574
],
7675
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active),
7776
on_trace_ready=trace_handler,
7877
record_shapes=True,
79-
) as torch_profiler:
80-
torch_profiler.step_num = global_step
81-
yield torch_profiler
78+
)
79+
torch_profiler.step_num = global_step
80+
torch_profiler.start()
81+
return torch_profiler
8282
else:
83-
torch_profiler = contextlib.nullcontext()
84-
yield None
83+
return None
8584

8685

8786
@contextlib.contextmanager

torchtitan/train.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import ctypes
78
import importlib
89
import os
10+
import signal
911
import time
1012
from datetime import timedelta
1113
from typing import Any, Generator, Iterable
@@ -33,8 +35,12 @@
3335
maybe_enable_profiling,
3436
)
3537

38+
c_globals = ctypes.CDLL(None) # POSIX
39+
3640

3741
class Trainer(torch.distributed.checkpoint.stateful.Stateful):
42+
torch_profiler: torch.profiler.profile | None = None
43+
3844
# core configs
3945
job_config: JobConfig
4046
parallel_dims: ParallelDims
@@ -613,13 +619,14 @@ def train(self):
613619
if not self.ft_manager.enabled
614620
else f"replica_{self.ft_manager.replica_id}"
615621
)
622+
self.torch_profiler = maybe_enable_profiling(
623+
job_config.profiling,
624+
global_step=self.step,
625+
base_folder=job_config.job.dump_folder,
626+
leaf_folder=leaf_folder,
627+
)
628+
616629
with (
617-
maybe_enable_profiling(
618-
job_config.profiling,
619-
global_step=self.step,
620-
base_folder=job_config.job.dump_folder,
621-
leaf_folder=leaf_folder,
622-
) as torch_profiler,
623630
maybe_enable_memory_snapshot(
624631
job_config.profiling,
625632
global_step=self.step,
@@ -643,6 +650,15 @@ def train(self):
643650
),
644651
),
645652
):
653+
if self.torch_profiler:
654+
655+
@ctypes.CFUNCTYPE(None, ctypes.c_int)
656+
def sigabrt_handler(signal):
657+
logger.info("SIGABRT received. Stopping profiler")
658+
self.torch_profiler.export_chrome_trace("trace.json")
659+
660+
c_globals.signal(signal.SIGABRT, sigabrt_handler)
661+
646662
data_iterator = self.batch_generator(self.dataloader)
647663
while self.should_continue_training():
648664
self.step += 1
@@ -666,8 +682,8 @@ def train(self):
666682
self.validator.validate(self.model_parts, self.step)
667683

668684
# signal the profiler that the next profiling step has started
669-
if torch_profiler:
670-
torch_profiler.step()
685+
if self.torch_profiler:
686+
self.torch_profiler.step()
671687
if memory_profiler:
672688
memory_profiler.step()
673689

@@ -730,10 +746,12 @@ def main(trainer_class: type[Trainer]) -> None:
730746
else:
731747
trainer.train()
732748
except Exception:
749+
logger.info("Torchtitan training threw an exception")
733750
if trainer:
734751
trainer.close()
735752
raise
736753
else:
754+
logger.info("Torchtitan training completed")
737755
trainer.close()
738756
torch.distributed.destroy_process_group()
739757
logger.info("Process group destroyed")

0 commit comments

Comments
 (0)