@@ -88,12 +88,14 @@ def _add_train_model_parser(subparser: argparse._SubParsersAction) -> None:
88
88
),
89
89
)
90
90
parser .add_argument (
91
- "-c" ,
92
- "--continue" ,
93
- dest = "continue_from" ,
94
- type = _process_continue_from ,
91
+ "--restart" ,
92
+ dest = "restart_from" ,
93
+ type = _process_restart_from ,
95
94
required = False ,
96
- help = "Checkpoint file (.ckpt) to continue training from." ,
95
+ help = (
96
+ "Checkpoint file (.ckpt) to continue interrupted training. "
97
+ "Set to `'auto'` to use latest checkpoint from the outputs directory."
98
+ ),
97
99
)
98
100
parser .add_argument (
99
101
"-r" ,
@@ -115,9 +117,9 @@ def _prepare_train_model_args(args: argparse.Namespace) -> None:
115
117
args .options = OmegaConf .merge (args .options , override_options )
116
118
117
119
118
- def _process_continue_from ( continue_from : str ) -> Optional [str ]:
119
- # covers the case where `continue_from ` is `auto`
120
- if continue_from == "auto" :
120
+ def _process_restart_from ( restart_from : str ) -> Optional [str ]:
121
+ # covers the case where `restart_from ` is `auto`
122
+ if restart_from == "auto" :
121
123
# try to find the `outputs` directory; if it doesn't exist
122
124
# then we are not continuing from a previous run
123
125
if Path ("outputs/" ).exists ():
@@ -129,12 +131,12 @@ def _process_continue_from(continue_from: str) -> Optional[str]:
129
131
# `sorted` because some checkpoint files are named with
130
132
# the epoch number (e.g. `epoch_10.ckpt` would be before
131
133
# `epoch_8.ckpt`). We therefore sort by file creation time.
132
- new_continue_from = str (
134
+ new_restart_from = str (
133
135
sorted (dir .glob ("*.ckpt" ), key = lambda f : f .stat ().st_ctime )[- 1 ]
134
136
)
135
- logging .info (f"Auto-continuing from `{ new_continue_from } `" )
137
+ logging .info (f"Auto-continuing from `{ new_restart_from } `" )
136
138
else :
137
- new_continue_from = None
139
+ new_restart_from = None
138
140
logging .info (
139
141
"Auto-continuation did not find any previous runs, "
140
142
"training from scratch"
@@ -145,17 +147,17 @@ def _process_continue_from(continue_from: str) -> Optional[str]:
145
147
# still executing this function
146
148
time .sleep (3 )
147
149
else :
148
- new_continue_from = continue_from
150
+ new_restart_from = restart_from
149
151
150
- return new_continue_from
152
+ return new_restart_from
151
153
152
154
153
155
def train_model (
154
156
options : Union [DictConfig , Dict ],
155
157
output : str = "model.pt" ,
156
158
extensions : str = "extensions/" ,
157
159
checkpoint_dir : Union [str , Path ] = "." ,
158
- continue_from : Optional [str ] = None ,
160
+ restart_from : Optional [str ] = None ,
159
161
) -> None :
160
162
"""Train an atomistic machine learning model using provided ``options``.
161
163
@@ -169,7 +171,7 @@ def train_model(
169
171
:param output: Path to save the final model
170
172
:param checkpoint_dir: Path to save checkpoints and other intermediate output files
171
173
like the fully expanded training options for a later restart.
172
- :param continue_from : File to continue training from.
174
+ :param restart_from : File to continue training from.
173
175
"""
174
176
###########################
175
177
# VALIDATE BASE OPTIONS ###
@@ -439,10 +441,12 @@ def train_model(
439
441
440
442
logging .info ("Setting up model" )
441
443
try :
442
- if continue_from is not None :
443
- logging .info (f"Loading checkpoint from `{ continue_from } `" )
444
- trainer = trainer_from_checkpoint (continue_from , hypers ["training" ])
445
- model = model_from_checkpoint (continue_from )
444
+ if restart_from is not None :
445
+ logging .info (f"Restarting training from `{ restart_from } `" )
446
+ trainer = trainer_from_checkpoint (
447
+ path = restart_from , context = "restart" , hypers = hypers ["training" ]
448
+ )
449
+ model = model_from_checkpoint (path = restart_from , context = "restart" )
446
450
model = model .restart (dataset_info )
447
451
else :
448
452
model = Model (hypers ["model" ], dataset_info )
0 commit comments