Skip to content

Commit 1ec459f

Browse files
senarviBordalantigaSeppo EnarviSeppo Enarvi
authored
Generic weight averaging callback that supports EMA (#20545)
* Weight averaging callback * A callback that updates a torch.optim.swa_utils.AveragedModel after specific steps or epochs. * The user can provide a callback that defines after which steps or epochs the average model is updated. * More generic customization of the WeightAveraging callback - The user can specify when to update the average model by overriding the should_update() method - Any keyword arguments will be passed to the AveragedModel constructor * Training tricks mentions WeightAveraging and EMA * Removed logging from WeightAveraging * Fixed the documentation * Fixed checkpoint loading with WeightAveraging * WeightAveraging calls the configure_model hook but issues a warning * Fixed a reference in a docstring. * Removed two unit tests to avoid running out of memory in the CI pipeline. * The default device for the averaged model is the device of the original model * Added seealso to WeightAveraging and StochasticWeightAveraging * More verbose description of WeightAveraging * Describe the magic number 7 in a comment * Update src/lightning/pytorch/CHANGELOG.md --------- Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Luca Antiga <luca.antiga@gmail.com> Co-authored-by: Seppo Enarvi <lingo-rise-lesser@duck.com> Co-authored-by: Seppo Enarvi <seppo.git@marjaniemi.com> Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 39e24f5 commit 1ec459f

File tree

11 files changed

+749
-23
lines changed

11 files changed

+749
-23
lines changed

docs/source-pytorch/advanced/training_tricks.rst

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,23 +50,48 @@ Read more about :ref:`Configuring Gradient Clipping <configure_gradient_clipping
5050

5151
----------
5252

53-
***************************
54-
Stochastic Weight Averaging
55-
***************************
53+
****************
54+
Weight Averaging
55+
****************
5656

57-
Stochastic Weight Averaging (SWA) can make your models generalize better at virtually no additional cost.
58-
This can be used with both non-trained and trained models. The SWA procedure smooths the loss landscape thus making
59-
it harder to end up in a local minimum during optimization.
57+
Weight averaging methods such as Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA) can make your
58+
models generalize better at virtually no additional cost. Averaging smooths the loss landscape thus making it harder to
59+
end up in a local minimum during optimization.
6060

61-
For a more detailed explanation of SWA and how it works,
62-
read `this post <https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging>`__ by the PyTorch team.
61+
Lightning provides two callbacks to facilitate weight averaging. :class:`~lightning.pytorch.callbacks.WeightAveraging`
62+
is a generic callback that wraps the
63+
`AveragedModel <https://pytorch.org/docs/stable/generated/torch.optim.swa_utils.AveragedModel.html>`__ class from
64+
PyTorch. It allows SWA, EMA, or a custom averaging strategy to be used. By default, it updates the weights after every
65+
step, but it can be customized to update at specific steps or epochs by overriding the `should_update()` method.
6366

64-
.. seealso:: The :class:`~lightning.pytorch.callbacks.StochasticWeightAveraging` callback
67+
The older :class:`~lightning.pytorch.callbacks.StochasticWeightAveraging` callback is specific to SWA. It starts the SWA
68+
procedure after a certain number of epochs and always runs on every epoch. Additionally, it switches to a constant
69+
learning rate schedule (`SWALR <https://pytorch.org/docs/stable/generated/torch.optim.swa_utils.SWALR.html>`__) when the
70+
procedure starts.
71+
72+
.. seealso::
73+
For a more detailed explanation of SWA and how it works, read
74+
`this post <https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging>`__ by the PyTorch team.
75+
76+
.. seealso::
77+
The :class:`~lightning.pytorch.callbacks.WeightAveraging` callback and
78+
:class:`~lightning.pytorch.callbacks.StochasticWeightAveraging` callback
6579

6680
.. testcode::
6781

68-
# Enable Stochastic Weight Averaging using the callback
69-
trainer = Trainer(callbacks=[StochasticWeightAveraging(swa_lrs=1e-2)])
82+
from lightning.pytorch.callbacks import StochasticWeightAveraging, WeightAveraging
83+
from torch.optim.swa_utils import get_ema_avg_fn
84+
85+
# Enable Exponential Moving Average after 100 steps
86+
class EMAWeightAveraging(WeightAveraging):
87+
def __init__(self):
88+
super().__init__(avg_fn=get_ema_avg_fn())
89+
def should_update(self, step_idx=None, epoch_idx=None):
90+
return (step_idx is not None) and (step_idx >= 100)
91+
trainer = Trainer(callbacks=EMAWeightAveraging())
92+
93+
# Enable Stochastic Weight Averaging after 10 epochs with learning rate 0.01
94+
trainer = Trainer(callbacks=StochasticWeightAveraging(swa_epoch_start=10, swa_lrs=0.01))
7095

7196
----------
7297

docs/source-pytorch/api_references.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ callbacks
4848
ThroughputMonitor
4949
Timer
5050
TQDMProgressBar
51+
WeightAveraging
5152

5253
cli
5354
-----

docs/source-pytorch/extensions/callbacks.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ Lightning has a few built-in callbacks.
8383
StochasticWeightAveraging
8484
Timer
8585
TQDMProgressBar
86+
WeightAveraging
8687

8788
----------
8889

docs/source-pytorch/glossary/index.rst

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,13 @@
4242
Strategy registry <../advanced/strategy_registry>
4343
Strategy integrations <../integrations/strategies/index>
4444
Style guide <../starter/style_guide>
45-
SWA <../advanced/training_tricks>
4645
SLURM <../clouds/cluster_advanced>
4746
Tensor Parallel <../advanced/model_parallel/tp>
4847
Transfer learning <../advanced/transfer_learning>
4948
Trainer <../common/trainer>
5049
TorchRun (TorchElastic) <../clouds/cluster_intermediate_2>
5150
Warnings <../advanced/warnings>
51+
Weight averaging <../advanced/training_tricks>
5252

5353

5454
########
@@ -326,13 +326,6 @@ Glossary
326326
:button_link: ../starter/style_guide.html
327327
:height: 100
328328

329-
.. displayitem::
330-
:header: SWA
331-
:description: Stochastic Weight Averaging (SWA) can make your models generalize better
332-
:col_css: col-md-12
333-
:button_link: ../advanced/training_tricks.html#stochastic-weight-averaging
334-
:height: 100
335-
336329
.. displayitem::
337330
:header: SLURM
338331
:description: Simple Linux Utility for Resource Management, or simply Slurm, is a free and open-source job scheduler for Linux clusters
@@ -375,6 +368,13 @@ Glossary
375368
:button_link: ../advanced/warnings.html
376369
:height: 100
377370

371+
.. displayitem::
372+
:header: Weight averaging
373+
:description: Stochastic Weight Averaging (SWA) or Exponential Moving Average (EMA) can make your models generalize better
374+
:col_css: col-md-12
375+
:button_link: ../advanced/training_tricks.html#weight-averaging
376+
:height: 100
377+
378378
.. raw:: html
379379

380380
</div>

docs/source-pytorch/model/build_model_intermediate.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Enable advanced training features using Trainer arguments. These are SOTA techni
2727
)
2828
2929
# access the latest state of the art techniques
30-
trainer = Trainer(callbacks=[StochasticWeightAveraging(...)])
30+
trainer = Trainer(callbacks=[WeightAveraging(...)])
3131
3232
----
3333

