44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import ctypes
78import importlib
89import os
10+ import signal
911import time
1012from datetime import timedelta
1113from typing import Any , Generator , Iterable
3335 maybe_enable_profiling ,
3436)
3537
38+ c_globals = ctypes .CDLL (None ) # POSIX
39+
3640
3741class Trainer (torch .distributed .checkpoint .stateful .Stateful ):
42+ torch_profiler : torch .profiler .profile | None = None
43+
3844 # core configs
3945 job_config : JobConfig
4046 parallel_dims : ParallelDims
@@ -613,13 +619,14 @@ def train(self):
613619 if not self .ft_manager .enabled
614620 else f"replica_{ self .ft_manager .replica_id } "
615621 )
622+ self .torch_profiler = maybe_enable_profiling (
623+ job_config .profiling ,
624+ global_step = self .step ,
625+ base_folder = job_config .job .dump_folder ,
626+ leaf_folder = leaf_folder ,
627+ )
628+
616629 with (
617- maybe_enable_profiling (
618- job_config .profiling ,
619- global_step = self .step ,
620- base_folder = job_config .job .dump_folder ,
621- leaf_folder = leaf_folder ,
622- ) as torch_profiler ,
623630 maybe_enable_memory_snapshot (
624631 job_config .profiling ,
625632 global_step = self .step ,
@@ -643,6 +650,15 @@ def train(self):
643650 ),
644651 ),
645652 ):
653+ if self .torch_profiler :
654+
655+ @ctypes .CFUNCTYPE (None , ctypes .c_int )
656+ def sigabrt_handler (signal ):
657+ logger .info ("SIGABRT received. Stopping profiler" )
658+ self .torch_profiler .export_chrome_trace ("trace.json" )
659+
660+ c_globals .signal (signal .SIGABRT , sigabrt_handler )
661+
646662 data_iterator = self .batch_generator (self .dataloader )
647663 while self .should_continue_training ():
648664 self .step += 1
@@ -666,8 +682,8 @@ def train(self):
666682 self .validator .validate (self .model_parts , self .step )
667683
668684 # signal the profiler that the next profiling step has started
669- if torch_profiler :
670- torch_profiler .step ()
685+ if self . torch_profiler :
686+ self . torch_profiler .step ()
671687 if memory_profiler :
672688 memory_profiler .step ()
673689
@@ -730,10 +746,12 @@ def main(trainer_class: type[Trainer]) -> None:
730746 else :
731747 trainer .train ()
732748 except Exception :
749+ logger .info ("Torchtitan training threw an exception" )
733750 if trainer :
734751 trainer .close ()
735752 raise
736753 else :
754+ logger .info ("Torchtitan training completed" )
737755 trainer .close ()
738756 torch .distributed .destroy_process_group ()
739757 logger .info ("Process group destroyed" )
0 commit comments