Open
Description
🐛 Bug
When writing predictions with a torch.save
together with a BasePredictionWriter
(see this example) on Colab using a TPU runtime employing all 8 cores, only an eighth of the predictions are actually saved on disk.
To Reproduce
The following code is based on the TPU tutorial with a few modifications:
Package installation:
!pip install torch==1.9.1 torchtext==0.10.1 torchvision==0.10.1 pytorch-lightning==1.5.8 cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl
Code:
import os
from typing import Any, List, Optional
import torch
import torch.nn.functional as F
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import BasePredictionWriter
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchmetrics.functional import accuracy
from torchvision import transforms
from torchvision.datasets import MNIST
BATCH_SIZE = 1024
class MNISTDataModule(LightningDataModule):
def __init__(self, data_dir: str = "./"):
super().__init__()
self.data_dir = data_dir
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# self.dims is returned when you call dm.size()
# Setting default dims here because we know them.
# Could optionally be assigned dynamically in dm.setup()
self.dims = (1, 28, 28)
self.num_classes = 10
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)
def predict_dataloader(self):
return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)
class LitModel(LightningModule):
def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):
super().__init__()
self.save_hyperparameters()
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(channels * width * height, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, num_classes),
)
def forward(self, x):
x = self.model(x)
return F.log_softmax(x, dim=1)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
acc = accuracy(preds, y)
self.log("val_loss", loss, prog_bar=True)
self.log("val_acc", acc, prog_bar=True)
return loss
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None) -> Any:
x, y = batch
return self(x)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
return optimizer
class CustomWriter(BasePredictionWriter):
def __init__(self, output_dir: str, write_interval: str):
super().__init__(write_interval)
self.output_dir = output_dir
def write_on_batch_end(
self, trainer, pl_module: LightningModule, prediction: Any, batch_indices: List[int], batch: Any,
batch_idx: int, dataloader_idx: int):
torch.save(prediction, os.path.join(self.output_dir, f"{dataloader_idx}_{batch_idx}.pt"))
def write_on_epoch_end(
self, trainer, pl_module: LightningModule, predictions: List[Any], batch_indices: List[Any]):
torch.save(predictions, os.path.join(self.output_dir, "predictions.pt"))
tmp_dir = "/tmp"
dm = MNISTDataModule()
model = LitModel(*dm.size(), dm.num_classes)
prediction_writer = CustomWriter(
output_dir=tmp_dir,
write_interval="epoch")
trainer = Trainer(
tpu_cores=8,
callbacks=[prediction_writer])
trainer.predict(model=model, datamodule=dm)
written_predictions = torch.load(os.path.join(tmp_dir, 'predictions.pt'))
nb_predictions = sum([t.shape[0] for t in written_predictions[0]])
assert nb_predictions == 10_000
When using tpu_cores=[1]
, all predictions are saved correctly with the downside of only using one core instead of all eight.
Expected behavior
The predictions from all cores should be saved in the file.
Environment
Colab with TPU runtime.
- CUDA:
- GPU:
- available: False
- version: 10.2
- Packages:
- numpy: 1.19.5
- pyTorch_debug: False
- pyTorch_version: 1.9.1+cu102
- pytorch-lightning: 1.5.8
- tqdm: 4.62.3
- System:
- OS: Linux
- architecture:
- 64bit
- processor: x86_64
- python: 3.7.12
- version: Proposal for help #1 SMP Tue Dec 7 09:58:10 PST 2021
Additional context
Using the BasePredictionWriter
was suggested in this issue. As requested by @kaushikb11, I created this new issue.