diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 5616defeffc8a..c794603990737 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Add enable_autolog_hparams argument to Trainer ([#20593](https://github.com/Lightning-AI/pytorch-lightning/pull/20593)) +- Support `grad_clip_norm_()` for FSDP ([#20784](https://github.com/Lightning-AI/pytorch-lightning/pull/20784)) ### Changed diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index b8624daac3fa3..d1b0cca4feeae 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -1207,7 +1207,9 @@ def clip_gradients( ) gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm) - self.trainer.precision_plugin.clip_gradients(optimizer, gradient_clip_val, gradient_clip_algorithm) + self.trainer.precision_plugin.clip_gradients( + self.trainer.model, optimizer, gradient_clip_val, gradient_clip_algorithm + ) def configure_gradient_clipping( self, diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index 75e792af46b90..f6ec37e7d4edb 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -15,6 +15,7 @@ import torch from torch import Tensor +from torch.nn import Module from torch.optim import LBFGS, Optimizer from typing_extensions import override @@ -100,6 +101,7 @@ def optimizer_step( # type: ignore[override] @override def clip_gradients( self, + module: Optional[Module], optimizer: Optimizer, clip_val: Union[int, float] = 0.0, gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, @@ -109,7 +111,9 @@ def clip_gradients( f"The current optimizer, {type(optimizer).__qualname__}, does not allow for gradient clipping" " because it performs unscaling of gradients internally. HINT: Are you using a 'fused' optimizer?" ) - super().clip_gradients(optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm) + super().clip_gradients( + module=module, optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm + ) def autocast_context_manager(self) -> torch.autocast: return torch.autocast(self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.half)) diff --git a/src/lightning/pytorch/plugins/precision/deepspeed.py b/src/lightning/pytorch/plugins/precision/deepspeed.py index 9225e3bb9e7be..e09eb67f4fecf 100644 --- a/src/lightning/pytorch/plugins/precision/deepspeed.py +++ b/src/lightning/pytorch/plugins/precision/deepspeed.py @@ -141,6 +141,7 @@ def optimizer_step( # type: ignore[override] @override def clip_gradients( self, + module: Optional[Module], optimizer: Optimizer, clip_val: Union[int, float] = 0.0, gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index f3bab3e915e91..280bc4351f237 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import AbstractContextManager -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch from lightning_utilities import apply_to_collection from torch import Tensor from torch.nn import Module +from torch.optim import Optimizer from typing_extensions import get_args, override import lightning.pytorch as pl @@ -81,14 +82,11 @@ def convert_module(self, module: Module) -> Module: return module @override - def clip_grad_by_norm(self, *_: Any, **__: Any) -> None: + def clip_grad_by_norm(self, module: Optional[Module], optimizer: Optimizer, clip_val: Union[int, float]) -> None: # see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_ - # section `Gradient Clipping`, using `torch.nn.utils.clip_grad_norm_` is incorrect with FSDP. - # To overcome this we need to call root_sharded_module.clip_grad_norm(clip_val), but we don't have a reference - # to the root module - raise MisconfigurationException( - f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`" - ) + if module is None: + return + module.clip_grad_norm_(clip_val) @property def mixed_precision_config(self) -> "TorchMixedPrecision": diff --git a/src/lightning/pytorch/plugins/precision/precision.py b/src/lightning/pytorch/plugins/precision/precision.py index 327fb2d4f5a27..a11182db68f97 100644 --- a/src/lightning/pytorch/plugins/precision/precision.py +++ b/src/lightning/pytorch/plugins/precision/precision.py @@ -143,6 +143,7 @@ def _clip_gradients( def clip_gradients( self, + module: Optional[Module], optimizer: Optimizer, clip_val: Union[int, float] = 0.0, gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, @@ -153,14 +154,14 @@ def clip_gradients( if gradient_clip_algorithm == GradClipAlgorithmType.VALUE: self.clip_grad_by_value(optimizer, clip_val) elif gradient_clip_algorithm == GradClipAlgorithmType.NORM: - self.clip_grad_by_norm(optimizer, clip_val) + self.clip_grad_by_norm(module, optimizer, clip_val) def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None: """Clip gradients by value.""" parameters = self.main_params(optimizer) torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val) - def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None: + def clip_grad_by_norm(self, module: Optional[Module], optimizer: Optimizer, clip_val: Union[int, float]) -> None: """Clip gradients by norm.""" parameters = self.main_params(optimizer) torch.nn.utils.clip_grad_norm_(parameters, clip_val) diff --git a/tests/tests_pytorch/plugins/precision/test_amp.py b/tests/tests_pytorch/plugins/precision/test_amp.py index cb061c540b2be..900892fad5fdd 100644 --- a/tests/tests_pytorch/plugins/precision/test_amp.py +++ b/tests/tests_pytorch/plugins/precision/test_amp.py @@ -14,6 +14,7 @@ from unittest.mock import Mock import pytest +from torch.nn import Module from torch.optim import Optimizer from lightning.pytorch.plugins import MixedPrecision @@ -22,22 +23,23 @@ def test_clip_gradients(): """Test that `.clip_gradients()` is a no-op when clipping is disabled.""" + module = Mock(spec=Module) optimizer = Mock(spec=Optimizer) precision = MixedPrecision(precision="16-mixed", device="cuda:0", scaler=Mock()) precision.clip_grad_by_value = Mock() precision.clip_grad_by_norm = Mock() - precision.clip_gradients(optimizer) + precision.clip_gradients(module, optimizer) precision.clip_grad_by_value.assert_not_called() precision.clip_grad_by_norm.assert_not_called() - precision.clip_gradients(optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.VALUE) + precision.clip_gradients(module, optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.VALUE) precision.clip_grad_by_value.assert_called_once() precision.clip_grad_by_norm.assert_not_called() precision.clip_grad_by_value.reset_mock() precision.clip_grad_by_norm.reset_mock() - precision.clip_gradients(optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.NORM) + precision.clip_gradients(module, optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.NORM) precision.clip_grad_by_value.assert_not_called() precision.clip_grad_by_norm.assert_called_once() @@ -46,8 +48,9 @@ def test_optimizer_amp_scaling_support_in_step_method(): """Test that the plugin checks if the optimizer takes over unscaling in its step, making it incompatible with gradient clipping (example: fused Adam).""" + module = Mock(spec=Module) optimizer = Mock(_step_supports_amp_scaling=True) precision = MixedPrecision(precision="16-mixed", device="cuda:0", scaler=Mock()) with pytest.raises(RuntimeError, match="The current optimizer.*does not allow for gradient clipping"): - precision.clip_gradients(optimizer, clip_val=1.0) + precision.clip_gradients(module, optimizer, clip_val=1.0)