Skip to content

Commit 22cdd10

Browse files
awaelchlilantiga
authored andcommitted
Only validate schedulers in automatic optimization (#18092)
(cherry picked from commit ea92c21)
1 parent 8054e4f commit 22cdd10

File tree

3 files changed

+43
-1
lines changed

3 files changed

+43
-1
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1616
- `LightningCLI` not saving correctly `seed_everything` when `run=True` and `seed_everything=True` ([#18056](https://github.com/Lightning-AI/lightning/pull/18056))
1717

1818

19+
- Fixed validation of non-PyTorch LR schedulers in manual optimization mode ([#18092](https://github.com/Lightning-AI/lightning/pull/18092))
20+
21+
1922
## [2.0.5] - 2023-07-07
2023

2124
### Fixed

src/lightning/pytorch/core/optimizer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,11 @@ def _validate_scheduler_api(lr_scheduler_configs: List[LRSchedulerConfig], model
325325
" It should have `state_dict` and `load_state_dict` methods defined."
326326
)
327327

328-
if not isinstance(scheduler, LRSchedulerTypeTuple) and not is_overridden("lr_scheduler_step", model):
328+
if (
329+
not isinstance(scheduler, LRSchedulerTypeTuple)
330+
and not is_overridden("lr_scheduler_step", model)
331+
and model.automatic_optimization
332+
):
329333
raise MisconfigurationException(
330334
f"The provided lr scheduler `{scheduler.__class__.__name__}` doesn't follow PyTorch's LRScheduler"
331335
" API. You should override the `LightningModule.lr_scheduler_step` hook with your own logic if"

tests/tests_pytorch/trainer/optimization/test_manual_optimization.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import torch.distributed as torch_distrib
2323
import torch.nn.functional as F
2424

25+
from lightning.fabric.utilities.exceptions import MisconfigurationException
2526
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
2627
from lightning.pytorch import seed_everything, Trainer
2728
from lightning.pytorch.demos.boring_classes import BoringModel, ManualOptimBoringModel
@@ -886,3 +887,37 @@ def configure_optimizers(self):
886887

887888
assert set(trainer.logged_metrics) == {"loss_d", "loss_g"}
888889
assert set(trainer.progress_bar_metrics) == {"loss_d", "loss_g"}
890+
891+
892+
@pytest.mark.parametrize("automatic_optimization", [True, False])
893+
def test_manual_optimization_with_non_pytorch_scheduler(automatic_optimization):
894+
"""In manual optimization, the user can provide a custom scheduler that doesn't follow PyTorch's interface."""
895+
896+
class IncompatibleScheduler:
897+
def __init__(self, optimizer):
898+
self.optimizer = optimizer
899+
900+
def state_dict(self):
901+
return {}
902+
903+
def load_state_dict(self, _):
904+
pass
905+
906+
class Model(BoringModel):
907+
def __init__(self):
908+
super().__init__()
909+
self.automatic_optimization = automatic_optimization
910+
911+
def configure_optimizers(self):
912+
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
913+
scheduler = IncompatibleScheduler(optimizer)
914+
return [optimizer], [scheduler]
915+
916+
model = Model()
917+
trainer = Trainer(accelerator="cpu", max_epochs=0)
918+
if automatic_optimization:
919+
with pytest.raises(MisconfigurationException, match="doesn't follow PyTorch's LRScheduler"):
920+
trainer.fit(model)
921+
else:
922+
# No error for manual optimization
923+
trainer.fit(model)

0 commit comments

Comments
 (0)