docs/source-pytorch/starter/introduction.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ Enable advanced training features using Trainer arguments. These are state-of-th
252252
)
253253
254254
# access the latest state of the art techniques
255-
trainer = L.Trainer(callbacks=[StochasticWeightAveraging(...)])
255+
trainer = L.Trainer(callbacks=[WeightAveraging(...)])
256256
257257
----
258258

src/lightning/pytorch/CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010

1111
### Added
1212

13-
- Added Torch-Tensorrt Integration with `LightningModule` ([#20808](https://github.com/Lightning-AI/pytorch-lightning/pull/20808))
13+
- Added `WeightAveraging` callback that wraps the PyTorch `AveragedModel` class ([#20545](https://github.com/Lightning-AI/pytorch-lightning/pull/20545))
14+
15+
16+
- Added Torch-Tensorrt integration with `LightningModule` ([#20808](https://github.com/Lightning-AI/pytorch-lightning/pull/20808))
1417

1518

1619
### Changed

src/lightning/pytorch/callbacks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from lightning.pytorch.callbacks.stochastic_weight_avg import StochasticWeightAveraging
3333
from lightning.pytorch.callbacks.throughput_monitor import ThroughputMonitor
3434
from lightning.pytorch.callbacks.timer import Timer
35+
from lightning.pytorch.callbacks.weight_averaging import WeightAveraging
3536

3637
__all__ = [
3738
"BackboneFinetuning",
@@ -58,4 +59,5 @@
5859
"ThroughputMonitor",
5960
"Timer",
6061
"TQDMProgressBar",
62+
"WeightAveraging",
6163
]

src/lightning/pytorch/callbacks/stochastic_weight_avg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(
6565
6666
.. warning:: ``StochasticWeightAveraging`` is currently only supported on every epoch.
6767
68-
See also how to :ref:`enable it directly on the Trainer <advanced/training_tricks:Stochastic Weight Averaging>`
68+
See also how to :ref:`enable it directly on the Trainer <advanced/training_tricks:Weight Averaging>`.
6969
7070
Arguments:
7171

0 commit comments

Comments
 (0)