Skip to content

Exposeweights_only for loading checkpoints with Trainer, LightningModule, LightningDataModule #21072

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 31 commits into
base: master
Choose a base branch
from

Conversation

matsumotosan
Copy link
Contributor

@matsumotosan matsumotosan commented Aug 14, 2025

What does this PR do?

Fixes #20450 #20058 #20643

Before submitting
  • Was this discussed/agreed via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

Reviewer checklist
  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

📚 Documentation preview 📚: https://pytorch-lightning--21072.org.readthedocs.build/en/21072/

@github-actions github-actions bot added the fabric lightning.fabric.Fabric label Aug 14, 2025
Copy link

codecov bot commented Aug 15, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 87%. Comparing base (7e9cea4) to head (7d6174a).

Additional details and impacted files
@@           Coverage Diff           @@
##           master   #21072   +/-   ##
=======================================
  Coverage      87%      87%           
=======================================
  Files         269      269           
  Lines       23515    23519    +4     
=======================================
+ Hits        20500    20504    +4     
  Misses       3015     3015           

@matsumotosan matsumotosan marked this pull request as draft August 15, 2025 18:21
@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Aug 15, 2025
@matsumotosan matsumotosan force-pushed the weights-only-compatibility branch from d7cb702 to 601e300 Compare August 15, 2025 22:20
@@ -56,11 +56,17 @@ def _load_from_checkpoint(
map_location: _MAP_LOCATION_TYPE = None,
hparams_file: Optional[_PATH] = None,
strict: Optional[bool] = None,
weights_only: Optional[bool] = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we default to weights_only=None or weights_only=True? If we have no use for weights_only=None, we can simplify the type hint to weights_only: bool = True.

@@ -45,7 +46,12 @@ def test_load_legacy_checkpoints(tmp_path, pl_version: str):
assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"'
path_ckpt = path_ckpts[-1]

model = ClassificationModel.load_from_checkpoint(path_ckpt, num_features=24)
# legacy load utility added in 1.5.0 (see https://github.com/Lightning-AI/pytorch-lightning/pull/9166)
if pl_version == "local":
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the simplest way that I could think of ensuring we continue testing the legacy checkpoints. Another way could be to use torch.serialization.add_safe_globals, but it seems a little more complicated (particularly since we're using the pl_legacy_patch context manager already.

@matsumotosan matsumotosan marked this pull request as ready for review August 16, 2025 15:37
@matsumotosan matsumotosan changed the title Compatibility for weights_only=True by default Compatibility for weights_only=True by default for loading weights Aug 16, 2025
@matsumotosan
Copy link
Contributor Author

@Borda I wanted to get your opinion on something before moving forward.

I've added weights_only as an argument to LightningModule.load_from_checkpoint and all downstream functions to allow users to determine which option they want to use to load checkpoints.

My issue right now is with resuming training from a checkpoint with Trainer.fit. I see a few options right now:

  1. Add weights_only as an argument to Trainer.fit (would also have to modify args for validate, test, and predict). Set default value to True.
  2. Use weights_only=True everywhere, and print an error message advising user to set TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD if they would like to load with weights_only=False. Users must explicitly set environment variable to force loading with weights_only=False.
  3. Add weights_only as an argument to Trainer initialization. Easy, but would not allow fine-grained control on loading models between different calls of fit, validate, etc.

I'm leaning towards option 1, but it involves changing up Trainer methods, which affects a lot of code so wanted to run this by you beforehand.

@Borda
Copy link
Member

Borda commented Aug 18, 2025

My issue right now is with resuming training from a checkpoint with Trainer.fit. I see a few options right now:

  1. Add weights_only as an argument to Trainer.fit (would also have to modify args for validate, test, and predict). Set default value to True.
  2. Use weights_only=True everywhere, and print an error message advising user to set TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD if they would like to load with weights_only=False. Users must explicitly set environment variable to force loading with weights_only=False.
  3. Add weights_only as an argument to Trainer initialization. Easy, but would not allow fine-grained control on loading models between different calls of fit, validate, etc.

The cleanest way would probably be 1), but it brings so many new arguments for a marginal use... so personally I would go with 2)
cc: @lantiga

@matsumotosan
Copy link
Contributor Author

@Borda I will go ahead with changing all cases of torch.load to use weights_only=True.

This will cause a lot of errors with checkpoints from previous versions, so I'll update the docs/warning messages as well to inform users to use either the context manager or global environment variable.

@lantiga
Copy link
Collaborator

lantiga commented Aug 19, 2025

Hi @matsumotosan let's do that only if the underlying torch is >= 2.6 (since starting weights_only became True by default from that point on), otherwise we're going to break a lot of older code

Borda and others added 3 commits August 19, 2025 10:17
@github-actions github-actions bot added ci Continuous Integration dependencies Pull requests that update a dependency file dockers package labels Aug 19, 2025
@matsumotosan
Copy link
Contributor Author

matsumotosan commented Aug 21, 2025

I am not sure if it's possible to default to weights_only=True for trainer.{test,validate,test,predict}, since it loads a checkpoint at some point and that checkpoint may include elements that are not allowed by torch.load(..., weights_only=True) (Lightning itself saves classes like ModelCheckpoint), which cannot be loaded unless weights_only=False or a context manager is used.

The big issue with context managers is that a different one has to be used each time a different checkpoint is loaded. Setting the environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 does not work as it is overriden by the value passed to torch.load.

With this in mind, I think passing weights_only to trainer.{fit,validate,test,predict} may be the only short-term solution.

If we need to force the weights_only=True solution, every item saved in the checkpoint would need to be converted into a primitive type that torch.load(.., weights_only=True) would accept.

I have also added weights_only as an argument to the LightningModule and LightningDataModule classes as there are a few issues that point this out as a source of error:

Maybe we could default add weights_only and have it default to False for now until a future release so that we don't break users' code whilst adding an explicit option for this?

@matsumotosan matsumotosan changed the title Compatibility for weights_only=True by default for loading weights Exposeweights_only for loading checkpoints with Trainer, LightningModule, LightningDataModule Aug 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci Continuous Integration dependencies Pull requests that update a dependency file dockers fabric lightning.fabric.Fabric package pl Generic label for PyTorch Lightning package
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Make sure the upcoming change in the default for weights_only from False to True is handled correctly
4 participants