diff --git a/docs/source-pytorch/common/trainer.rst b/docs/source-pytorch/common/trainer.rst index 339c59771001a..6a8a8135a1843 100644 --- a/docs/source-pytorch/common/trainer.rst +++ b/docs/source-pytorch/common/trainer.rst @@ -759,6 +759,9 @@ overfit_batches Uses this much data of the training & validation set. If the training & validation dataloaders have ``shuffle=True``, Lightning will automatically disable it. +* When set to a value > 0, sequential sampling (no shuffling) is used +* Consistent batches are used for both training and validation across epochs, but training and validation use different sets of data + Useful for quickly debugging or trying to overfit on purpose. .. testcode:: @@ -769,11 +772,11 @@ Useful for quickly debugging or trying to overfit on purpose. # use only 1% of the train & val set trainer = Trainer(overfit_batches=0.01) - # overfit on 10 of the same batches + # overfit on 10 consistent train batches & 10 consistent val batches trainer = Trainer(overfit_batches=10) -plugins -^^^^^^^ + # debug using a single consistent train batch and a single consistent val batch + :ref:`Plugins` allow you to connect arbitrary backends, precision libraries, clusters etc. For example: @@ -895,7 +898,7 @@ DataSource can be a ``LightningModule`` or a ``LightningDataModule``. # if 0 (default) train_loader = model.train_dataloader() - # or if using data module: datamodule.train_dataloader() + # or if using data module: datamodule.train_dataloaders() for epoch in epochs: for batch in train_loader: ... diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 3e5273085ed2b..841d78b457d48 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -244,15 +244,23 @@ def _get_distributed_sampler( def _resolve_overfit_batches(combined_loader: CombinedLoader, mode: RunningStage) -> None: + """Resolve overfit batches by disabling shuffling. + + When overfit_batches > 0, this function ensures that sequential sampling is used without shuffling for consistent + batches across epochs. Training and validation use different sets of data. + + """ all_have_sequential_sampler = all( isinstance(dl.sampler, SequentialSampler) for dl in combined_loader.flattened if hasattr(dl, "sampler") ) if all_have_sequential_sampler: return + rank_zero_warn( f"You requested to overfit but enabled {mode.dataloader_prefix} dataloader shuffling." f" We are turning off the {mode.dataloader_prefix} dataloader shuffling for you." ) + updated = [ _update_dataloader(dl, sampler=SequentialSampler(dl.dataset), mode=mode) if hasattr(dl, "dataset") else dl for dl in combined_loader.flattened diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py index b02d9d089a354..fb5f4b04400e6 100644 --- a/tests/tests_pytorch/conftest.py +++ b/tests/tests_pytorch/conftest.py @@ -95,6 +95,12 @@ def restore_env_variables(): "TF_GRPC_DEFAULT_OPTIONS", "XLA_FLAGS", "TORCHINDUCTOR_CACHE_DIR", # leaked by torch.compile + # TensorFlow and TPU related variables + "TF2_BEHAVIOR", + "TPU_ML_PLATFORM", + "TPU_ML_PLATFORM_VERSION", + "LD_LIBRARY_PATH", + "ENABLE_RUNTIME_UPTIME_TELEMETRY", } leaked_vars.difference_update(allowlist) assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}" diff --git a/tests/tests_pytorch/trainer/flags/test_overfit_batches.py b/tests/tests_pytorch/trainer/flags/test_overfit_batches.py index 050818287ba45..6322698ef3b73 100644 --- a/tests/tests_pytorch/trainer/flags/test_overfit_batches.py +++ b/tests/tests_pytorch/trainer/flags/test_overfit_batches.py @@ -170,3 +170,44 @@ def test_distributed_sampler_with_overfit_batches(): train_sampler = trainer.train_dataloader.sampler assert isinstance(train_sampler, DistributedSampler) assert train_sampler.shuffle is False + + +def test_overfit_batches_same_batch_for_train_and_val(tmp_path): + """Test that when overfit_batches=1, the same batch is used for both training and validation.""" + + class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.train_batches = [] + self.val_batches = [] + + def training_step(self, batch, batch_idx): + self.train_batches.append(batch) + return super().training_step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + self.val_batches.append(batch) + return super().validation_step(batch, batch_idx) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=2, + overfit_batches=1, + check_val_every_n_epoch=1, + enable_model_summary=False, + ) + trainer.fit(model) + + # Verify that the same batch was used for both training and validation + assert len(model.train_batches) > 0 + assert len(model.val_batches) > 0 + + # Compare the actual batch contents + train_batch = model.train_batches[0] + val_batch = model.val_batches[0] + + # Check if the batches are identical + assert torch.equal(train_batch, val_batch), ( + "Training and validation batches should be identical when overfit_batches=1" + )