Skip to content

Computation graph not being built #20644

Open
@joeagriffith

Description

@joeagriffith

Bug description

The computation graph is not being created when precision == 'bf16-mixed' and nn.module is used earlier in a 'no_grad()'.

The following example script elicits the following error on 'out2.backward()'.

I encountered this bug when trying to create an inner optimisation loop for one of my modules.

What version are you seeing the problem on?

v2.5

How to reproduce the bug

import torch 
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as L

class Net(nn.Module):
    def __init__(self, in_size):
        super().__init__()
        self.layer = nn.Linear(in_size, 1)

    def forward(self, x):
        return self.layer(x)

    def loss(self, x):

         # although unused, calculating out1 breaks grad tracking for out2
        with torch.no_grad():
            out1 = self.layer(x)

        out2 = self.layer(x).mean()
        out2.backward()

        out3 = self.layer(x).mean()
        return out3

class LightningModel(L.LightningModule):
    def __init__(self, net):
        super().__init__()
        self.net = net

    def forward(self, z):
        return self.actor(z)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        loss = self.net.loss(x)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.net.parameters(), lr=3e-4)

device = 'cpu' # or cuda 
net = Net(10).to(device)
model = LightningModel(net)

trainer = L.Trainer(
    accelerator=device,
    precision='bf16-mixed',
)

X = torch.randn(10, 10)
Y = torch.randn(10)
batch = (X, Y)
trainer.fit(model, batch)

Error messages and logs

(.venv) joeag@deepthought:~/projects$ python bug.py 
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/loops/utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.

  | Name | Type | Params | Mode 
