diff --git a/pyproject.toml b/pyproject.toml index b4d5d0b1638f5..499fcbf8d0c37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -182,6 +182,7 @@ filterwarnings = [ # "error::DeprecationWarning", "error::FutureWarning", "ignore::FutureWarning:onnxscript", # Temporary ignore until onnxscript is updated + "ignore:You are using `torch.load` with `weights_only=False`.*:FutureWarning", ] xfail_strict = true junit_duration_report = "call" diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 16cd42adc90d9..f4797ce3e18ae 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -19,7 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- +- Default to `weights_only=True` for `torch>=2.6` when loading checkpoints. ([#21072](https://github.com/Lightning-AI/pytorch-lightning/pull/21072)) ### Fixed diff --git a/src/lightning/fabric/plugins/io/checkpoint_io.py b/src/lightning/fabric/plugins/io/checkpoint_io.py index 3a33dac3335d1..db7578d9ca8c0 100644 --- a/src/lightning/fabric/plugins/io/checkpoint_io.py +++ b/src/lightning/fabric/plugins/io/checkpoint_io.py @@ -47,13 +47,20 @@ def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_optio """ @abstractmethod - def load_checkpoint(self, path: _PATH, map_location: Optional[Any] = None) -> dict[str, Any]: + def load_checkpoint( + self, path: _PATH, map_location: Optional[Any] = None, weights_only: Optional[bool] = None + ) -> dict[str, Any]: """Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages. Args: path: Path to checkpoint map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage locations. + weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain + ``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains + an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we + recommend using ``weights_only=True``. For more information, please refer to the + `PyTorch Developer Notes on Serialization Semantics `_. Returns: The loaded checkpoint. diff --git a/src/lightning/fabric/plugins/io/torch_io.py b/src/lightning/fabric/plugins/io/torch_io.py index 90a5f62ba7413..c52ad6913e1e2 100644 --- a/src/lightning/fabric/plugins/io/torch_io.py +++ b/src/lightning/fabric/plugins/io/torch_io.py @@ -59,7 +59,10 @@ def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_optio @override def load_checkpoint( - self, path: _PATH, map_location: Optional[Callable] = lambda storage, loc: storage + self, + path: _PATH, + map_location: Optional[Callable] = lambda storage, loc: storage, + weights_only: Optional[bool] = None, ) -> dict[str, Any]: """Loads checkpoint using :func:`torch.load`, with additional handling for ``fsspec`` remote loading of files. @@ -67,6 +70,11 @@ def load_checkpoint( path: Path to checkpoint map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage locations. + weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain + ``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains + an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we + recommend using ``weights_only=True``. For more information, please refer to the + `PyTorch Developer Notes on Serialization Semantics `_. Returns: The loaded checkpoint. @@ -80,7 +88,7 @@ def load_checkpoint( if not fs.exists(path): raise FileNotFoundError(f"Checkpoint file not found: {path}") - return pl_load(path, map_location=map_location) + return pl_load(path, map_location=map_location, weights_only=weights_only) @override def remove_checkpoint(self, path: _PATH) -> None: diff --git a/src/lightning/fabric/utilities/cloud_io.py b/src/lightning/fabric/utilities/cloud_io.py index 637dfcd9b1671..a8c4007376b2b 100644 --- a/src/lightning/fabric/utilities/cloud_io.py +++ b/src/lightning/fabric/utilities/cloud_io.py @@ -17,7 +17,7 @@ import io import logging from pathlib import Path -from typing import IO, Any, Union +from typing import IO, Any, Optional, Union import fsspec import fsspec.utils @@ -34,13 +34,18 @@ def _load( path_or_url: Union[IO, _PATH], map_location: _MAP_LOCATION_TYPE = None, - weights_only: bool = False, + weights_only: Optional[bool] = None, ) -> Any: """Loads a checkpoint. Args: path_or_url: Path or URL of the checkpoint. map_location: a function, ``torch.device``, string or a dict specifying how to remap storage locations. + weights_only: If ``True``, restricts loading to ``state_dicts`` of plain ``torch.Tensor`` and other primitive + types. If loading a checkpoint from a trusted source that contains an ``nn.Module``, use + ``weights_only=False``. If loading checkpoint from an untrusted source, we recommend using + ``weights_only=True``. For more information, please refer to the + `PyTorch Developer Notes on Serialization Semantics `_. """ if not isinstance(path_or_url, (str, Path)): @@ -51,6 +56,9 @@ def _load( weights_only=weights_only, ) if str(path_or_url).startswith("http"): + if weights_only is None and _TORCH_GREATER_EQUAL_2_6: + weights_only = True + log.debug(f"Default to `weights_only=True` for remote checkpoint: {path_or_url}") return torch.hub.load_state_dict_from_url( str(path_or_url), map_location=map_location, # type: ignore[arg-type] @@ -70,7 +78,7 @@ def get_filesystem(path: _PATH, **kwargs: Any) -> AbstractFileSystem: return fs -def _atomic_save(checkpoint: dict[str, Any], filepath: Union[str, Path]) -> None: +def _atomic_save(checkpoint: dict[str, Any], filepath: _PATH) -> None: """Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints. Args: diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py index 70239baac0e6d..5655f2674638e 100644 --- a/src/lightning/fabric/utilities/imports.py +++ b/src/lightning/fabric/utilities/imports.py @@ -35,6 +35,7 @@ _TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0") _TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1") _TORCH_GREATER_EQUAL_2_5 = compare_version("torch", operator.ge, "2.5.0") +_TORCH_GREATER_EQUAL_2_6 = compare_version("torch", operator.ge, "2.6.0") _TORCH_LESS_EQUAL_2_6 = compare_version("torch", operator.le, "2.6.0") _TORCHMETRICS_GREATER_EQUAL_1_0_0 = compare_version("torchmetrics", operator.ge, "1.0.0") _PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index fe9173d008230..b41f57d1194b0 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -18,7 +18,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise. ([#9580](https://github.com/Lightning-AI/pytorch-lightning/pull/9580)) +- Default to `weights_only=True` for `torch>=2.6` when loading checkpoints. ([#21072](https://github.com/Lightning-AI/pytorch-lightning/pull/21072)) + + +- ### Removed diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index ff84c2fd8b199..07ec02ef87bd8 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -177,6 +177,7 @@ def load_from_checkpoint( checkpoint_path: Union[_PATH, IO], map_location: _MAP_LOCATION_TYPE = None, hparams_file: Optional[_PATH] = None, + weights_only: Optional[bool] = None, **kwargs: Any, ) -> Self: r"""Primary way of loading a datamodule from a checkpoint. When Lightning saves a checkpoint it stores the @@ -206,6 +207,11 @@ def load_from_checkpoint( If your datamodule's ``hparams`` argument is :class:`~argparse.Namespace` and ``.yaml`` file has hierarchical structure, you need to refactor your datamodule to treat ``hparams`` as :class:`~dict`. + weights_only: If ``True``, restricts loading to ``state_dicts`` of plain ``torch.Tensor`` and other + primitive types. If loading a checkpoint from a trusted source that contains an ``nn.Module``, use + ``weights_only=False``. If loading checkpoint from an untrusted source, we recommend using + ``weights_only=True``. For more information, please refer to the + `PyTorch Developer Notes on Serialization Semantics `_. \**kwargs: Any extra keyword args needed to init the datamodule. Can also be used to override saved hyperparameter values. @@ -242,6 +248,7 @@ def load_from_checkpoint( map_location=map_location, hparams_file=hparams_file, strict=None, + weights_only=weights_only, **kwargs, ) return cast(Self, loaded) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 85f631ee40f75..37b07f025f8e9 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -1690,6 +1690,7 @@ def load_from_checkpoint( map_location: _MAP_LOCATION_TYPE = None, hparams_file: Optional[_PATH] = None, strict: Optional[bool] = None, + weights_only: Optional[bool] = None, **kwargs: Any, ) -> Self: r"""Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments @@ -1723,6 +1724,11 @@ def load_from_checkpoint( strict: Whether to strictly enforce that the keys in :attr:`checkpoint_path` match the keys returned by this module's state dict. Defaults to ``True`` unless ``LightningModule.strict_loading`` is set, in which case it defaults to the value of ``LightningModule.strict_loading``. + weights_only: If ``True``, restricts loading to ``state_dicts`` of plain ``torch.Tensor`` and other + primitive types. If loading a checkpoint from a trusted source that contains an ``nn.Module``, use + ``weights_only=False``. If loading checkpoint from an untrusted source, we recommend using + ``weights_only=True``. For more information, please refer to the + `PyTorch Developer Notes on Serialization Semantics `_. \**kwargs: Any extra keyword args needed to init the model. Can also be used to override saved hyperparameter values. @@ -1778,6 +1784,7 @@ def load_from_checkpoint( map_location, hparams_file, strict, + weights_only, **kwargs, ) return cast(Self, loaded) diff --git a/src/lightning/pytorch/core/saving.py b/src/lightning/pytorch/core/saving.py index 21fd3912f7849..391e9dd5d0f25 100644 --- a/src/lightning/pytorch/core/saving.py +++ b/src/lightning/pytorch/core/saving.py @@ -56,11 +56,13 @@ def _load_from_checkpoint( map_location: _MAP_LOCATION_TYPE = None, hparams_file: Optional[_PATH] = None, strict: Optional[bool] = None, + weights_only: Optional[bool] = None, **kwargs: Any, ) -> Union["pl.LightningModule", "pl.LightningDataModule"]: map_location = map_location or _default_map_location + with pl_legacy_patch(): - checkpoint = pl_load(checkpoint_path, map_location=map_location) + checkpoint = pl_load(checkpoint_path, map_location=map_location, weights_only=weights_only) # convert legacy checkpoints to the new format checkpoint = _pl_migrate_checkpoint( diff --git a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py index 7f97a2f54bf19..52fb0e3230a82 100644 --- a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py +++ b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py @@ -414,7 +414,9 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: """Creating a model checkpoint dictionary object from various component states. Args: - weights_only: saving model weights only + weights_only: If True, only saves model and loops state_dict objects. If False, + additionally saves callbacks, optimizers, schedulers, and precision plugin states. + Return: structured dictionary: { 'epoch': training epoch diff --git a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py index 006a123356c98..a5ad77cf25c1a 100644 --- a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py +++ b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py @@ -18,6 +18,7 @@ import pytest import torch +from packaging.version import Version import lightning.pytorch as pl from lightning.pytorch import Callback, Trainer @@ -45,7 +46,12 @@ def test_load_legacy_checkpoints(tmp_path, pl_version: str): assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"' path_ckpt = path_ckpts[-1] - model = ClassificationModel.load_from_checkpoint(path_ckpt, num_features=24) + if pl_version == "local": + pl_version = pl.__version__ + + weights_only = Version(pl_version) >= Version("1.5.0") + + model = ClassificationModel.load_from_checkpoint(path_ckpt, num_features=24, weights_only=weights_only) trainer = Trainer(default_root_dir=tmp_path) dm = ClassifDataModule(num_features=24, length=6000, batch_size=128, n_clusters_per_class=2, n_informative=8) res = trainer.test(model, datamodule=dm) @@ -73,13 +79,18 @@ def test_legacy_ckpt_threading(pl_version: str): assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"' path_ckpt = path_ckpts[-1] + # legacy load utility added in 1.5.0 (see https://github.com/Lightning-AI/pytorch-lightning/pull/9166) + if pl_version == "local": + pl_version = pl.__version__ + weights_only = not Version(pl_version) < Version("1.5.0") + def load_model(): import torch from lightning.pytorch.utilities.migration import pl_legacy_patch with pl_legacy_patch(): - _ = torch.load(path_ckpt, weights_only=False) + _ = torch.load(path_ckpt, weights_only=weights_only) with patch("sys.path", [PATH_LEGACY] + sys.path): t1 = ThreadExceptionHandler(target=load_model) @@ -94,9 +105,14 @@ def load_model(): @pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS) @RunIf(sklearn=True) -def test_resume_legacy_checkpoints(tmp_path, pl_version: str): +def test_resume_legacy_checkpoints(monkeypatch, tmp_path, pl_version: str): PATH_LEGACY = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version) with patch("sys.path", [PATH_LEGACY] + sys.path): + if pl_version == "local": + pl_version = pl.__version__ + if Version(pl_version) < Version("1.5.0"): + monkeypatch.setenv("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD", "1") + path_ckpts = sorted(glob.glob(os.path.join(PATH_LEGACY, f"*{CHECKPOINT_EXTENSION}"))) assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"' path_ckpt = path_ckpts[-1] diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index 575bcadadc404..d0c72721ce1be 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -19,6 +19,7 @@ from argparse import Namespace from dataclasses import dataclass, field from enum import Enum +from typing import Optional from unittest import mock import cloudpickle @@ -94,7 +95,9 @@ def __init__(self, hparams, *my_args, **my_kwargs): # ------------------------- # STANDARD TESTS # ------------------------- -def _run_standard_hparams_test(tmp_path, model, cls, datamodule=None, try_overwrite=False): +def _run_standard_hparams_test( + tmp_path, model, cls, datamodule=None, try_overwrite=False, weights_only: Optional[bool] = None +): """Tests for the existence of an arg 'test_arg=14'.""" obj = datamodule if issubclass(cls, LightningDataModule) else model @@ -108,19 +111,20 @@ def _run_standard_hparams_test(tmp_path, model, cls, datamodule=None, try_overwr # make sure the raw checkpoint saved the properties raw_checkpoint_path = _raw_checkpoint_path(trainer) - raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=False) + raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=weights_only) + assert cls.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint assert raw_checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]["test_arg"] == 14 # verify that model loads correctly - obj2 = cls.load_from_checkpoint(raw_checkpoint_path) + obj2 = cls.load_from_checkpoint(raw_checkpoint_path, weights_only=weights_only) assert obj2.hparams.test_arg == 14 assert isinstance(obj2.hparams, hparam_type) if try_overwrite: # verify that we can overwrite the property - obj3 = cls.load_from_checkpoint(raw_checkpoint_path, test_arg=78) + obj3 = cls.load_from_checkpoint(raw_checkpoint_path, test_arg=78, weights_only=weights_only) assert obj3.hparams.test_arg == 78 return raw_checkpoint_path @@ -175,8 +179,10 @@ def test_omega_conf_hparams(tmp_path, cls): assert isinstance(obj.hparams, Container) # run standard test suite - raw_checkpoint_path = _run_standard_hparams_test(tmp_path, model, cls, datamodule=datamodule) - obj2 = cls.load_from_checkpoint(raw_checkpoint_path) + # weights_only=False as omegaconf.DictConfig is not an allowed global by default + raw_checkpoint_path = _run_standard_hparams_test(tmp_path, model, cls, datamodule=datamodule, weights_only=False) + obj2 = cls.load_from_checkpoint(raw_checkpoint_path, weights_only=False) + assert isinstance(obj2.hparams, Container) # config specific tests @@ -367,13 +373,17 @@ class DictConfSubClassBoringModel: ... BoringModelWithMixinAndInit, ], ) -def test_collect_init_arguments(tmp_path, cls): +def test_collect_init_arguments(tmp_path, cls: BoringModel): """Test that the model automatically saves the arguments passed into the constructor.""" extra_args = {} + weights_only = True + if cls is AggSubClassBoringModel: extra_args.update(my_loss=torch.nn.CosineEmbeddingLoss()) + weights_only = False elif cls is DictConfSubClassBoringModel: extra_args.update(dict_conf=OmegaConf.create({"my_param": "anything"})) + weights_only = False model = cls(**extra_args) assert model.hparams.batch_size == 64 @@ -392,12 +402,12 @@ def test_collect_init_arguments(tmp_path, cls): raw_checkpoint_path = _raw_checkpoint_path(trainer) - raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=False) + raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=weights_only) assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]["batch_size"] == 179 # verify that model loads correctly - model = cls.load_from_checkpoint(raw_checkpoint_path) + model = cls.load_from_checkpoint(raw_checkpoint_path, weights_only=weights_only) assert model.hparams.batch_size == 179 if isinstance(model, AggSubClassBoringModel): @@ -408,7 +418,7 @@ def test_collect_init_arguments(tmp_path, cls): assert model.hparams.dict_conf["my_param"] == "anything" # verify that we can overwrite whatever we want - model = cls.load_from_checkpoint(raw_checkpoint_path, batch_size=99) + model = cls.load_from_checkpoint(raw_checkpoint_path, batch_size=99, weights_only=weights_only) assert model.hparams.batch_size == 99 @@ -781,7 +791,7 @@ def __init__(self, args_0, args_1, args_2, kwarg_1=None): logger=False, ) trainer.fit(model) - _ = TestHydraModel.load_from_checkpoint(checkpoint_callback.best_model_path) + _ = TestHydraModel.load_from_checkpoint(checkpoint_callback.best_model_path, weights_only=False) @pytest.mark.parametrize("ignore", ["arg2", ("arg2", "arg3")]) diff --git a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py index f7a76079cfca2..ef4c652d08660 100644 --- a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py +++ b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py @@ -32,7 +32,9 @@ class CustomCheckpointIO(CheckpointIO): def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: torch.save(checkpoint, path) - def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> dict[str, Any]: + def load_checkpoint( + self, path: _PATH, storage_options: Optional[Any] = None, weights_only: bool = True + ) -> dict[str, Any]: return torch.load(path, weights_only=True) def remove_checkpoint(self, path: _PATH) -> None: