Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions torchtitan/experiments/forge/example_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,13 @@ def train(self):
self.checkpointer.load(step=job_config.checkpoint.load_step)
logger.info(f"Training starts at step {self.step + 1}.")

torch_profiler = maybe_enable_profiling(
job_config.profiling,
global_step=self.step,
base_folder=job_config.job.dump_folder,
)

with (
maybe_enable_profiling(
job_config.profiling,
global_step=self.step,
base_folder=job_config.job.dump_folder,
) as torch_profiler,
maybe_enable_memory_snapshot(
job_config.profiling,
global_step=self.step,
Expand Down
13 changes: 6 additions & 7 deletions torchtitan/tools/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
MEMORY_SNAPSHOT_MAX_ENTRIES = 100000


@contextlib.contextmanager
def maybe_enable_profiling(
profiling_config: ProfilingConfig,
*,
Expand Down Expand Up @@ -68,20 +67,20 @@ def trace_handler(prof):
gpu_device_profiled = torch.profiler.ProfilerActivity.CUDA
elif torch.xpu.is_available():
gpu_device_profiled = torch.profiler.ProfilerActivity.XPU
with torch.profiler.profile(
torch_profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
gpu_device_profiled,
],
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active),
on_trace_ready=trace_handler,
record_shapes=True,
) as torch_profiler:
torch_profiler.step_num = global_step
yield torch_profiler
)
torch_profiler.step_num = global_step
torch_profiler.start()
return torch_profiler
else:
torch_profiler = contextlib.nullcontext()
yield None
return None


@contextlib.contextmanager
Expand Down
34 changes: 26 additions & 8 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import ctypes
import importlib
import os
import signal
import time
from datetime import timedelta
from typing import Any, Generator, Iterable
Expand Down Expand Up @@ -33,8 +35,12 @@
maybe_enable_profiling,
)

c_globals = ctypes.CDLL(None) # POSIX


class Trainer(torch.distributed.checkpoint.stateful.Stateful):
torch_profiler: torch.profiler.profile | None = None

# core configs
job_config: JobConfig
parallel_dims: ParallelDims
Expand Down Expand Up @@ -613,13 +619,14 @@ def train(self):
if not self.ft_manager.enabled
else f"replica_{self.ft_manager.replica_id}"
)
self.torch_profiler = maybe_enable_profiling(
job_config.profiling,
global_step=self.step,
base_folder=job_config.job.dump_folder,
leaf_folder=leaf_folder,
)

with (
maybe_enable_profiling(
job_config.profiling,
global_step=self.step,
base_folder=job_config.job.dump_folder,
leaf_folder=leaf_folder,
) as torch_profiler,
maybe_enable_memory_snapshot(
job_config.profiling,
global_step=self.step,
Expand All @@ -643,6 +650,15 @@ def train(self):
),
),
):
if self.torch_profiler:

@ctypes.CFUNCTYPE(None, ctypes.c_int)
def sigabrt_handler(signal):
logger.info("SIGABRT received. Stopping profiler")
self.torch_profiler.export_chrome_trace("trace.json")

c_globals.signal(signal.SIGABRT, sigabrt_handler)

data_iterator = self.batch_generator(self.dataloader)
while self.should_continue_training():
self.step += 1
Expand All @@ -666,8 +682,8 @@ def train(self):
self.validator.validate(self.model_parts, self.step)

# signal the profiler that the next profiling step has started
if torch_profiler:
torch_profiler.step()
if self.torch_profiler:
self.torch_profiler.step()
if memory_profiler:
memory_profiler.step()

Expand Down Expand Up @@ -730,10 +746,12 @@ def main(trainer_class: type[Trainer]) -> None:
else:
trainer.train()
except Exception:
logger.info("Torchtitan training threw an exception")
if trainer:
trainer.close()
raise
else:
logger.info("Torchtitan training completed")
trainer.close()
torch.distributed.destroy_process_group()
logger.info("Process group destroyed")
Expand Down
Loading