From a6183416f70f3c90568b39ad4d6040b13f775567 Mon Sep 17 00:00:00 2001 From: Markus Krimmel Date: Sun, 18 Aug 2024 09:18:50 +0200 Subject: [PATCH] allow loading of nested states when Strategy.load_checkpoint is used --- src/lightning/fabric/fabric.py | 17 ++++++++---- src/lightning/fabric/strategies/strategy.py | 29 +++++++++++++-------- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index 0ff5b04b30b0a..c76299eebf268 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -85,6 +85,17 @@ def _do_nothing(*_: Any) -> None: pass +def _recursively_update_state(old_state: Dict[str, Any], new_unwrapped_state: Dict[str, Any]) -> None: + for k in list(new_unwrapped_state.keys()): + obj, _ = _unwrap_compiled(old_state[k]) + if isinstance(obj, (_FabricModule, _FabricOptimizer, _FabricDataLoader)): + pass + elif isinstance(obj, dict): + _recursively_update_state(old_state[k], new_unwrapped_state[k]) + else: + old_state[k] = new_unwrapped_state[k] + + class Fabric: r"""Fabric accelerates your PyTorch training or inference code with minimal changes required. @@ -775,11 +786,7 @@ def load( if state is not None: # We need to unwrap objects (see above) but this creates a new dictionary. In-place updates # (for user metadata) wouldn't show up in the original dict, so we need to copy the data back. - for k in list(unwrapped_state.keys()): - obj, _ = _unwrap_compiled(state[k]) - if isinstance(obj, (_FabricModule, _FabricOptimizer, _FabricDataLoader)): - continue - state[k] = unwrapped_state[k] + _recursively_update_state(state, unwrapped_state) return remainder def load_raw(self, path: Union[str, Path], obj: Union[nn.Module, Optimizer], strict: bool = True) -> None: diff --git a/src/lightning/fabric/strategies/strategy.py b/src/lightning/fabric/strategies/strategy.py index 6bfed6a270b68..7bc21aa2fdff2 100644 --- a/src/lightning/fabric/strategies/strategy.py +++ b/src/lightning/fabric/strategies/strategy.py @@ -301,6 +301,21 @@ def get_optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: # for optimizers that are not sharded, we return the state dict on all ranks return optimizer.state_dict() + def _recursively_load_state(self, state: Dict[str, Any], checkpoint: Dict[str, Any], strict: bool = True) -> None: + _validate_keys_for_strict_loading(state.keys(), checkpoint.keys(), strict=strict) + for name, obj in state.copy().items(): + if name not in checkpoint: + continue + if isinstance(obj, _Stateful): + if isinstance(obj, Module): + self.load_module_state_dict(module=obj, state_dict=checkpoint.pop(name), strict=strict) + else: + obj.load_state_dict(checkpoint.pop(name)) + elif isinstance(obj, dict): + self._recursively_load_state(state=state[name], checkpoint=checkpoint.pop(name), strict=strict) + else: + state[name] = checkpoint.pop(name) + def load_checkpoint( self, path: _PATH, @@ -338,17 +353,7 @@ def load_checkpoint( state.load_state_dict(checkpoint) return {} - _validate_keys_for_strict_loading(state.keys(), checkpoint.keys(), strict=strict) - for name, obj in state.copy().items(): - if name not in checkpoint: - continue - if isinstance(obj, _Stateful): - if isinstance(obj, Module): - self.load_module_state_dict(module=obj, state_dict=checkpoint.pop(name), strict=strict) - else: - obj.load_state_dict(checkpoint.pop(name)) - else: - state[name] = checkpoint.pop(name) + self._recursively_load_state(state, checkpoint, strict=strict) return checkpoint def teardown(self) -> None: @@ -405,6 +410,8 @@ def _convert_stateful_objects_in_state( converted = self.get_optimizer_state(optimizer=obj) elif isinstance(obj, _Stateful): converted = obj.state_dict() + elif isinstance(obj, dict): + converted = self._convert_stateful_objects_in_state(obj, filter) else: converted = obj _apply_filter(key, filter, converted, converted_state)