Skip to content

prefetch_factor*worker datums thrown away because of _check_dataloader_iterable #18414

Closed
@ben-davidson-6

Description

@ben-davidson-6

Bug description

If you have a dataset which just pops things off a queue and you set persistent_workers=True, num_workers > 0 in your dataloader, then the first two items on the queue are thrown away (as we prefetch num_workers*prefetch_factor).

This is because we call https://github.com/Lightning-AI/lightning/blob/722fdeac44cce49928184d89684eeb668742bf37/src/lightning/pytorch/trainer/connectors/data_connector.py#L391 in the training loop. This starts the dataloading process which fills the prefetch buffer, this buffer is then tossed once we start the first epoch.

What version are you seeing the problem on?

v2.0

How to reproduce the bug

python
import multiprocessing as mp
from queue import Queue
from typing import Iterator

import numpy as np
from torch.utils.data import DataLoader, IterableDataset

from lightning import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel


class QueueDataset(IterableDataset):
    def __init__(self, queue: Queue) -> None:
        super().__init__()
        self.queue = queue

    def __iter__(self) -> Iterator:
        for k in range(5):
            print(f"getting {k}")
            tensor, index = self.queue.get(timeout=10)
            print(f"got {index}")
            yield tensor


if __name__ == "__main__":
    q = mp.Queue()
    arr = np.random.random([1, 32]).astype(np.float32)
    for ind in range(5):
        q.put((arr, ind))
    dataloader = DataLoader(QueueDataset(q), num_workers=1, batch_size=None, persistent_workers=True)
    trainer = Trainer(max_epochs=1, enable_progress_bar=False)
    trainer.fit(BoringModel(), dataloader)

Error messages and logs

getting 0
got 0
getting 1
got 1
getting 0
got 2
getting 1
got 3
getting 2
got 4
getting 3

Then we get the _queue.Empty exception as it times out since the queue is empty

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

cc @justusschock @awaelchli @Borda

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions