diff --git a/torchtitan/experiments/forge/example_train.py b/torchtitan/experiments/forge/example_train.py index 7b0b0c81e9..8f1c4355f4 100644 --- a/torchtitan/experiments/forge/example_train.py +++ b/torchtitan/experiments/forge/example_train.py @@ -282,12 +282,13 @@ def train(self): self.checkpointer.load(step=job_config.checkpoint.load_step) logger.info(f"Training starts at step {self.step + 1}.") + torch_profiler = maybe_enable_profiling( + job_config.profiling, + global_step=self.step, + base_folder=job_config.job.dump_folder, + ) + with ( - maybe_enable_profiling( - job_config.profiling, - global_step=self.step, - base_folder=job_config.job.dump_folder, - ) as torch_profiler, maybe_enable_memory_snapshot( job_config.profiling, global_step=self.step, diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index f398dba9b5..049c780a73 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -18,7 +18,6 @@ MEMORY_SNAPSHOT_MAX_ENTRIES = 100000 -@contextlib.contextmanager def maybe_enable_profiling( profiling_config: ProfilingConfig, *, @@ -68,7 +67,7 @@ def trace_handler(prof): gpu_device_profiled = torch.profiler.ProfilerActivity.CUDA elif torch.xpu.is_available(): gpu_device_profiled = torch.profiler.ProfilerActivity.XPU - with torch.profiler.profile( + torch_profiler = torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, gpu_device_profiled, @@ -76,12 +75,12 @@ def trace_handler(prof): schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active), on_trace_ready=trace_handler, record_shapes=True, - ) as torch_profiler: - torch_profiler.step_num = global_step - yield torch_profiler + ) + torch_profiler.step_num = global_step + torch_profiler.start() + return torch_profiler else: - torch_profiler = contextlib.nullcontext() - yield None + return None @contextlib.contextmanager diff --git a/torchtitan/train.py b/torchtitan/train.py index 0070806e94..84309ddcda 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -4,8 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import ctypes import importlib import os +import signal import time from datetime import timedelta from typing import Any, Generator, Iterable @@ -33,8 +35,12 @@ maybe_enable_profiling, ) +c_globals = ctypes.CDLL(None) # POSIX + class Trainer(torch.distributed.checkpoint.stateful.Stateful): + torch_profiler: torch.profiler.profile | None = None + # core configs job_config: JobConfig parallel_dims: ParallelDims @@ -613,13 +619,14 @@ def train(self): if not self.ft_manager.enabled else f"replica_{self.ft_manager.replica_id}" ) + self.torch_profiler = maybe_enable_profiling( + job_config.profiling, + global_step=self.step, + base_folder=job_config.job.dump_folder, + leaf_folder=leaf_folder, + ) + with ( - maybe_enable_profiling( - job_config.profiling, - global_step=self.step, - base_folder=job_config.job.dump_folder, - leaf_folder=leaf_folder, - ) as torch_profiler, maybe_enable_memory_snapshot( job_config.profiling, global_step=self.step, @@ -643,6 +650,15 @@ def train(self): ), ), ): + if self.torch_profiler: + + @ctypes.CFUNCTYPE(None, ctypes.c_int) + def sigabrt_handler(signal): + logger.info("SIGABRT received. Stopping profiler") + self.torch_profiler.export_chrome_trace("trace.json") + + c_globals.signal(signal.SIGABRT, sigabrt_handler) + data_iterator = self.batch_generator(self.dataloader) while self.should_continue_training(): self.step += 1 @@ -666,8 +682,8 @@ def train(self): self.validator.validate(self.model_parts, self.step) # signal the profiler that the next profiling step has started - if torch_profiler: - torch_profiler.step() + if self.torch_profiler: + self.torch_profiler.step() if memory_profiler: memory_profiler.step() @@ -730,10 +746,12 @@ def main(trainer_class: type[Trainer]) -> None: else: trainer.train() except Exception: + logger.info("Torchtitan training threw an exception") if trainer: trainer.close() raise else: + logger.info("Torchtitan training completed") trainer.close() torch.distributed.destroy_process_group() logger.info("Process group destroyed")