From 3858050b7925e8632a01ff17f566cfe6eac019a4 Mon Sep 17 00:00:00 2001 From: Furkan-rgb <50831308+Furkan-rgb@users.noreply.github.com> Date: Wed, 30 Apr 2025 06:45:08 +0000 Subject: [PATCH 1/3] Added 'start_from_epoch' parameter to early stopping cb for cases whereby the user wants to start monitoring after a warming up period. --- .../pytorch/callbacks/early_stopping.py | 26 ++++++++++++++----- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/src/lightning/pytorch/callbacks/early_stopping.py b/src/lightning/pytorch/callbacks/early_stopping.py index d108894f614e6..11b117961056f 100644 --- a/src/lightning/pytorch/callbacks/early_stopping.py +++ b/src/lightning/pytorch/callbacks/early_stopping.py @@ -26,10 +26,10 @@ from torch import Tensor from typing_extensions import override -import lightning.pytorch as pl -from lightning.pytorch.callbacks.callback import Callback -from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning.pytorch.utilities.rank_zero import rank_prefixed_message, rank_zero_warn +import pytorch_lightning as pl +from pytorch_lightning.callbacks.callback import Callback +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.rank_zero import rank_prefixed_message, rank_zero_warn log = logging.getLogger(__name__) @@ -64,6 +64,8 @@ class EarlyStopping(Callback): check_on_train_epoch_end: whether to run early stopping at the end of the training epoch. If this is ``False``, then the check runs at the end of the validation. log_rank_zero_only: When set ``True``, logs the status of the early stopping callback only for rank 0 process. + start_from_epoch: the epoch from which to start monitoring for early stopping. Defaults to 0 (start from the + beginning). Set to a higher value to let the model train for a minimum number of epochs before monitoring. Raises: MisconfigurationException: @@ -73,10 +75,14 @@ class EarlyStopping(Callback): Example:: - >>> from lightning.pytorch import Trainer - >>> from lightning.pytorch.callbacks import EarlyStopping + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.callbacks import EarlyStopping >>> early_stopping = EarlyStopping('val_loss') >>> trainer = Trainer(callbacks=[early_stopping]) + + >>> # Start monitoring only from epoch 5 + >>> early_stopping = EarlyStopping('val_loss', start_from_epoch=5) + >>> trainer = Trainer(callbacks=[early_stopping]) .. tip:: Saving and restoring multiple early stopping callbacks at the same time is supported under variation in the following arguments: @@ -104,6 +110,7 @@ def __init__( divergence_threshold: Optional[float] = None, check_on_train_epoch_end: Optional[bool] = None, log_rank_zero_only: bool = False, + start_from_epoch: int = 0, ): super().__init__() self.monitor = monitor @@ -119,6 +126,7 @@ def __init__( self.stopped_epoch = 0 self._check_on_train_epoch_end = check_on_train_epoch_end self.log_rank_zero_only = log_rank_zero_only + self.start_from_epoch = start_from_epoch if self.mode not in self.mode_dict: raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}") @@ -179,7 +187,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: self.patience = state_dict["patience"] def _should_skip_check(self, trainer: "pl.Trainer") -> bool: - from lightning.pytorch.trainer.states import TrainerFn + from pytorch_lightning.trainer.states import TrainerFn return trainer.state.fn != TrainerFn.FITTING or trainer.sanity_checking @@ -197,6 +205,10 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None: """Checks whether the early stopping condition is met and if so tells the trainer to stop the training.""" + # Skip early stopping check if current epoch is less than start_from_epoch + if trainer.current_epoch < self.start_from_epoch: + return + logs = trainer.callback_metrics if trainer.fast_dev_run or not self._validate_condition_metric( # disable early_stopping with fast_dev_run From 014d3da7e002b6de5291ac1090d3337b5c82ca56 Mon Sep 17 00:00:00 2001 From: Furkan-rgb <50831308+Furkan-rgb@users.noreply.github.com> Date: Wed, 30 Apr 2025 06:52:04 +0000 Subject: [PATCH 2/3] test: add early stopping test for start_from_epoch functionality --- .../callbacks/test_early_stopping.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/tests_pytorch/callbacks/test_early_stopping.py b/tests/tests_pytorch/callbacks/test_early_stopping.py index 9a87b3daaad6e..026f6d655fada 100644 --- a/tests/tests_pytorch/callbacks/test_early_stopping.py +++ b/tests/tests_pytorch/callbacks/test_early_stopping.py @@ -505,3 +505,36 @@ def test_early_stopping_log_info(log_rank_zero_only, world_size, global_rank, ex log_mock.assert_called_once_with(expected_log) else: log_mock.assert_not_called() + + +def test_early_stopping_start_from_epoch(tmp_path): + """Test that early stopping checks only activate after start_from_epoch.""" + losses = [6, 5, 4, 3, 2, 1] # decreasing losses + start_from_epoch = 3 + expected_stop_epoch = None # Should not stop early + + class CurrentModel(BoringModel): + def on_validation_epoch_end(self): + val_loss = losses[self.current_epoch] + self.log("val_loss", val_loss) + + model = CurrentModel() + + # Mock the _run_early_stopping_check method to verify when it's called + with mock.patch("lightning.pytorch.callbacks.early_stopping.EarlyStopping._evaluate_stopping_criteria") as es_mock: + es_mock.return_value = (False, "") + early_stopping = EarlyStopping(monitor="val_loss", start_from_epoch=start_from_epoch) + trainer = Trainer( + default_root_dir=tmp_path, + callbacks=[early_stopping], + limit_train_batches=0.2, + limit_val_batches=0.2, + max_epochs=len(losses), + ) + trainer.fit(model) + + # Check that _evaluate_stopping_criteria is not called for epochs before start_from_epoch + assert es_mock.call_count == len(losses) - start_from_epoch + # Check that only the correct epochs were processed + for i, call_args in enumerate(es_mock.call_args_list): + assert torch.allclose(call_args[0][0], torch.tensor(losses[i + start_from_epoch])) From 8bbb6d52c944e32deb6c069f8b5bb502ebfc431f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 30 Apr 2025 06:53:50 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/callbacks/early_stopping.py | 2 +- tests/tests_pytorch/callbacks/test_early_stopping.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/lightning/pytorch/callbacks/early_stopping.py b/src/lightning/pytorch/callbacks/early_stopping.py index 11b117961056f..393d8f699f80e 100644 --- a/src/lightning/pytorch/callbacks/early_stopping.py +++ b/src/lightning/pytorch/callbacks/early_stopping.py @@ -79,7 +79,7 @@ class EarlyStopping(Callback): >>> from pytorch_lightning.callbacks import EarlyStopping >>> early_stopping = EarlyStopping('val_loss') >>> trainer = Trainer(callbacks=[early_stopping]) - + >>> # Start monitoring only from epoch 5 >>> early_stopping = EarlyStopping('val_loss', start_from_epoch=5) >>> trainer = Trainer(callbacks=[early_stopping]) diff --git a/tests/tests_pytorch/callbacks/test_early_stopping.py b/tests/tests_pytorch/callbacks/test_early_stopping.py index 026f6d655fada..0ec23d33f3e7c 100644 --- a/tests/tests_pytorch/callbacks/test_early_stopping.py +++ b/tests/tests_pytorch/callbacks/test_early_stopping.py @@ -511,7 +511,6 @@ def test_early_stopping_start_from_epoch(tmp_path): """Test that early stopping checks only activate after start_from_epoch.""" losses = [6, 5, 4, 3, 2, 1] # decreasing losses start_from_epoch = 3 - expected_stop_epoch = None # Should not stop early class CurrentModel(BoringModel): def on_validation_epoch_end(self): @@ -519,7 +518,7 @@ def on_validation_epoch_end(self): self.log("val_loss", val_loss) model = CurrentModel() - + # Mock the _run_early_stopping_check method to verify when it's called with mock.patch("lightning.pytorch.callbacks.early_stopping.EarlyStopping._evaluate_stopping_criteria") as es_mock: es_mock.return_value = (False, "")