Skip to content

Commit 601e300

Browse files
committed
Merge branch 'master' into weights-only-compatibility
2 parents f276114 + 8d1a734 commit 601e300

File tree

17 files changed

+1033
-27
lines changed

17 files changed

+1033
-27
lines changed

.github/checkgroup.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ subprojects:
4848
- "!**/*.md"
4949
checks:
5050
- "pytorch-lightning (GPUs) (testing Lightning | latest)"
51+
- "pytorch-lightning (GPUs) (testing PyTorch | oldest)"
5152
- "pytorch-lightning (GPUs) (testing PyTorch | latest)"
5253

5354
- id: "pytorch_lightning: Benchmarks"
@@ -174,6 +175,7 @@ subprojects:
174175
- "!*.md"
175176
- "!**/*.md"
176177
checks:
178+
- "lightning-fabric (GPUs) (testing Fabric | oldest)"
177179
- "lightning-fabric (GPUs) (testing Fabric | latest)"
178180
- "lightning-fabric (GPUs) (testing Lightning | latest)"
179181

.github/workflows/ci-tests-pytorch.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ jobs:
139139
pip install ".[${EXTRA_PREFIX}extra,${EXTRA_PREFIX}test,${EXTRA_PREFIX}strategies]" \
140140
-U --upgrade-strategy=eager --prefer-binary \
141141
-r requirements/_integrations/accelerators.txt \
142-
--extra-index-url="${TORCH_URL}" --find-links="${PYPI_CACHE_DIR}"
142+
--extra-index-url="${TORCH_URL}" --find-links="${PYPI_CACHE_DIR}" --find-links="https://download.pytorch.org/whl/torch-tensorrt"
143143
pip list
144144
- name: Drop LAI from extensions
145145
if: ${{ matrix.pkg-name != 'lightning' }}

docs/source-pytorch/advanced/training_tricks.rst

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,23 +50,48 @@ Read more about :ref:`Configuring Gradient Clipping <configure_gradient_clipping
5050

5151
----------
5252

53-
***************************
54-
Stochastic Weight Averaging
55-
***************************
53+
****************
54+
Weight Averaging
55+
****************
5656

57-
Stochastic Weight Averaging (SWA) can make your models generalize better at virtually no additional cost.
58-
This can be used with both non-trained and trained models. The SWA procedure smooths the loss landscape thus making
59-
it harder to end up in a local minimum during optimization.
57+
Weight averaging methods such as Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA) can make your
58+
models generalize better at virtually no additional cost. Averaging smooths the loss landscape thus making it harder to
59+
end up in a local minimum during optimization.
6060

61-
For a more detailed explanation of SWA and how it works,
62-
read `this post <https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging>`__ by the PyTorch team.
61+
Lightning provides two callbacks to facilitate weight averaging. :class:`~lightning.pytorch.callbacks.WeightAveraging`
62+
is a generic callback that wraps the
63+
`AveragedModel <https://pytorch.org/docs/stable/generated/torch.optim.swa_utils.AveragedModel.html>`__ class from
64+
PyTorch. It allows SWA, EMA, or a custom averaging strategy to be used. By default, it updates the weights after every
65+
step, but it can be customized to update at specific steps or epochs by overriding the `should_update()` method.
6366

64-
.. seealso:: The :class:`~lightning.pytorch.callbacks.StochasticWeightAveraging` callback
67+
The older :class:`~lightning.pytorch.callbacks.StochasticWeightAveraging` callback is specific to SWA. It starts the SWA
68+
procedure after a certain number of epochs and always runs on every epoch. Additionally, it switches to a constant
69+
learning rate schedule (`SWALR <https://pytorch.org/docs/stable/generated/torch.optim.swa_utils.SWALR.html>`__) when the
70+
procedure starts.
71+
72+
.. seealso::
73+
For a more detailed explanation of SWA and how it works, read
74+
`this post <https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging>`__ by the PyTorch team.
75+
76+
.. seealso::
77+
The :class:`~lightning.pytorch.callbacks.WeightAveraging` callback and
78+
:class:`~lightning.pytorch.callbacks.StochasticWeightAveraging` callback
6579

6680
.. testcode::
6781

