Skip to content

Commit e1a2cfe

Browse files
committed
fix the resume bug: the lr is not related to iteration, but epoch
1 parent 61fe292 commit e1a2cfe

File tree

2 files changed

+36
-41
lines changed

2 files changed

+36
-41
lines changed

deepspeech/exps/deepspeech2/model.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,6 @@ def valid(self):
123123
def setup_model(self):
124124
config = self.config.clone()
125125
config.defrost()
126-
assert (self.train_loader.collate_fn.feature_size ==
127-
self.test_loader.collate_fn.feature_size)
128-
assert (self.train_loader.collate_fn.vocab_size ==
129-
self.test_loader.collate_fn.vocab_size)
130126
config.model.feat_size = self.train_loader.collate_fn.feature_size
131127
config.model.dict_size = self.train_loader.collate_fn.vocab_size
132128
config.freeze()

deepspeech/training/trainer.py

Lines changed: 36 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -29,37 +29,37 @@
2929

3030
class Trainer():
3131
"""
32-
An experiment template in order to structure the training code and take
33-
care of saving, loading, logging, visualization stuffs. It's intended to
34-
be flexible and simple.
35-
36-
So it only handles output directory (create directory for the output,
37-
create a checkpoint directory, dump the config in use and create
32+
An experiment template in order to structure the training code and take
33+
care of saving, loading, logging, visualization stuffs. It's intended to
34+
be flexible and simple.
35+
36+
So it only handles output directory (create directory for the output,
37+
create a checkpoint directory, dump the config in use and create
3838
visualizer and logger) in a standard way without enforcing any
39-
input-output protocols to the model and dataloader. It leaves the main
40-
part for the user to implement their own (setup the model, criterion,
41-
optimizer, define a training step, define a validation function and
39+
input-output protocols to the model and dataloader. It leaves the main
40+
part for the user to implement their own (setup the model, criterion,
41+
optimizer, define a training step, define a validation function and
4242
customize all the text and visual logs).
43-
It does not save too much boilerplate code. The users still have to write
44-
the forward/backward/update mannually, but they are free to add
43+
It does not save too much boilerplate code. The users still have to write
44+
the forward/backward/update mannually, but they are free to add
4545
non-standard behaviors if needed.
4646
We have some conventions to follow.
47-
1. Experiment should have ``model``, ``optimizer``, ``train_loader`` and
47+
1. Experiment should have ``model``, ``optimizer``, ``train_loader`` and
4848
``valid_loader``, ``config`` and ``args`` attributes.
49-
2. The config should have a ``training`` field, which has
50-
``valid_interval``, ``save_interval`` and ``max_iteration`` keys. It is
51-
used as the trigger to invoke validation, checkpointing and stop of the
49+
2. The config should have a ``training`` field, which has
50+
``valid_interval``, ``save_interval`` and ``max_iteration`` keys. It is
51+
used as the trigger to invoke validation, checkpointing and stop of the
5252
experiment.
53-
3. There are four methods, namely ``train_batch``, ``valid``,
53+
3. There are four methods, namely ``train_batch``, ``valid``,
5454
``setup_model`` and ``setup_dataloader`` that should be implemented.
55-
Feel free to add/overwrite other methods and standalone functions if you
55+
Feel free to add/overwrite other methods and standalone functions if you
5656
need.
57-
57+
5858
Parameters
5959
----------
6060
config: yacs.config.CfgNode
6161
The configuration used for the experiment.
62-
62+
6363
args: argparse.Namespace
6464
The parsed command line arguments.
6565
Examples
@@ -68,16 +68,16 @@ class Trainer():
6868
>>> exp = Trainer(config, args)
6969
>>> exp.setup()
7070
>>> exp.run()
71-
>>>
71+
>>>
7272
>>> config = get_cfg_defaults()
7373
>>> parser = default_argument_parser()
7474
>>> args = parser.parse_args()
75-
>>> if args.config:
75+
>>> if args.config:
7676
>>> config.merge_from_file(args.config)
7777
>>> if args.opts:
7878
>>> config.merge_from_list(args.opts)
7979
>>> config.freeze()
80-
>>>
80+
>>>
8181
>>> if args.nprocs > 1 and args.device == "gpu":
8282
>>> dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
8383
>>> else:
@@ -114,7 +114,7 @@ def setup(self):
114114

115115
@property
116116
def parallel(self):
117-
"""A flag indicating whether the experiment should run with
117+
"""A flag indicating whether the experiment should run with
118118
multiprocessing.
119119
"""
120120
return self.args.device == "gpu" and self.args.nprocs > 1
@@ -144,9 +144,9 @@ def save(self, tag=None, infos: dict=None):
144144
self.optimizer, infos)
145145

146146
def resume_or_scratch(self):
147-
"""Resume from latest checkpoint at checkpoints in the output
147+
"""Resume from latest checkpoint at checkpoints in the output
148148
directory or load a specified checkpoint.
149-
149+
150150
If ``args.checkpoint_path`` is not None, load the checkpoint, else
151151
resume training.
152152
"""
@@ -181,8 +181,7 @@ def train(self):
181181
if from_scratch:
182182
# save init model, i.e. 0 epoch
183183
self.save(tag='init', infos=None)
184-
185-
self.lr_scheduler.step(self.iteration)
184+
self.lr_scheduler.step(self.epoch)
186185
if self.parallel:
187186
self.train_loader.batch_sampler.set_epoch(self.epoch)
188187

@@ -254,7 +253,7 @@ def setup_output_dir(self):
254253

255254
def setup_checkpointer(self):
256255
"""Create a directory used to save checkpoints into.
257-
256+
258257
It is "checkpoints" inside the output directory.
259258
"""
260259
# checkpoint dir
@@ -277,13 +276,13 @@ def destory(self):
277276
@mp_tools.rank_zero_only
278277
def setup_visualizer(self):
279278
"""Initialize a visualizer to log the experiment.
280-
279+
281280
The visual log is saved in the output directory.
282-
281+
283282
Notes
284283
------
285-
Only the main process has a visualizer with it. Use multiple
286-
visualizers in multiprocess to write to a same log file may cause
284+
Only the main process has a visualizer with it. Use multiple
285+
visualizers in multiprocess to write to a same log file may cause
287286
unexpected behaviors.
288287
"""
289288
# visualizer
@@ -292,9 +291,9 @@ def setup_visualizer(self):
292291

293292
@mp_tools.rank_zero_only
294293
def dump_config(self):
295-
"""Save the configuration used for this experiment.
296-
297-
It is saved in to ``config.yaml`` in the output directory at the
294+
"""Save the configuration used for this experiment.
295+
296+
It is saved in to ``config.yaml`` in the output directory at the
298297
beginning of the experiment.
299298
"""
300299
with open(self.output_dir / "config.yaml", 'wt') as f:
@@ -312,13 +311,13 @@ def valid(self):
312311
raise NotImplementedError("valid should be implemented.")
313312

314313
def setup_model(self):
315-
"""Setup model, criterion and optimizer, etc. A subclass should
314+
"""Setup model, criterion and optimizer, etc. A subclass should
316315
implement this method.
317316
"""
318317
raise NotImplementedError("setup_model should be implemented.")
319318

320319
def setup_dataloader(self):
321-
"""Setup training dataloader and validation dataloader. A subclass
320+
"""Setup training dataloader and validation dataloader. A subclass
322321
should implement this method.
323322
"""
324323
raise NotImplementedError("setup_dataloader should be implemented.")

0 commit comments

Comments
 (0)