Skip to content

Commit 41320d8

Browse files
awaelchlilantiga
authored andcommitted
Enable loading legacy checkpoints that pickled the _FaultToleranceMode enum (#18094)
(cherry picked from commit 5308e90)
1 parent c161d06 commit 41320d8

File tree

3 files changed

+24
-0
lines changed

3 files changed

+24
-0
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1919
- Fixed validation of non-PyTorch LR schedulers in manual optimization mode ([#18092](https://github.com/Lightning-AI/lightning/pull/18092))
2020

2121

22+
- Fixed an attribute error for `_FaultTolerantMode` when loading an old checkpoint that pickled the enum ([#18094](https://github.com/Lightning-AI/lightning/pull/18094))
23+
24+
2225
## [2.0.5] - 2023-07-07
2326

2427
### Fixed

src/lightning/pytorch/utilities/migration/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from packaging.version import Version
2222

2323
import lightning.pytorch as pl
24+
from lightning.fabric.utilities.enums import LightningEnum
2425
from lightning.fabric.utilities.imports import _IS_WINDOWS
2526
from lightning.fabric.utilities.types import _PATH
2627
from lightning.fabric.utilities.warnings import PossibleUserWarning
@@ -79,6 +80,8 @@ class pl_legacy_patch:
7980
version 1.2.8. See: https://github.com/Lightning-AI/lightning/pull/6898
8081
2. ``lightning.pytorch.utilities.argparse_utils``: A module that was deprecated in 1.2 and removed in 1.4,
8182
but still needs to be available for import for legacy checkpoints.
83+
3. ``lightning.pytorch.utilities.enums._FaultTolerantMode``: This enum was removed in 2.0 but was pickled
84+
into older checkpoints.
8285
8386
Example:
8487
@@ -95,6 +98,14 @@ def __enter__(self) -> "pl_legacy_patch":
9598
# `_gpus_arg_default` used to be imported from these locations
9699
legacy_argparse_module._gpus_arg_default = lambda x: x
97100
pl.utilities.argparse._gpus_arg_default = lambda x: x
101+
102+
# `_FaultTolerantMode` was removed from the enums
103+
class _FaultTolerantMode(LightningEnum):
104+
DISABLED = "disabled"
105+
AUTOMATIC = "automatic"
106+
MANUAL = "manual"
107+
108+
pl.utilities.enums._FaultTolerantMode = _FaultTolerantMode
98109
return self
99110

100111
def __exit__(
@@ -106,6 +117,8 @@ def __exit__(
106117
if hasattr(pl.utilities.argparse, "_gpus_arg_default"):
107118
delattr(pl.utilities.argparse, "_gpus_arg_default")
108119
del sys.modules["lightning.pytorch.utilities.argparse_utils"]
120+
if hasattr(pl.utilities.enums, "_FaultTolerantMode"):
121+
delattr(pl.utilities.enums, "_FaultTolerantMode")
109122
_lock.release()
110123

111124

tests/tests_pytorch/utilities/migration/test_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ def test_patch_legacy_gpus_arg_default():
4242
assert not hasattr(pl.utilities.argparse, "_gpus_arg_default")
4343

4444

45+
def test_patch_legacy_fault_tolerant_mode():
46+
with pl_legacy_patch():
47+
from lightning.pytorch.utilities.enums import _FaultTolerantMode
48+
49+
assert _FaultTolerantMode.AUTOMATIC.value == "automatic"
50+
assert not hasattr(pl.utilities.enums, "_FaultTolerantMode")
51+
52+
4553
def test_migrate_checkpoint(monkeypatch):
4654
"""Test that the correct migration function gets executed given the current version of the checkpoint."""
4755
# A checkpoint that is older than any migration point in the index

0 commit comments

Comments
 (0)