--------------------------------------
0 | net  | Net  | 11     | train
--------------------------------------
11        Trainable params
0         Non-trainable params
11        Total params
0.000     Total estimated model params size (MB)
2         Modules in train mode
0         Modules in eval mode
/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (10) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
Epoch 0:   0%|                                                                                                                                       | 0/10 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/home/joeag/projects/bug.py", line 53, in <module>
    trainer.fit(model, batch)
  File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 539, in fit
    call._call_and_handle_interrupt(
  File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py", line 47, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 575, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 982, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 1026, in _run_stage
    self.fit_loop.run()
  File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py", line 216, in run
    self.advance()
  File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py", line 455, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 150, in run
    self.advance(data_fetcher)
  File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 320, in advance
    batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 192, in run
    self._optimizer_step(batch_idx, closure)
  File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 270, in _optimizer_step
    call._call_lightning_module_hook(
  File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py", line 171, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/core/module.py", line 1302, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/core/optimizer.py", line 154, in step
    step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/strategies/strategy.py", line 239, in optimizer_step
    return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/plugins/precision/amp.py", line 76, in optimizer_step
    return super().optimizer_step(optimizer, model=model, closure=closure, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/plugins/precision/precision.py", line 123, in optimizer_step
    return optimizer.step(closure=closure, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/joeag/projects/.venv/lib/python3.11/site-packages/torch/optim/optimizer.py", line 493, in wrapper
    out = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/home/joeag/projects/.venv/lib/python3.11/site-packages/torch/optim/optimizer.py", line 91, in _use_grad
    ret = func(self, *args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/joeag/projects/.venv/lib/python3.11/site-packages/torch/optim/adam.py", line 223, in step
    loss = closure()
           ^^^^^^^^^
  File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/plugins/precision/precision.py", line 109, in _wrap_closure
    closure_result = closure()
                     ^^^^^^^^^
  File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 146, in __call__
    self._result = self.closure(*args, **kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/joeag/projects/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 131, in closure
    step_output = self._step_fn()
                  ^^^^^^^^^^^^^^^
  File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 319, in _training_step
    training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py", line 323, in _call_strategy_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/strategies/strategy.py", line 391, in training_step
    return self.lightning_module.training_step(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/joeag/projects/bug.py", line 35, in training_step
    loss = self.net.loss(x)
           ^^^^^^^^^^^^^^^^
  File "/home/joeag/projects/bug.py", line 20, in loss
    out2.backward()
  File "/home/joeag/projects/.venv/lib/python3.11/site-packages/torch/_tensor.py", line 626, in backward
    torch.autograd.backward(
  File "/home/joeag/projects/.venv/lib/python3.11/site-packages/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/home/joeag/projects/.venv/lib/python3.11/site-packages/torch/autograd/graph.py", line 823, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA GeForce RTX 4090
    - available: True
    - version: 12.4
  • Lightning:
    - joetorch: 0.0.40
    - lightning: 2.5.0.post0
    - lightning-utilities: 0.12.0
    - pytorch-lightning: 2.5.0.post0
    - torch: 2.6.0
    - torch-tb-profiler: 0.4.3
    - torchaudio: 2.6.0
    - torchmetrics: 1.6.1
    - torchvision: 0.21.0
  • Packages:
    - absl-py: 2.1.0
    - affine: 2.4.0
    - aiohappyeyeballs: 2.4.6
    - aiohttp: 3.11.13
    - aiosignal: 1.3.2
    - annotated-types: 0.7.0
    - asttokens: 3.0.0
    - attrs: 25.1.0
    - certifi: 2025.1.31
    - charset-normalizer: 3.4.1
    - click: 8.1.8
    - click-plugins: 1.1.1
    - cligj: 0.7.2
    - comm: 0.2.2
    - contourpy: 1.3.1
    - cycler: 0.12.1
    - datasets: 3.3.2
    - debugpy: 1.8.12
    - decorator: 5.2.1
    - dill: 0.3.8
    - docker-pycreds: 0.4.0
    - executing: 2.2.0
    - filelock: 3.17.0
    - fonttools: 4.56.0
    - frozenlist: 1.5.0
    - fsspec: 2024.12.0
    - geopandas: 1.0.1
    - gitdb: 4.0.12
    - gitpython: 3.1.44
    - grpcio: 1.70.0
    - huggingface-hub: 0.29.2
    - idna: 3.10
    - imageio: 2.37.0
    - ipykernel: 6.29.5
    - ipython: 8.32.0
    - ipywidgets: 8.1.5
    - jedi: 0.19.2
    - jinja2: 3.1.5
    - joetorch: 0.0.40
    - jupyter-client: 8.6.3
    - jupyter-core: 5.7.2
    - jupyterlab-widgets: 3.0.13
    - kiwisolver: 1.4.8
    - lightning: 2.5.0.post0
    - lightning-utilities: 0.12.0
    - markdown: 3.7
    - markupsafe: 3.0.2
    - matplotlib: 3.10.0
    - matplotlib-inline: 0.1.7
    - mpmath: 1.3.0
    - multidict: 6.1.0
    - multiprocess: 0.70.16
    - narwhals: 1.30.0
    - nest-asyncio: 1.6.0
    - networkx: 3.4.2
    - numpy: 2.2.3
    - nvidia-cublas-cu12: 12.4.5.8
    - nvidia-cuda-cupti-cu12: 12.4.127
    - nvidia-cuda-nvrtc-cu12: 12.4.127
    - nvidia-cuda-runtime-cu12: 12.4.127
    - nvidia-cudnn-cu12: 9.1.0.70
    - nvidia-cufft-cu12: 11.2.1.3
    - nvidia-curand-cu12: 10.3.5.147
    - nvidia-cusolver-cu12: 11.6.1.9
    - nvidia-cusparse-cu12: 12.3.1.170
    - nvidia-cusparselt-cu12: 0.6.2
    - nvidia-nccl-cu12: 2.21.5
    - nvidia-nvjitlink-cu12: 12.4.127
    - nvidia-nvtx-cu12: 12.4.127
    - packaging: 24.2
    - pandas: 2.2.3
    - parso: 0.8.4
    - pexpect: 4.9.0
    - pillow: 11.1.0
    - pip: 25.0.1
    - platformdirs: 4.3.6
    - plotly: 6.0.0
    - prompt-toolkit: 3.0.50
    - propcache: 0.3.0
    - protobuf: 5.29.3
    - psutil: 7.0.0
    - ptyprocess: 0.7.0
    - pure-eval: 0.2.3
    - pyarrow: 19.0.1
    - pydantic: 2.10.6
    - pydantic-core: 2.27.2
    - pygments: 2.19.1
    - pyogrio: 0.10.0
    - pyparsing: 3.2.1
    - pyproj: 3.7.1
    - python-dateutil: 2.9.0.post0
    - pytorch-lightning: 2.5.0.post0
    - pytz: 2025.1
    - pyyaml: 6.0.2
    - pyzmq: 26.2.1
    - rasterio: 1.4.3
    - requests: 2.32.3
    - sentry-sdk: 2.22.0
    - setproctitle: 1.3.5
    - setuptools: 65.5.0
    - shapely: 2.0.7
    - six: 1.17.0
    - smmap: 5.0.2
    - stack-data: 0.6.3
    - sympy: 1.13.1
    - tensorboard: 2.19.0
    - tensorboard-data-server: 0.7.2
    - torch: 2.6.0
    - torch-tb-profiler: 0.4.3
    - torchaudio: 2.6.0
    - torchmetrics: 1.6.1
    - torchvision: 0.21.0
    - tornado: 6.4.2
    - tqdm: 4.67.1
    - traitlets: 5.14.3
    - triton: 3.2.0
    - typing-extensions: 4.12.2
    - tzdata: 2025.1
    - urllib3: 2.3.0
    - wandb: 0.19.7
    - wcwidth: 0.2.13
    - werkzeug: 3.1.3
    - widgetsnbextension: 4.0.13
    - xxhash: 3.5.0
    - yarl: 1.18.3
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.11.11
    - release: 6.8.0-52-generic
    - version: Use black for autoformatting #53-Ubuntu SMP PREEMPT_DYNAMIC Sat Jan 11 00:06:25 UTC 2025

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions