Skip to content

mark_forward_method does not work with ModelParallelStrategy #20710

Open
@tonyf

Description

@tonyf

Bug description

When using the ModelParallelStrategy, methods annotated with mark_forward_method raise an exception if the function signature does not match that of the module's forward method. This fails specifically when the number of args/kwargs differ between the functions.

For calling generate here would fail in an FSDP2 setting with the error TypeError: Model.forward got an unexpected keyword argument cfg

class Model(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        return x

    def generate(self, x, y, cfg: int = 0.5):
        z_1 = self.forward(x, y)
        z_2 = self.foward(x, torch.zeros_like(y))
        ...

What version are you seeing the problem on?

v2.5

Error messages and logs

        │
[rank0]: │   473 │   │   ):                                                                                                                                      │
[rank0]: │   474 │   │   │   self.callbacks.on_validation_step_start(self, batch_idx)                                                                            │
[rank0]: │   475 │   │   │                                                                                                                                       │
[rank0]: │ ❱ 476 │   │   │   result = self.validation_step(batch, batch_idx)                                                                                     │
[rank0]: │   477 │   │   │   self.callbacks.on_validation_step_end(self, result, batch_idx)                                                                      │
[rank0]: │   478 │   │                                                                                                                                           │
[rank0]: │   479 │   │   result = self.on_validation_epoch_end()                                                                                                 │
[rank0]: │                                                                                                                                                       │
[rank0]: │ /home/tony/workspace/models/models/flow_matching/stage_1_train.py:112 in validation_step                                                              │
[rank0]: │                                                                                                                                                       │
[rank0]: │   109 │   │   B, _, T, H, W = samples.shape                                                                                                           │
[rank0]: │   110 │   │   ct, ch, cw = self.autoencoder.compression                                                                                               │
[rank0]: │   111 │   │                                                                                                                                           │
[rank0]: │ ❱ 112 │   │   samples = self.model.sample(                                                                                                            │
[rank0]: │   113 │   │   │   shape=(B, (T - 1) // ct + 1, H // ch, W // cw, self.autoencoder.latent_dim),                                                        │
[rank0]: │   114 │   │   │   text=text_embeds,                                                                                                                   │
[rank0]: │   115 │   │   │   sample_steps=self.config.sample_steps,                                                                                              │
[rank0]: │                                                                                                                                                       │
[rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/lightning/fabric/wrappers.py:197 in call_forward_module                                │
[rank0]: │                                                                                                                                                       │
[rank0]: │   194 │   │   def call_forward_module(*args: Any, **kwargs: Any) -> Any:                                                                              │
[rank0]: │   195 │   │   │   # Patch the original_module's forward, so we can redirect the arguments back                                                        │
[rank0]: │   196 │   │   │   self._original_module.forward = wrapped_forward                                                                                     │
[rank0]: │ ❱ 197 │   │   │   return self.forward(*args, **kwargs)                                                                                                │
[rank0]: │   198 │   │                                                                                                                                           │
[rank0]: │   199 │   │   return call_forward_module                                                                                                              │
[rank0]: │   200                                                                                                                                                 │
[rank0]: │                                                                                                                                                       │
[rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/lightning/fabric/wrappers.py:136 in forward                                            │
[rank0]: │                                                                                                                                                       │
[rank0]: │   133 │   │   args, kwargs = precision.convert_input((args, kwargs))                                                                                  │
[rank0]: │   134 │   │                                                                                                                                           │
[rank0]: │   135 │   │   with precision.forward_context():                                                                                                       │
[rank0]: │ ❱ 136 │   │   │   output = self._forward_module(*args, **kwargs)                                                                                      │
[rank0]: │   137 │   │                                                                                                                                           │
[rank0]: │   138 │   │   output = precision.convert_output(output)                                                                                               │
[rank0]: │   139                                                                                                                                                 │
[rank0]: │                                                                                                                                                       │
[rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739 in _wrapped_call_impl                                  │
[rank0]: │                                                                                                                                                       │
[rank0]: │   1736 │   │   if self._compiled_call_impl is not None:                                                                                               │
[rank0]: │   1737 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]                                                             │
[rank0]: │   1738 │   │   else:                                                                                                                                  │
[rank0]: │ ❱ 1739 │   │   │   return self._call_impl(*args, **kwargs)                                                                                            │
[rank0]: │   1740 │                                                                                                                                              │
[rank0]: │   1741 │   # torchrec tests the code consistency with the following code                                                                              │
[rank0]: │   1742 │   # fmt: off                                                                                                                                 │
[rank0]: │                                                                                                                                                       │
[rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750 in _call_impl                                          │
[rank0]: │                                                                                                                                                       │
[rank0]: │   1747 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks                                                        │
[rank0]: │   1748 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                                                                        │
[rank0]: │   1749 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                                                                        │
[rank0]: │ ❱ 1750 │   │   │   return forward_call(*args, **kwargs)                                                                                               │
[rank0]: │   1751 │   │                                                                                                                                          │
[rank0]: │   1752 │   │   result = None                                                                                                                          │
[rank0]: │   1753 │   │   called_always_called_hooks = set()                                                                                                     │
[rank0]: │                                                                                                                                                       │
[rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:574 in _fn                                                 │
[rank0]: │                                                                                                                                                       │
[rank0]: │    571 │   │   │   )                                                                                                                                  │
[rank0]: │    572 │   │   │                                                                                                                                      │
[rank0]: │    573 │   │   │   try:                                                                                                                               │
[rank0]: │ ❱  574 │   │   │   │   return fn(*args, **kwargs)                                                                                                     │
[rank0]: │    575 │   │   │   finally:                                                                                                                           │
[rank0]: │    576 │   │   │   │   # Restore the dynamic layer stack depth if necessary.                                                                          │
[rank0]: │    577 │   │   │   │   torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth(                                                                 │
[rank0]: │                                                                                                                                                       │
[rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739 in _wrapped_call_impl                                  │
[rank0]: │                                                                                                                                                       │
[rank0]: │   1736 │   │   if self._compiled_call_impl is not None:                                                                                               │
[rank0]: │   1737 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]                                                             │
[rank0]: │   1738 │   │   else:                                                                                                                                  │
[rank0]: │ ❱ 1739 │   │   │   return self._call_impl(*args, **kwargs)                                                                                            │
[rank0]: │   1740 │                                                                                                                                              │
[rank0]: │   1741 │   # torchrec tests the code consistency with the following code                                                                              │
[rank0]: │   1742 │   # fmt: off                                                                                                                                 │
[rank0]: │                                                                                                                                                       │
[rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750 in _call_impl                                          │
[rank0]: │                                                                                                                                                       │
[rank0]: │   1747 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks                                                        │
[rank0]: │   1748 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                                                                        │
[rank0]: │   1749 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                                                                        │
[rank0]: │ ❱ 1750 │   │   │   return forward_call(*args, **kwargs)                                                                                               │
[rank0]: │   1751 │   │                                                                                                                                          │
[rank0]: │   1752 │   │   result = None                                                                                                                          │
[rank0]: │   1753 │   │   called_always_called_hooks = set()                                                                                                     │
[rank0]: ╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
[rank0]: TypeError: Rem.forward() got an unexpected keyword argument 'shape'

Environment

Current environment
#- PyTorch Lightning Version: 2.5.0.post
#- PyTorch Version: 2.6.0+cu124
#- Python version: 3.11
#- OS: Linux
#- CUDA/cuDNN version: 12.4
#- GPU models and configuration: 8xH100
#- How you installed Lightning(`conda`, `pip`, source): pip

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingneeds triageWaiting to be triaged by maintainersver: 2.5.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions