Open
Description
Bug description
When using the ModelParallelStrategy
, methods annotated with mark_forward_method
raise an exception if the function signature does not match that of the module's forward
method. This fails specifically when the number of args/kwargs differ between the functions.
For calling generate
here would fail in an FSDP2 setting with the error TypeError: Model.forward got an unexpected keyword argument cfg
class Model(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return x
def generate(self, x, y, cfg: int = 0.5):
z_1 = self.forward(x, y)
z_2 = self.foward(x, torch.zeros_like(y))
...
What version are you seeing the problem on?
v2.5
Error messages and logs
│
[rank0]: │ 473 │ │ ): │
[rank0]: │ 474 │ │ │ self.callbacks.on_validation_step_start(self, batch_idx) │
[rank0]: │ 475 │ │ │ │
[rank0]: │ ❱ 476 │ │ │ result = self.validation_step(batch, batch_idx) │
[rank0]: │ 477 │ │ │ self.callbacks.on_validation_step_end(self, result, batch_idx) │
[rank0]: │ 478 │ │ │
[rank0]: │ 479 │ │ result = self.on_validation_epoch_end() │
[rank0]: │ │
[rank0]: │ /home/tony/workspace/models/models/flow_matching/stage_1_train.py:112 in validation_step │
[rank0]: │ │
[rank0]: │ 109 │ │ B, _, T, H, W = samples.shape │
[rank0]: │ 110 │ │ ct, ch, cw = self.autoencoder.compression │
[rank0]: │ 111 │ │ │
[rank0]: │ ❱ 112 │ │ samples = self.model.sample( │
[rank0]: │ 113 │ │ │ shape=(B, (T - 1) // ct + 1, H // ch, W // cw, self.autoencoder.latent_dim), │
[rank0]: │ 114 │ │ │ text=text_embeds, │
[rank0]: │ 115 │ │ │ sample_steps=self.config.sample_steps, │
[rank0]: │ │
[rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/lightning/fabric/wrappers.py:197 in call_forward_module │
[rank0]: │ │
[rank0]: │ 194 │ │ def call_forward_module(*args: Any, **kwargs: Any) -> Any: │
[rank0]: │ 195 │ │ │ # Patch the original_module's forward, so we can redirect the arguments back │
[rank0]: │ 196 │ │ │ self._original_module.forward = wrapped_forward │
[rank0]: │ ❱ 197 │ │ │ return self.forward(*args, **kwargs) │
[rank0]: │ 198 │ │ │
[rank0]: │ 199 │ │ return call_forward_module │
[rank0]: │ 200 │
[rank0]: │ │
[rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/lightning/fabric/wrappers.py:136 in forward │
[rank0]: │ │
[rank0]: │ 133 │ │ args, kwargs = precision.convert_input((args, kwargs)) │
[rank0]: │ 134 │ │ │
[rank0]: │ 135 │ │ with precision.forward_context(): │
[rank0]: │ ❱ 136 │ │ │ output = self._forward_module(*args, **kwargs) │
[rank0]: │ 137 │ │ │
[rank0]: │ 138 │ │ output = precision.convert_output(output) │
[rank0]: │ 139 │
[rank0]: │ │
[rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739 in _wrapped_call_impl │
[rank0]: │ │
[rank0]: │ 1736 │ │ if self._compiled_call_impl is not None: │
[rank0]: │ 1737 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] │
[rank0]: │ 1738 │ │ else: │
[rank0]: │ ❱ 1739 │ │ │ return self._call_impl(*args, **kwargs) │
[rank0]: │ 1740 │ │
[rank0]: │ 1741 │ # torchrec tests the code consistency with the following code │
[rank0]: │ 1742 │ # fmt: off │
[rank0]: │ │
[rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750 in _call_impl │
[rank0]: │ │
[rank0]: │ 1747 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
[rank0]: │ 1748 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
[rank0]: │ 1749 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
[rank0]: │ ❱ 1750 │ │ │ return forward_call(*args, **kwargs) │
[rank0]: │ 1751 │ │ │
[rank0]: │ 1752 │ │ result = None │
[rank0]: │ 1753 │ │ called_always_called_hooks = set() │
[rank0]: │ │
[rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:574 in _fn │
[rank0]: │ │
[rank0]: │ 571 │ │ │ ) │
[rank0]: │ 572 │ │ │ │
[rank0]: │ 573 │ │ │ try: │
[rank0]: │ ❱ 574 │ │ │ │ return fn(*args, **kwargs) │
[rank0]: │ 575 │ │ │ finally: │
[rank0]: │ 576 │ │ │ │ # Restore the dynamic layer stack depth if necessary. │
[rank0]: │ 577 │ │ │ │ torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth( │
[rank0]: │ │
[rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739 in _wrapped_call_impl │
[rank0]: │ │
[rank0]: │ 1736 │ │ if self._compiled_call_impl is not None: │
[rank0]: │ 1737 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] │
[rank0]: │ 1738 │ │ else: │
[rank0]: │ ❱ 1739 │ │ │ return self._call_impl(*args, **kwargs) │
[rank0]: │ 1740 │ │
[rank0]: │ 1741 │ # torchrec tests the code consistency with the following code │
[rank0]: │ 1742 │ # fmt: off │
[rank0]: │ │
[rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750 in _call_impl │
[rank0]: │ │
[rank0]: │ 1747 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
[rank0]: │ 1748 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
[rank0]: │ 1749 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
[rank0]: │ ❱ 1750 │ │ │ return forward_call(*args, **kwargs) │
[rank0]: │ 1751 │ │ │
[rank0]: │ 1752 │ │ result = None │
[rank0]: │ 1753 │ │ called_always_called_hooks = set() │
[rank0]: ╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
[rank0]: TypeError: Rem.forward() got an unexpected keyword argument 'shape'
Environment
Current environment
#- PyTorch Lightning Version: 2.5.0.post
#- PyTorch Version: 2.6.0+cu124
#- Python version: 3.11
#- OS: Linux
#- CUDA/cuDNN version: 12.4
#- GPU models and configuration: 8xH100
#- How you installed Lightning(`conda`, `pip`, source): pip
More info
No response