Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion docs/source-pytorch/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,9 @@ overfit_batches
Uses this much data of the training & validation set.
If the training & validation dataloaders have ``shuffle=True``, Lightning will automatically disable it.

* When set to exactly 1, the same batch is used for both training and validation steps, which is useful for debugging model implementation
* For other values, sequential sampling (no shuffling) is used

Useful for quickly debugging or trying to overfit on purpose.

.. testcode::
Expand All @@ -769,9 +772,13 @@ Useful for quickly debugging or trying to overfit on purpose.
# use only 1% of the train & val set
trainer = Trainer(overfit_batches=0.01)

# overfit on 10 of the same batches
# overfit on 10 (same) train batches & 10 (same) val batches
trainer = Trainer(overfit_batches=10)

# debug by training and validating on exactly the same single batch
# (useful for verifying model implementation)
trainer = Trainer(overfit_batches=1)

plugins
^^^^^^^

Expand Down
53 changes: 49 additions & 4 deletions src/lightning/pytorch/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,19 +244,64 @@ def _get_distributed_sampler(


def _resolve_overfit_batches(combined_loader: CombinedLoader, mode: RunningStage) -> None:
"""Resolve overfit batches by ensuring the same batch is used for both training and validation."""
all_have_sequential_sampler = all(
isinstance(dl.sampler, SequentialSampler) for dl in combined_loader.flattened if hasattr(dl, "sampler")
)
if all_have_sequential_sampler:
return

rank_zero_warn(
f"You requested to overfit but enabled {mode.dataloader_prefix} dataloader shuffling."
f" We are turning off the {mode.dataloader_prefix} dataloader shuffling for you."
)
updated = [
_update_dataloader(dl, sampler=SequentialSampler(dl.dataset), mode=mode) if hasattr(dl, "dataset") else dl
for dl in combined_loader.flattened
]

# Get the first batch from the training dataloader
first_batch = None
if mode == RunningStage.TRAINING:
for dl in combined_loader.flattened:
if hasattr(dl, "dataset"):
first_batch = next(iter(dl))
break

# Create new dataloaders with SequentialSampler
updated = []
for dl in combined_loader.flattened:
if hasattr(dl, "dataset"):
if mode == RunningStage.VALIDATING and first_batch is not None:
# For validation, create a custom sampler that always returns the first batch
class SingleBatchSampler(Sampler):
def __init__(self, batch):
self.batch = batch

def __iter__(self):
yield self.batch

def __len__(self):
return 1

sampler = SingleBatchSampler(first_batch)
else:
sampler = SequentialSampler(dl.dataset)

# Create a new dataloader with the new sampler
dl = DataLoader(
dataset=dl.dataset,
batch_size=dl.batch_size,
sampler=sampler,
num_workers=dl.num_workers,
collate_fn=dl.collate_fn,
pin_memory=dl.pin_memory,
drop_last=dl.drop_last,
timeout=dl.timeout,
worker_init_fn=dl.worker_init_fn,
multiprocessing_context=dl.multiprocessing_context,
generator=dl.generator,
prefetch_factor=dl.prefetch_factor,
persistent_workers=dl.persistent_workers,
)
updated.append(dl)

combined_loader.flattened = updated


Expand Down
6 changes: 6 additions & 0 deletions tests/tests_pytorch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ def restore_env_variables():
"TF_GRPC_DEFAULT_OPTIONS",
"XLA_FLAGS",
"TORCHINDUCTOR_CACHE_DIR", # leaked by torch.compile
# TensorFlow and TPU related variables
"TF2_BEHAVIOR",
"TPU_ML_PLATFORM",
"TPU_ML_PLATFORM_VERSION",
"LD_LIBRARY_PATH",
"ENABLE_RUNTIME_UPTIME_TELEMETRY",
}
leaked_vars.difference_update(allowlist)
assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}"
Expand Down
41 changes: 41 additions & 0 deletions tests/tests_pytorch/trainer/flags/test_overfit_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,44 @@ def test_distributed_sampler_with_overfit_batches():
train_sampler = trainer.train_dataloader.sampler
assert isinstance(train_sampler, DistributedSampler)
assert train_sampler.shuffle is False


def test_overfit_batches_same_batch_for_train_and_val(tmp_path):
"""Test that when overfit_batches=1, the same batch is used for both training and validation."""

class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.train_batches = []
self.val_batches = []

def training_step(self, batch, batch_idx):
self.train_batches.append(batch)
return super().training_step(batch, batch_idx)

def validation_step(self, batch, batch_idx):
self.val_batches.append(batch)
return super().validation_step(batch, batch_idx)

model = TestModel()
trainer = Trainer(
default_root_dir=tmp_path,
max_epochs=2,
overfit_batches=1,
check_val_every_n_epoch=1,
enable_model_summary=False,
)
trainer.fit(model)

# Verify that the same batch was used for both training and validation
assert len(model.train_batches) > 0
assert len(model.val_batches) > 0

# Compare the actual batch contents
train_batch = model.train_batches[0]
val_batch = model.val_batches[0]

# Check if the batches are identical
assert torch.equal(train_batch, val_batch), (
"Training and validation batches should be identical when overfit_batches=1"
)