Closed
Description
Bug description
If you have a dataset which just pops things off a queue and you set persistent_workers=True, num_workers > 0
in your dataloader, then the first two items on the queue are thrown away (as we prefetch num_workers*prefetch_factor).
This is because we call https://github.com/Lightning-AI/lightning/blob/722fdeac44cce49928184d89684eeb668742bf37/src/lightning/pytorch/trainer/connectors/data_connector.py#L391 in the training loop. This starts the dataloading process which fills the prefetch buffer, this buffer is then tossed once we start the first epoch.
What version are you seeing the problem on?
v2.0
How to reproduce the bug
python
import multiprocessing as mp
from queue import Queue
from typing import Iterator
import numpy as np
from torch.utils.data import DataLoader, IterableDataset
from lightning import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
class QueueDataset(IterableDataset):
def __init__(self, queue: Queue) -> None:
super().__init__()
self.queue = queue
def __iter__(self) -> Iterator:
for k in range(5):
print(f"getting {k}")
tensor, index = self.queue.get(timeout=10)
print(f"got {index}")
yield tensor
if __name__ == "__main__":
q = mp.Queue()
arr = np.random.random([1, 32]).astype(np.float32)
for ind in range(5):
q.put((arr, ind))
dataloader = DataLoader(QueueDataset(q), num_workers=1, batch_size=None, persistent_workers=True)
trainer = Trainer(max_epochs=1, enable_progress_bar=False)
trainer.fit(BoringModel(), dataloader)
Error messages and logs
getting 0
got 0
getting 1
got 1
getting 0
got 2
getting 1
got 3
getting 2
got 4
getting 3
Then we get the _queue.Empty exception as it times out since the queue is empty
Environment
Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):
More info
No response