-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Exposeweights_only
for loading checkpoints with Trainer
, LightningModule
, LightningDataModule
#21072
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Exposeweights_only
for loading checkpoints with Trainer
, LightningModule
, LightningDataModule
#21072
Changes from 9 commits
074b01e
65cc1ed
4eaaf58
f276114
601e300
28f53ae
b1cfdf1
4d96a78
4c39c30
861d7e0
12bd0d6
5eacb6e
0430e22
2abe915
8e0f61e
525d9a8
83fd824
93cbe94
2a53f2f
3833892
9d8997e
561c02c
c67c8a3
005c439
2ab89a2
74e5e5a
a4c9efe
2c2ab9e
54b859a
7ddb4f8
7d6174a
906e52e
377cf11
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
|
||
import pytest | ||
import torch | ||
from packaging.version import Version | ||
|
||
import lightning.pytorch as pl | ||
from lightning.pytorch import Callback, Trainer | ||
|
@@ -45,7 +46,12 @@ def test_load_legacy_checkpoints(tmp_path, pl_version: str): | |
assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"' | ||
path_ckpt = path_ckpts[-1] | ||
|
||
model = ClassificationModel.load_from_checkpoint(path_ckpt, num_features=24) | ||
# legacy load utility added in 1.5.0 (see https://github.com/Lightning-AI/pytorch-lightning/pull/9166) | ||
if pl_version == "local": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the simplest way that I could think of ensuring we continue testing the legacy checkpoints. Another way could be to use |
||
pl_version = pl.__version__ | ||
weights_only = not Version(pl_version) < Version("1.5.0") | ||
|
||
model = ClassificationModel.load_from_checkpoint(path_ckpt, num_features=24, weights_only=weights_only) | ||
trainer = Trainer(default_root_dir=tmp_path) | ||
dm = ClassifDataModule(num_features=24, length=6000, batch_size=128, n_clusters_per_class=2, n_informative=8) | ||
res = trainer.test(model, datamodule=dm) | ||
|
@@ -73,13 +79,18 @@ def test_legacy_ckpt_threading(pl_version: str): | |
assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"' | ||
path_ckpt = path_ckpts[-1] | ||
|
||
# legacy load utility added in 1.5.0 (see https://github.com/Lightning-AI/pytorch-lightning/pull/9166) | ||
if pl_version == "local": | ||
pl_version = pl.__version__ | ||
weights_only = not Version(pl_version) < Version("1.5.0") | ||
|
||
def load_model(): | ||
import torch | ||
|
||
from lightning.pytorch.utilities.migration import pl_legacy_patch | ||
|
||
with pl_legacy_patch(): | ||
_ = torch.load(path_ckpt, weights_only=False) | ||
_ = torch.load(path_ckpt, weights_only=weights_only) | ||
|
||
with patch("sys.path", [PATH_LEGACY] + sys.path): | ||
t1 = ThreadExceptionHandler(target=load_model) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we default to
weights_only=None
orweights_only=True
? If we have no use forweights_only=None
, we can simplify the type hint toweights_only: bool = True
.