Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
074b01e
change weights_only default to True
matsumotosan Aug 14, 2025
65cc1ed
add docs on weights_only arg
matsumotosan Aug 14, 2025
4eaaf58
Merge branch 'master' into weights-only-compatibility
SkafteNicki Aug 15, 2025
f276114
add weights_only arg to checkpoint save. weights_only during test set…
matsumotosan Aug 15, 2025
601e300
Merge branch 'master' into weights-only-compatibility
matsumotosan Aug 15, 2025
28f53ae
add weights_only arg to checkpoint_io
matsumotosan Aug 16, 2025
b1cfdf1
woops, reverting changes
matsumotosan Aug 16, 2025
4d96a78
permissions too
matsumotosan Aug 16, 2025
4c39c30
fix link
matsumotosan Aug 16, 2025
861d7e0
fix another link
matsumotosan Aug 16, 2025
12bd0d6
datamodule weights_only args
matsumotosan Aug 17, 2025
5eacb6e
wip: try safe_globals context manager for tests
matsumotosan Aug 17, 2025
0430e22
add weights_only arg to _run_standard_hparams_test
matsumotosan Aug 18, 2025
2abe915
weights_only=False when adding extra_args
matsumotosan Aug 18, 2025
8e0f61e
Merge branch 'master' into weights-only-compatibility
Borda Aug 18, 2025
525d9a8
Merge branch 'master' into weights-only-compatibility
matsumotosan Aug 18, 2025
83fd824
switch to lightning_utilities.cli requirements set-oldest (#21077)
Borda Aug 19, 2025
93cbe94
bump: try `deepspeed >=0.14.1,<=0.15.0` (#21076)
Borda Aug 19, 2025
2a53f2f
weights_only=True default for torch>=2.6
matsumotosan Aug 19, 2025
3833892
Merge branch 'master' into weights-only-compatibility
matsumotosan Aug 19, 2025
9d8997e
Merge branch 'Lightning-AI:master' into weights-only-compatibility
matsumotosan Aug 19, 2025
561c02c
changelog
matsumotosan Aug 19, 2025
c67c8a3
ignore torch.load futurewarning
matsumotosan Aug 19, 2025
005c439
add .*
matsumotosan Aug 19, 2025
2ab89a2
will this woork
matsumotosan Aug 19, 2025
74e5e5a
weights_only according pl version
matsumotosan Aug 20, 2025
a4c9efe
set env var TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 for pl < 1.5.0
matsumotosan Aug 20, 2025
2c2ab9e
weights_only=False for omegaconf hparams test
matsumotosan Aug 21, 2025
54b859a
default to weights_only=true for loading from state_dict from url
matsumotosan Aug 21, 2025
7ddb4f8
weights_only=False for hydra
matsumotosan Aug 21, 2025
7d6174a
Merge branch 'master' into weights-only-compatibility
matsumotosan Aug 23, 2025
906e52e
Update src/lightning/fabric/utilities/cloud_io.py
matsumotosan Aug 27, 2025
377cf11
Merge branch 'Lightning-AI:master' into weights-only-compatibility
matsumotosan Aug 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/lightning/fabric/plugins/io/checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,20 @@ def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_optio
"""

@abstractmethod
def load_checkpoint(self, path: _PATH, map_location: Optional[Any] = None) -> dict[str, Any]:
def load_checkpoint(
self, path: _PATH, map_location: Optional[Any] = None, weights_only: bool = True
) -> dict[str, Any]:
"""Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages.

Args:
path: Path to checkpoint
map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
locations.
weights_only: If ``True``, restricts loading to ``state_dicts`` of plain ``torch.Tensor`` and other
primitive types. If loading a checkpoint from a trusted source that contains an ``nn.Module``, use
``weights_only=False``. If loading checkpoint from an untrusted source, we recommend using
``weights_only=True``. For more information, please refer to the
`PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.

Returns: The loaded checkpoint.

Expand Down
4 changes: 2 additions & 2 deletions src/lightning/fabric/plugins/io/torch_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_optio

@override
def load_checkpoint(
self, path: _PATH, map_location: Optional[Callable] = lambda storage, loc: storage
self, path: _PATH, map_location: Optional[Callable] = lambda storage, loc: storage, weights_only: bool = True
) -> dict[str, Any]:
"""Loads checkpoint using :func:`torch.load`, with additional handling for ``fsspec`` remote loading of files.

Expand All @@ -80,7 +80,7 @@ def load_checkpoint(
if not fs.exists(path):
raise FileNotFoundError(f"Checkpoint file not found: {path}")

return pl_load(path, map_location=map_location)
return pl_load(path, map_location=map_location, weights_only=weights_only)

@override
def remove_checkpoint(self, path: _PATH) -> None:
Expand Down
9 changes: 7 additions & 2 deletions src/lightning/fabric/utilities/cloud_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,18 @@
def _load(
path_or_url: Union[IO, _PATH],
map_location: _MAP_LOCATION_TYPE = None,
weights_only: bool = False,
weights_only: bool = True,
) -> Any:
"""Loads a checkpoint.

Args:
path_or_url: Path or URL of the checkpoint.
map_location: a function, ``torch.device``, string or a dict specifying how to remap storage locations.
weights_only: If ``True``, restricts loading to ``state_dicts`` of plain ``torch.Tensor`` and other primitive
types. If loading a checkpoint from a trusted source that contains an ``nn.Module``, use
``weights_only=False``. If loading checkpoint from an untrusted source, we recommend using
``weights_only=True``. For more information, please refer to the
`PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`__.

"""
if not isinstance(path_or_url, (str, Path)):
Expand Down Expand Up @@ -70,7 +75,7 @@ def get_filesystem(path: _PATH, **kwargs: Any) -> AbstractFileSystem:
return fs


def _atomic_save(checkpoint: dict[str, Any], filepath: Union[str, Path]) -> None:
def _atomic_save(checkpoint: dict[str, Any], filepath: _PATH) -> None:
"""Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints.

Args:
Expand Down
2 changes: 2 additions & 0 deletions src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1690,6 +1690,7 @@ def load_from_checkpoint(
map_location: _MAP_LOCATION_TYPE = None,
hparams_file: Optional[_PATH] = None,
strict: Optional[bool] = None,
weights_only: Optional[bool] = None,
**kwargs: Any,
) -> Self:
r"""Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments
Expand Down Expand Up @@ -1778,6 +1779,7 @@ def load_from_checkpoint(
map_location,
hparams_file,
strict,
weights_only,
**kwargs,
)
return cast(Self, loaded)
Expand Down
8 changes: 7 additions & 1 deletion src/lightning/pytorch/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,17 @@ def _load_from_checkpoint(
map_location: _MAP_LOCATION_TYPE = None,
hparams_file: Optional[_PATH] = None,
strict: Optional[bool] = None,
weights_only: Optional[bool] = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we default to weights_only=None or weights_only=True? If we have no use for weights_only=None, we can simplify the type hint to weights_only: bool = True.

**kwargs: Any,
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
map_location = map_location or _default_map_location

if weights_only is None:
log.debug("`weights_only` not specified, defaulting to `True`.")
weights_only = True

with pl_legacy_patch():
checkpoint = pl_load(checkpoint_path, map_location=map_location)
checkpoint = pl_load(checkpoint_path, map_location=map_location, weights_only=weights_only)

# convert legacy checkpoints to the new format
checkpoint = _pl_migrate_checkpoint(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,9 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
"""Creating a model checkpoint dictionary object from various component states.

Args:
weights_only: saving model weights only
weights_only: If True, only saves model and loops state_dict objects. If False,
additionally saves callbacks, optimizers, schedulers, and precision plugin states.

Return:
structured dictionary: {
'epoch': training epoch
Expand Down
15 changes: 13 additions & 2 deletions tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import pytest
import torch
from packaging.version import Version

import lightning.pytorch as pl
from lightning.pytorch import Callback, Trainer
Expand Down Expand Up @@ -45,7 +46,12 @@ def test_load_legacy_checkpoints(tmp_path, pl_version: str):
assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"'
path_ckpt = path_ckpts[-1]

model = ClassificationModel.load_from_checkpoint(path_ckpt, num_features=24)
# legacy load utility added in 1.5.0 (see https://github.com/Lightning-AI/pytorch-lightning/pull/9166)
if pl_version == "local":
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the simplest way that I could think of ensuring we continue testing the legacy checkpoints. Another way could be to use torch.serialization.add_safe_globals, but it seems a little more complicated (particularly since we're using the pl_legacy_patch context manager already.

pl_version = pl.__version__
weights_only = not Version(pl_version) < Version("1.5.0")

model = ClassificationModel.load_from_checkpoint(path_ckpt, num_features=24, weights_only=weights_only)
trainer = Trainer(default_root_dir=tmp_path)
dm = ClassifDataModule(num_features=24, length=6000, batch_size=128, n_clusters_per_class=2, n_informative=8)
res = trainer.test(model, datamodule=dm)
Expand Down Expand Up @@ -73,13 +79,18 @@ def test_legacy_ckpt_threading(pl_version: str):
assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"'
path_ckpt = path_ckpts[-1]

# legacy load utility added in 1.5.0 (see https://github.com/Lightning-AI/pytorch-lightning/pull/9166)
if pl_version == "local":
pl_version = pl.__version__
weights_only = not Version(pl_version) < Version("1.5.0")

def load_model():
import torch

from lightning.pytorch.utilities.migration import pl_legacy_patch

with pl_legacy_patch():
_ = torch.load(path_ckpt, weights_only=False)
_ = torch.load(path_ckpt, weights_only=weights_only)

with patch("sys.path", [PATH_LEGACY] + sys.path):
t1 = ThreadExceptionHandler(target=load_model)
Expand Down
4 changes: 3 additions & 1 deletion tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ class CustomCheckpointIO(CheckpointIO):
def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
torch.save(checkpoint, path)

def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> dict[str, Any]:
def load_checkpoint(
self, path: _PATH, storage_options: Optional[Any] = None, weights_only: bool = True
) -> dict[str, Any]:
return torch.load(path, weights_only=True)

def remove_checkpoint(self, path: _PATH) -> None:
Expand Down
Loading