Open
Description
Bug description
The computation graph is not being created when precision == 'bf16-mixed' and nn.module is used earlier in a 'no_grad()'.
The following example script elicits the following error on 'out2.backward()'.
I encountered this bug when trying to create an inner optimisation loop for one of my modules.
What version are you seeing the problem on?
v2.5
How to reproduce the bug
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as L
class Net(nn.Module):
def __init__(self, in_size):
super().__init__()
self.layer = nn.Linear(in_size, 1)
def forward(self, x):
return self.layer(x)
def loss(self, x):
# although unused, calculating out1 breaks grad tracking for out2
with torch.no_grad():
out1 = self.layer(x)
out2 = self.layer(x).mean()
out2.backward()
out3 = self.layer(x).mean()
return out3
class LightningModel(L.LightningModule):
def __init__(self, net):
super().__init__()
self.net = net
def forward(self, z):
return self.actor(z)
def training_step(self, batch, batch_idx):
x, y = batch
loss = self.net.loss(x)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.net.parameters(), lr=3e-4)
device = 'cpu' # or cuda
net = Net(10).to(device)
model = LightningModel(net)
trainer = L.Trainer(
accelerator=device,
precision='bf16-mixed',
)
X = torch.randn(10, 10)
Y = torch.randn(10)
batch = (X, Y)
trainer.fit(model, batch)
Error messages and logs
(.venv) joeag@deepthought:~/projects$ python bug.py
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/loops/utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
| Name | Type | Params | Mode
--------------------------------------
0 | net | Net | 11 | train
--------------------------------------
11 Trainable params
0 Non-trainable params
11 Total params
0.000 Total estimated model params size (MB)
2 Modules in train mode
0 Modules in eval mode
/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (10) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
Epoch 0: 0%| | 0/10 [00:00<?, ?it/s]Traceback (most recent call last):
File "/home/joeag/projects/bug.py", line 53, in <module>
trainer.fit(model, batch)
File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 539, in fit
call._call_and_handle_interrupt(
File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py", line 47, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 575, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 982, in _run
results = self._run_stage()
^^^^^^^^^^^^^^^^^
File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 1026, in _run_stage
self.fit_loop.run()
File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py", line 216, in run
self.advance()
File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py", line 455, in advance
self.epoch_loop.run(self._data_fetcher)
File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 150, in run
self.advance(data_fetcher)
File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 320, in advance
batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 192, in run
self._optimizer_step(batch_idx, closure)
File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 270, in _optimizer_step
call._call_lightning_module_hook(
File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py", line 171, in _call_lightning_module_hook
output = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/core/module.py", line 1302, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/core/optimizer.py", line 154, in step
step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/strategies/strategy.py", line 239, in optimizer_step
return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/plugins/precision/amp.py", line 76, in optimizer_step
return super().optimizer_step(optimizer, model=model, closure=closure, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/plugins/precision/precision.py", line 123, in optimizer_step
return optimizer.step(closure=closure, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/joeag/projects/.venv/lib/python3.11/site-packages/torch/optim/optimizer.py", line 493, in wrapper
out = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/joeag/projects/.venv/lib/python3.11/site-packages/torch/optim/optimizer.py", line 91, in _use_grad
ret = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/joeag/projects/.venv/lib/python3.11/site-packages/torch/optim/adam.py", line 223, in step
loss = closure()
^^^^^^^^^
File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/plugins/precision/precision.py", line 109, in _wrap_closure
closure_result = closure()
^^^^^^^^^
File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 146, in __call__
self._result = self.closure(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/joeag/projects/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 131, in closure
step_output = self._step_fn()
^^^^^^^^^^^^^^^
File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 319, in _training_step
training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py", line 323, in _call_strategy_hook
output = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/joeag/projects/.venv/lib/python3.11/site-packages/pytorch_lightning/strategies/strategy.py", line 391, in training_step
return self.lightning_module.training_step(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/joeag/projects/bug.py", line 35, in training_step
loss = self.net.loss(x)
^^^^^^^^^^^^^^^^
File "/home/joeag/projects/bug.py", line 20, in loss
out2.backward()
File "/home/joeag/projects/.venv/lib/python3.11/site-packages/torch/_tensor.py", line 626, in backward
torch.autograd.backward(
File "/home/joeag/projects/.venv/lib/python3.11/site-packages/torch/autograd/__init__.py", line 347, in backward
_engine_run_backward(
File "/home/joeag/projects/.venv/lib/python3.11/site-packages/torch/autograd/graph.py", line 823, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
Environment
Current environment
- CUDA:
- GPU:
- NVIDIA GeForce RTX 4090
- available: True
- version: 12.4 - Lightning:
- joetorch: 0.0.40
- lightning: 2.5.0.post0
- lightning-utilities: 0.12.0
- pytorch-lightning: 2.5.0.post0
- torch: 2.6.0
- torch-tb-profiler: 0.4.3
- torchaudio: 2.6.0
- torchmetrics: 1.6.1
- torchvision: 0.21.0 - Packages:
- absl-py: 2.1.0
- affine: 2.4.0
- aiohappyeyeballs: 2.4.6
- aiohttp: 3.11.13
- aiosignal: 1.3.2
- annotated-types: 0.7.0
- asttokens: 3.0.0
- attrs: 25.1.0
- certifi: 2025.1.31
- charset-normalizer: 3.4.1
- click: 8.1.8
- click-plugins: 1.1.1
- cligj: 0.7.2
- comm: 0.2.2
- contourpy: 1.3.1
- cycler: 0.12.1
- datasets: 3.3.2
- debugpy: 1.8.12
- decorator: 5.2.1
- dill: 0.3.8
- docker-pycreds: 0.4.0
- executing: 2.2.0
- filelock: 3.17.0
- fonttools: 4.56.0
- frozenlist: 1.5.0
- fsspec: 2024.12.0
- geopandas: 1.0.1
- gitdb: 4.0.12
- gitpython: 3.1.44
- grpcio: 1.70.0
- huggingface-hub: 0.29.2
- idna: 3.10
- imageio: 2.37.0
- ipykernel: 6.29.5
- ipython: 8.32.0
- ipywidgets: 8.1.5
- jedi: 0.19.2
- jinja2: 3.1.5
- joetorch: 0.0.40
- jupyter-client: 8.6.3
- jupyter-core: 5.7.2
- jupyterlab-widgets: 3.0.13
- kiwisolver: 1.4.8
- lightning: 2.5.0.post0
- lightning-utilities: 0.12.0
- markdown: 3.7
- markupsafe: 3.0.2
- matplotlib: 3.10.0
- matplotlib-inline: 0.1.7
- mpmath: 1.3.0
- multidict: 6.1.0
- multiprocess: 0.70.16
- narwhals: 1.30.0
- nest-asyncio: 1.6.0
- networkx: 3.4.2
- numpy: 2.2.3
- nvidia-cublas-cu12: 12.4.5.8
- nvidia-cuda-cupti-cu12: 12.4.127
- nvidia-cuda-nvrtc-cu12: 12.4.127
- nvidia-cuda-runtime-cu12: 12.4.127
- nvidia-cudnn-cu12: 9.1.0.70
- nvidia-cufft-cu12: 11.2.1.3
- nvidia-curand-cu12: 10.3.5.147
- nvidia-cusolver-cu12: 11.6.1.9
- nvidia-cusparse-cu12: 12.3.1.170
- nvidia-cusparselt-cu12: 0.6.2
- nvidia-nccl-cu12: 2.21.5
- nvidia-nvjitlink-cu12: 12.4.127
- nvidia-nvtx-cu12: 12.4.127
- packaging: 24.2
- pandas: 2.2.3
- parso: 0.8.4
- pexpect: 4.9.0
- pillow: 11.1.0
- pip: 25.0.1
- platformdirs: 4.3.6
- plotly: 6.0.0
- prompt-toolkit: 3.0.50
- propcache: 0.3.0
- protobuf: 5.29.3
- psutil: 7.0.0
- ptyprocess: 0.7.0
- pure-eval: 0.2.3
- pyarrow: 19.0.1
- pydantic: 2.10.6
- pydantic-core: 2.27.2
- pygments: 2.19.1
- pyogrio: 0.10.0
- pyparsing: 3.2.1
- pyproj: 3.7.1
- python-dateutil: 2.9.0.post0
- pytorch-lightning: 2.5.0.post0
- pytz: 2025.1
- pyyaml: 6.0.2
- pyzmq: 26.2.1
- rasterio: 1.4.3
- requests: 2.32.3
- sentry-sdk: 2.22.0
- setproctitle: 1.3.5
- setuptools: 65.5.0
- shapely: 2.0.7
- six: 1.17.0
- smmap: 5.0.2
- stack-data: 0.6.3
- sympy: 1.13.1
- tensorboard: 2.19.0
- tensorboard-data-server: 0.7.2
- torch: 2.6.0
- torch-tb-profiler: 0.4.3
- torchaudio: 2.6.0
- torchmetrics: 1.6.1
- torchvision: 0.21.0
- tornado: 6.4.2
- tqdm: 4.67.1
- traitlets: 5.14.3
- triton: 3.2.0
- typing-extensions: 4.12.2
- tzdata: 2025.1
- urllib3: 2.3.0
- wandb: 0.19.7
- wcwidth: 0.2.13
- werkzeug: 3.1.3
- widgetsnbextension: 4.0.13
- xxhash: 3.5.0
- yarl: 1.18.3 - System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.11.11
- release: 6.8.0-52-generic
- version: Useblack
for autoformatting #53-Ubuntu SMP PREEMPT_DYNAMIC Sat Jan 11 00:06:25 UTC 2025
More info
No response