68-
# Enable Stochastic Weight Averaging using the callback
69-
trainer = Trainer(callbacks=[StochasticWeightAveraging(swa_lrs=1e-2)])
82+
from lightning.pytorch.callbacks import StochasticWeightAveraging, WeightAveraging
83+
from torch.optim.swa_utils import get_ema_avg_fn
84+
85+
# Enable Exponential Moving Average after 100 steps
86+
class EMAWeightAveraging(WeightAveraging):
87+
def __init__(self):
88+
super().__init__(avg_fn=get_ema_avg_fn())
89+
def should_update(self, step_idx=None, epoch_idx=None):
90+
return (step_idx is not None) and (step_idx >= 100)
91+
trainer = Trainer(callbacks=EMAWeightAveraging())
92+
93+
# Enable Stochastic Weight Averaging after 10 epochs with learning rate 0.01
94+
trainer = Trainer(callbacks=StochasticWeightAveraging(swa_epoch_start=10, swa_lrs=0.01))
7095

7196
----------
7297

docs/source-pytorch/api_references.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ callbacks
4848
ThroughputMonitor
4949
Timer
5050
TQDMProgressBar
51+
WeightAveraging
5152

5253
cli
5354
-----

docs/source-pytorch/extensions/callbacks.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ Lightning has a few built-in callbacks.
8383
StochasticWeightAveraging
8484
Timer
8585
TQDMProgressBar
86+
WeightAveraging
8687

8788
----------
8889

docs/source-pytorch/glossary/index.rst

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,13 @@
4242
Strategy registry <../advanced/strategy_registry>
4343
Strategy integrations <../integrations/strategies/index>
4444
Style guide <../starter/style_guide>
45-
SWA <../advanced/training_tricks>
4645
SLURM <../clouds/cluster_advanced>
4746
Tensor Parallel <../advanced/model_parallel/tp>
4847
Transfer learning <../advanced/transfer_learning>
4948
Trainer <../common/trainer>
5049
TorchRun (TorchElastic) <../clouds/cluster_intermediate_2>
5150
Warnings <../advanced/warnings>
51+
Weight averaging <../advanced/training_tricks>
5252

5353

5454
########
@@ -326,13 +326,6 @@ Glossary
326326
:button_link: ../starter/style_guide.html
327327
:height: 100
328328

329-
.. displayitem::
330-
:header: SWA
331-
:description: Stochastic Weight Averaging (SWA) can make your models generalize better
332-
:col_css: col-md-12
333-
:button_link: ../advanced/training_tricks.html#stochastic-weight-averaging
334-
:height: 100
335-
336329
.. displayitem::
337330
:header: SLURM
338331
:description: Simple Linux Utility for Resource Management, or simply Slurm, is a free and open-source job scheduler for Linux clusters
@@ -375,6 +368,13 @@ Glossary
375368
:button_link: ../advanced/warnings.html
376369
:height: 100
377370

371+
.. displayitem::
372+
:header: Weight averaging
373+
:description: Stochastic Weight Averaging (SWA) or Exponential Moving Average (EMA) can make your models generalize better
374+
:col_css: col-md-12
375+
:button_link: ../advanced/training_tricks.html#weight-averaging
376+
:height: 100
377+
378378
.. raw:: html
379379

380380
</div>

docs/source-pytorch/model/build_model_intermediate.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Enable advanced training features using Trainer arguments. These are SOTA techni
2727
)
2828
2929
# access the latest state of the art techniques
30-
trainer = Trainer(callbacks=[StochasticWeightAveraging(...)])
30+
trainer = Trainer(callbacks=[WeightAveraging(...)])
3131
3232
----
3333

docs/source-pytorch/starter/introduction.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ Enable advanced training features using Trainer arguments. These are state-of-th
252252
)
253253
254254
# access the latest state of the art techniques
255-
trainer = L.Trainer(callbacks=[StochasticWeightAveraging(...)])
255+
trainer = L.Trainer(callbacks=[WeightAveraging(...)])
256256
257257
----
258258

requirements/pytorch/test.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,6 @@ fastapi # for `ServableModuleValidator` # not setting version as re-defined in
1818
uvicorn # for `ServableModuleValidator` # not setting version as re-defined in App
1919

2020
tensorboard >=2.9.1, <2.21.0 # for `TensorBoardLogger`
21+
22+
--find-links https://download.pytorch.org/whl/torch-tensorrt
23+
torch-tensorrt; platform_system == "Linux" and python_version >= "3.12"

src/lightning/pytorch/CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010

1111
### Added
1212

13-
-
13+
- Added `WeightAveraging` callback that wraps the PyTorch `AveragedModel` class ([#20545](https://github.com/Lightning-AI/pytorch-lightning/pull/20545))
14+
15+
16+
- Added Torch-Tensorrt integration with `LightningModule` ([#20808](https://github.com/Lightning-AI/pytorch-lightning/pull/20808))
1417

1518

1619
### Changed

0 commit comments

Comments
 (0)