Skip to content

Commit fd6d518

Browse files
Add checkpoint selection option (#573)
Implement a checkpoint selection option through a tag named `context`, and update relevant features --------- Co-authored-by: SanggyuChong <sanggyu.chong@epfl.ch>
1 parent bf06e78 commit fd6d518

File tree

22 files changed

+163
-70
lines changed

22 files changed

+163
-70
lines changed

docs/src/advanced-concepts/auto-restarting.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ When restarting multiple times (for example, when training an expensive model
55
or running on an HPC cluster with short time limits), it is useful to be able
66
to train and restart multiple times with the same command.
77

8-
In ``metatrain``, this functionality is provided via the ``--continue auto``
8+
In ``metatrain``, this functionality is provided via the ``--restart auto``
99
(or ``-c auto``) flag of ``mtt train``. This flag will automatically restart
1010
the training from the last checkpoint, if one is found in the ``outputs/``
1111
of the current directory. If no checkpoint is found, the training will start

docs/src/dev-docs/new-architecture.rst

+24-2
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,18 @@ method.
105105
self.dataset_info = dataset_info
106106
107107
@classmethod
108-
def load_checkpoint(cls, checkpoint: Dict[str, Any]) -> "ModelInterface":
108+
def load_checkpoint(
109+
cls,
110+
checkpoint: Dict[str, Any],
111+
context: Literal["restart", "finetune", "export"],
112+
) -> "ModelInterface":
113+
"""Create a model from a checkpoint's state dictionary.
114+
115+
:param checkpoint: Checkpoint's state dictionary.
116+
:param context: Purpose of the model to load from the checkpoint file.
117+
Required values are "restart" and "finetune", "export" but can be
118+
extended to other values.
119+
"""
109120
pass
110121
111122
def restart(cls, dataset_info: DatasetInfo) -> "ModelInterface":
@@ -168,8 +179,19 @@ methods for ``train()``, ``save_checkpoint()`` and ``load_checkpoint()``.
168179
169180
@classmethod
170181
def load_checkpoint(
171-
cls, checkpoint: Dict[str, Any], train_hypers: Dict[str, Any]
182+
cls,
183+
checkpoint: Dict[str, Any],
184+
train_hypers: Dict[str, Any],
185+
context: Literal["restart", "finetune"],
172186
) -> "TrainerInterface":
187+
"""Create a trainer from a checkpoint's state dictionary.
188+
189+
:param checkpoint: Checkpoint's state dictionary.
190+
:param context: Purpose of the model to load from the checkpoint file.
191+
Required values are "restart" and "finetune" but can be
192+
extended to other values.
193+
:param train_hypers: Hyperparameters used to create the trainer.
194+
"""
173195
pass
174196
175197
The format of checkpoints is not defined by ``metatrain`` and can be any format that

docs/src/getting-started/checkpoints.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ The sub-command to continue training from a checkpoint is
1818

1919
.. code-block:: bash
2020
21-
mtt train options.yaml --continue model.ckpt
21+
mtt train options.yaml --restart model.ckpt
2222
2323
or
2424

pyproject.toml

+2-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ dynamic = ["version"]
44
requires-python = ">=3.9"
55

66
readme = "README.rst"
7-
license = {text = "BSD-3-Clause"}
7+
license = "BSD-3-Clause"
88
description = "Training and evaluating machine learning models for atomistic systems."
99
authors = [{name = "metatrain developers"}]
1010

@@ -24,7 +24,6 @@ keywords = ["machine learning", "molecular modeling"]
2424
classifiers = [
2525
"Development Status :: 4 - Beta",
2626
"Intended Audience :: Science/Research",
27-
"License :: OSI Approved :: BSD License",
2827
"Operating System :: POSIX",
2928
"Operating System :: MacOS :: MacOS X",
3029
"Operating System :: Microsoft :: Windows",
@@ -52,7 +51,7 @@ mtt = "metatrain.__main__:main"
5251

5352
[build-system]
5453
requires = [
55-
"setuptools >= 68",
54+
"setuptools >= 77",
5655
"setuptools_scm>=8",
5756
"wheel",
5857
]

src/metatrain/cli/train.py

+23-19
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,14 @@ def _add_train_model_parser(subparser: argparse._SubParsersAction) -> None:
8888
),
8989
)
9090
parser.add_argument(
91-
"-c",
92-
"--continue",
93-
dest="continue_from",
94-
type=_process_continue_from,
91+
"--restart",
92+
dest="restart_from",
93+
type=_process_restart_from,
9594
required=False,
96-
help="Checkpoint file (.ckpt) to continue training from.",
95+
help=(
96+
"Checkpoint file (.ckpt) to continue interrupted training. "
97+
"Set to `'auto'` to use latest checkpoint from the outputs directory."
98+
),
9799
)
98100
parser.add_argument(
99101
"-r",
@@ -115,9 +117,9 @@ def _prepare_train_model_args(args: argparse.Namespace) -> None:
115117
args.options = OmegaConf.merge(args.options, override_options)
116118

117119

118-
def _process_continue_from(continue_from: str) -> Optional[str]:
119-
# covers the case where `continue_from` is `auto`
120-
if continue_from == "auto":
120+
def _process_restart_from(restart_from: str) -> Optional[str]:
121+
# covers the case where `restart_from` is `auto`
122+
if restart_from == "auto":
121123
# try to find the `outputs` directory; if it doesn't exist
122124
# then we are not continuing from a previous run
123125
if Path("outputs/").exists():
@@ -129,12 +131,12 @@ def _process_continue_from(continue_from: str) -> Optional[str]:
129131
# `sorted` because some checkpoint files are named with
130132
# the epoch number (e.g. `epoch_10.ckpt` would be before
131133
# `epoch_8.ckpt`). We therefore sort by file creation time.
132-
new_continue_from = str(
134+
new_restart_from = str(
133135
sorted(dir.glob("*.ckpt"), key=lambda f: f.stat().st_ctime)[-1]
134136
)
135-
logging.info(f"Auto-continuing from `{new_continue_from}`")
137+
logging.info(f"Auto-continuing from `{new_restart_from}`")
136138
else:
137-
new_continue_from = None
139+
new_restart_from = None
138140
logging.info(
139141
"Auto-continuation did not find any previous runs, "
140142
"training from scratch"
@@ -145,17 +147,17 @@ def _process_continue_from(continue_from: str) -> Optional[str]:
145147
# still executing this function
146148
time.sleep(3)
147149
else:
148-
new_continue_from = continue_from
150+
new_restart_from = restart_from
149151

150-
return new_continue_from
152+
return new_restart_from
151153

152154

153155
def train_model(
154156
options: Union[DictConfig, Dict],
155157
output: str = "model.pt",
156158
extensions: str = "extensions/",
157159
checkpoint_dir: Union[str, Path] = ".",
158-
continue_from: Optional[str] = None,
160+
restart_from: Optional[str] = None,
159161
) -> None:
160162
"""Train an atomistic machine learning model using provided ``options``.
161163
@@ -169,7 +171,7 @@ def train_model(
169171
:param output: Path to save the final model
170172
:param checkpoint_dir: Path to save checkpoints and other intermediate output files
171173
like the fully expanded training options for a later restart.
172-
:param continue_from: File to continue training from.
174+
:param restart_from: File to continue training from.
173175
"""
174176
###########################
175177
# VALIDATE BASE OPTIONS ###
@@ -439,10 +441,12 @@ def train_model(
439441

440442
logging.info("Setting up model")
441443
try:
442-
if continue_from is not None:
443-
logging.info(f"Loading checkpoint from `{continue_from}`")
444-
trainer = trainer_from_checkpoint(continue_from, hypers["training"])
445-
model = model_from_checkpoint(continue_from)
444+
if restart_from is not None:
445+
logging.info(f"Restarting training from `{restart_from}`")
446+
trainer = trainer_from_checkpoint(
447+
path=restart_from, context="restart", hypers=hypers["training"]
448+
)
449+
model = model_from_checkpoint(path=restart_from, context="restart")
446450
model = model.restart(dataset_info)
447451
else:
448452
model = Model(hypers["model"], dataset_info)

src/metatrain/deprecated/pet/model.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Optional
1+
from typing import Any, Dict, List, Literal, Optional
22

33
import metatensor.torch
44
import torch
@@ -243,7 +243,11 @@ def forward(
243243
return output_quantities
244244

245245
@classmethod
246-
def load_checkpoint(cls, checkpoint: Dict[str, Any]) -> "PET":
246+
def load_checkpoint(
247+
cls,
248+
checkpoint: Dict[str, Any],
249+
context: Literal["restart", "finetune", "export"],
250+
) -> "PET":
247251
hypers = checkpoint["hypers"]
248252
model_hypers = hypers["ARCHITECTURAL_HYPERS"]
249253
dataset_info = checkpoint["dataset_info"]

src/metatrain/deprecated/pet/trainer.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import time
66
import warnings
77
from pathlib import Path
8-
from typing import Any, Dict, List, Union
8+
from typing import Any, Dict, List, Literal, Union
99

1010
import numpy as np
1111
import torch
@@ -784,7 +784,10 @@ def save_checkpoint(self, model, path: Union[str, Path]):
784784

785785
@classmethod
786786
def load_checkpoint(
787-
cls, checkpoint: Dict[str, Any], train_hypers: Dict[str, Any]
787+
cls,
788+
checkpoint: Dict[str, Any],
789+
train_hypers: Dict,
790+
context: Literal["restart", "finetune"],
788791
) -> "Trainer":
789792
# This function takes a metatrain PET checkpoint and returns a Trainer
790793
# instance with the hypers, while also saving the checkpoint in the

src/metatrain/experimental/nanopet/model.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import warnings
22
from math import prod
3-
from typing import Any, Dict, List, Optional
3+
from typing import Any, Dict, List, Literal, Optional
44

55
import metatensor.torch
66
import torch
@@ -551,9 +551,19 @@ def requested_neighbor_lists(
551551
return [self.requested_nl]
552552

553553
@classmethod
554-
def load_checkpoint(cls, checkpoint: Dict[str, Any]) -> "NanoPET":
554+
def load_checkpoint(
555+
cls,
556+
checkpoint: Dict[str, Any],
557+
context: Literal["restart", "finetune", "export"],
558+
) -> "NanoPET":
555559
model_data = checkpoint["model_data"]
556-
model_state_dict = checkpoint["model_state_dict"]
560+
561+
if context == "restart":
562+
model_state_dict = checkpoint["model_state_dict"]
563+
elif context == "finetune" or context == "export":
564+
model_state_dict = checkpoint["best_model_state_dict"]
565+
else:
566+
raise ValueError("Unknown context tag for checkpoint loading!")
557567

558568
# Create the model
559569
model = cls(**model_data)

src/metatrain/experimental/nanopet/tests/test_continue.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def test_continue(monkeypatch, tmp_path):
7070

7171
trainer.save_checkpoint(model, "tmp.ckpt")
7272

73-
model_after = model_from_checkpoint("tmp.ckpt")
73+
model_after = model_from_checkpoint("tmp.ckpt", context="restart")
7474
assert isinstance(model_after, NanoPET)
7575
model_after.restart(dataset_info)
7676

src/metatrain/experimental/nanopet/trainer.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import copy
22
import logging
33
from pathlib import Path
4-
from typing import Any, Dict, List, Union
4+
from typing import Any, Dict, List, Literal, Union
55

66
import torch
77
import torch.distributed
@@ -494,7 +494,10 @@ def save_checkpoint(self, model, path: Union[str, Path]):
494494

495495
@classmethod
496496
def load_checkpoint(
497-
cls, checkpoint: Dict[str, Any], train_hypers: Dict[str, Any]
497+
cls,
498+
checkpoint: Dict[str, Any],
499+
train_hypers: Dict[str, Any],
500+
context: Literal["restart", "finetune"], # not used at the moment
498501
) -> "Trainer":
499502
epoch = checkpoint["epoch"]
500503
optimizer_state_dict = checkpoint["optimizer_state_dict"]

src/metatrain/gap/trainer.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Any, Dict, List, Union
2+
from typing import Any, Dict, List, Literal, Union
33

44
import metatensor
55
import metatensor.torch
@@ -145,6 +145,9 @@ def save_checkpoint(self, model, checkpoint_dir: str):
145145

146146
@classmethod
147147
def load_checkpoint(
148-
cls, checkpoint: Dict[str, Any], hypers_train: Dict[str, Any]
148+
cls,
149+
checkpoint: Dict[str, Any],
150+
hypers_train: Dict[str, Any],
151+
context: Literal["restart", "finetune"],
149152
) -> "GAP":
150153
raise ValueError("GAP does not allow restarting training")

src/metatrain/pet/model.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import warnings
22
from math import prod
3-
from typing import Any, Dict, List, Optional
3+
from typing import Any, Dict, List, Literal, Optional
44

55
import metatensor.torch
66
import torch
@@ -656,9 +656,20 @@ def forward(
656656
return return_dict
657657

658658
@classmethod
659-
def load_checkpoint(cls, checkpoint: Dict[str, Any]) -> "PET":
659+
def load_checkpoint(
660+
cls,
661+
checkpoint: Dict[str, Any],
662+
context: Literal["restart", "finetune", "export"],
663+
) -> "PET":
660664
model_data = checkpoint["model_data"]
661-
model_state_dict = checkpoint["model_state_dict"]
665+
666+
if context == "restart":
667+
model_state_dict = checkpoint["model_state_dict"]
668+
elif context == "finetune" or context == "export":
669+
model_state_dict = checkpoint["best_model_state_dict"]
670+
else:
671+
raise ValueError("Unknown context tag for checkpoint loading!")
672+
662673
finetune_config = checkpoint["train_hypers"].get("finetune", {})
663674

664675
# Create the model

src/metatrain/pet/tests/test_continue.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def test_continue(monkeypatch, tmp_path):
7070

7171
trainer.save_checkpoint(model, "tmp.ckpt")
7272

73-
model_after = model_from_checkpoint("tmp.ckpt")
73+
model_after = model_from_checkpoint("tmp.ckpt", context="restart")
7474
assert isinstance(model_after, PET)
7575
model_after.restart(dataset_info)
7676

src/metatrain/pet/tests/test_finetuning.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def test_finetuning_restart(monkeypatch, tmp_path):
111111

112112
hypers = DEFAULT_HYPERS.copy()
113113

114-
hypers["training"]["num_epochs"] = 0
114+
hypers["training"]["num_epochs"] = 1
115115

116116
# Pre-training
117117
trainer = Trainer(hypers["training"])
@@ -126,7 +126,7 @@ def test_finetuning_restart(monkeypatch, tmp_path):
126126
trainer.save_checkpoint(model, "tmp.ckpt")
127127

128128
# Finetuning
129-
model_finetune = model_from_checkpoint("tmp.ckpt")
129+
model_finetune = model_from_checkpoint("tmp.ckpt", context="finetune")
130130
assert isinstance(model_finetune, PET)
131131
model_finetune.restart(dataset_info)
132132

@@ -158,7 +158,7 @@ def test_finetuning_restart(monkeypatch, tmp_path):
158158
assert any(["lora_" in name for name, _ in model_finetune.named_parameters()])
159159

160160
# Finetuning restart
161-
model_finetune_restart = model_from_checkpoint("finetuned.ckpt")
161+
model_finetune_restart = model_from_checkpoint("finetuned.ckpt", context="restart")
162162
assert isinstance(model_finetune_restart, PET)
163163
model_finetune_restart.restart(dataset_info)
164164

src/metatrain/pet/trainer.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import copy
22
import logging
33
from pathlib import Path
4-
from typing import Any, Dict, List, Union
4+
from typing import Any, Dict, List, Literal, Union
55

66
import torch
77
from torch.optim.lr_scheduler import LambdaLR
@@ -511,7 +511,10 @@ def save_checkpoint(self, model, path: Union[str, Path]):
511511

512512
@classmethod
513513
def load_checkpoint(
514-
cls, checkpoint: Dict[str, Any], train_hypers: Dict[str, Any]
514+
cls,
515+
checkpoint: Dict[str, Any],
516+
context: Literal["restart", "finetune", "export"], # not used at the moment
517+
train_hypers: Dict[str, Any],
515518
) -> "Trainer":
516519
epoch = checkpoint["epoch"]
517520
optimizer_state_dict = checkpoint["optimizer_state_dict"]

0 commit comments

Comments
 (0)