Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,17 @@ def _do_nothing(*_: Any) -> None:
pass


def _recursively_update_state(old_state: Dict[str, Any], new_unwrapped_state: Dict[str, Any]) -> None:
for k in list(new_unwrapped_state.keys()):
obj, _ = _unwrap_compiled(old_state[k])
if isinstance(obj, (_FabricModule, _FabricOptimizer, _FabricDataLoader)):
pass
elif isinstance(obj, dict):
_recursively_update_state(old_state[k], new_unwrapped_state[k])
else:
old_state[k] = new_unwrapped_state[k]


class Fabric:
r"""Fabric accelerates your PyTorch training or inference code with minimal changes required.

Expand Down Expand Up @@ -775,11 +786,7 @@ def load(
if state is not None:
# We need to unwrap objects (see above) but this creates a new dictionary. In-place updates
# (for user metadata) wouldn't show up in the original dict, so we need to copy the data back.
for k in list(unwrapped_state.keys()):
obj, _ = _unwrap_compiled(state[k])
if isinstance(obj, (_FabricModule, _FabricOptimizer, _FabricDataLoader)):
continue
state[k] = unwrapped_state[k]
_recursively_update_state(state, unwrapped_state)
return remainder

def load_raw(self, path: Union[str, Path], obj: Union[nn.Module, Optimizer], strict: bool = True) -> None:
Expand Down
29 changes: 18 additions & 11 deletions src/lightning/fabric/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,21 @@ def get_optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
# for optimizers that are not sharded, we return the state dict on all ranks
return optimizer.state_dict()

def _recursively_load_state(self, state: Dict[str, Any], checkpoint: Dict[str, Any], strict: bool = True) -> None:
_validate_keys_for_strict_loading(state.keys(), checkpoint.keys(), strict=strict)
for name, obj in state.copy().items():
if name not in checkpoint:
continue
if isinstance(obj, _Stateful):
if isinstance(obj, Module):
self.load_module_state_dict(module=obj, state_dict=checkpoint.pop(name), strict=strict)
else:
obj.load_state_dict(checkpoint.pop(name))
elif isinstance(obj, dict):
self._recursively_load_state(state=state[name], checkpoint=checkpoint.pop(name), strict=strict)
else:
state[name] = checkpoint.pop(name)

def load_checkpoint(
self,
path: _PATH,
Expand Down Expand Up @@ -338,17 +353,7 @@ def load_checkpoint(
state.load_state_dict(checkpoint)
return {}

_validate_keys_for_strict_loading(state.keys(), checkpoint.keys(), strict=strict)
for name, obj in state.copy().items():
if name not in checkpoint:
continue
if isinstance(obj, _Stateful):
if isinstance(obj, Module):
self.load_module_state_dict(module=obj, state_dict=checkpoint.pop(name), strict=strict)
else:
obj.load_state_dict(checkpoint.pop(name))
else:
state[name] = checkpoint.pop(name)
self._recursively_load_state(state, checkpoint, strict=strict)
return checkpoint

def teardown(self) -> None:
Expand Down Expand Up @@ -405,6 +410,8 @@ def _convert_stateful_objects_in_state(
converted = self.get_optimizer_state(optimizer=obj)
elif isinstance(obj, _Stateful):
converted = obj.state_dict()
elif isinstance(obj, dict):
converted = self._convert_stateful_objects_in_state(obj, filter)
else:
converted = obj
_apply_filter(key, filter, converted, converted_state)
Expand Down
Loading