Skip to content
10 changes: 6 additions & 4 deletions src/lightning/pytorch/profilers/advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import os
import pstats
import tempfile
from collections import defaultdict
from pathlib import Path
from typing import Optional, Union

Expand Down Expand Up @@ -66,14 +67,15 @@ def __init__(
If you attempt to stop recording an action which was never started.
"""
super().__init__(dirpath=dirpath, filename=filename)
self.profiled_actions: dict[str, cProfile.Profile] = {}
self.profiled_actions: dict[str, cProfile.Profile] = defaultdict(cProfile.Profile)
self.line_count_restriction = line_count_restriction
self.dump_stats = dump_stats

@override
def start(self, action_name: str) -> None:
if action_name not in self.profiled_actions:
self.profiled_actions[action_name] = cProfile.Profile()
# Disable all profilers before starting a new one
for pr in self.profiled_actions.values():
pr.disable()
self.profiled_actions[action_name].enable()

@override
Expand Down Expand Up @@ -114,7 +116,7 @@ def summary(self) -> str:
@override
def teardown(self, stage: Optional[str]) -> None:
super().teardown(stage=stage)
self.profiled_actions = {}
self.profiled_actions = defaultdict(cProfile.Profile)

def __reduce__(self) -> tuple:
# avoids `TypeError: cannot pickle 'cProfile.Profile' object`
Expand Down
Loading