Skip to content

Predict on TPU using all cores #11417

Open
@stekiri

Description

@stekiri

🐛 Bug

When writing predictions with a torch.save together with a BasePredictionWriter (see this example) on Colab using a TPU runtime employing all 8 cores, only an eighth of the predictions are actually saved on disk.

To Reproduce

The following code is based on the TPU tutorial with a few modifications:

Package installation:

!pip install torch==1.9.1 torchtext==0.10.1 torchvision==0.10.1 pytorch-lightning==1.5.8 cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl

Code:

import os
from typing import Any, List, Optional

import torch
import torch.nn.functional as F
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import BasePredictionWriter
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchmetrics.functional import accuracy
from torchvision import transforms
from torchvision.datasets import MNIST

BATCH_SIZE = 1024


class MNISTDataModule(LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

        # self.dims is returned when you call dm.size()
        # Setting default dims here because we know them.
        # Could optionally be assigned dynamically in dm.setup()
        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
        self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)

    def predict_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)


class LitModel(LightningModule):
    def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):
        super().__init__()

        self.save_hyperparameters()

        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, num_classes),
        )

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)
        return loss

    def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None) -> Any:
        x, y = batch
        return self(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        return optimizer


class CustomWriter(BasePredictionWriter):
    def __init__(self, output_dir: str, write_interval: str):
        super().__init__(write_interval)
        self.output_dir = output_dir

    def write_on_batch_end(
            self, trainer, pl_module: LightningModule, prediction: Any, batch_indices: List[int], batch: Any,
            batch_idx: int, dataloader_idx: int):
        torch.save(prediction, os.path.join(self.output_dir, f"{dataloader_idx}_{batch_idx}.pt"))

    def write_on_epoch_end(
            self, trainer, pl_module: LightningModule, predictions: List[Any], batch_indices: List[Any]):
        torch.save(predictions, os.path.join(self.output_dir, "predictions.pt"))


tmp_dir = "/tmp"
dm = MNISTDataModule()
model = LitModel(*dm.size(), dm.num_classes)
prediction_writer = CustomWriter(
    output_dir=tmp_dir,
    write_interval="epoch")
trainer = Trainer(
    tpu_cores=8,
    callbacks=[prediction_writer])

trainer.predict(model=model, datamodule=dm)

written_predictions = torch.load(os.path.join(tmp_dir, 'predictions.pt'))
nb_predictions = sum([t.shape[0] for t in written_predictions[0]])

assert nb_predictions == 10_000

When using tpu_cores=[1], all predictions are saved correctly with the downside of only using one core instead of all eight.

Expected behavior

The predictions from all cores should be saved in the file.

Environment

Colab with TPU runtime.

  • CUDA:
    • GPU:
    • available: False
    • version: 10.2
  • Packages:
    • numpy: 1.19.5
    • pyTorch_debug: False
    • pyTorch_version: 1.9.1+cu102
    • pytorch-lightning: 1.5.8
    • tqdm: 4.62.3
  • System:
    • OS: Linux
    • architecture:
      • 64bit
    • processor: x86_64
    • python: 3.7.12
    • version: Proposal for help #1 SMP Tue Dec 7 09:58:10 PST 2021

Additional context

Using the BasePredictionWriter was suggested in this issue. As requested by @kaushikb11, I created this new issue.

cc @kaushikb11 @rohitgr7

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions