Skip to content

Commit 7116a9f

Browse files
ishandutta0098pre-commit-ci[bot]Bordaawaelchli
authored
Include parent directory validation check for deepspeed (#17795)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
1 parent c31ef77 commit 7116a9f

File tree

6 files changed

+107
-9
lines changed

6 files changed

+107
-9
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
9494
- Added the ability to set the `torch.distributed.fsdp.ShardingStrategy` via string in `FSDPStrategy` ([#18087](https://github.com/Lightning-AI/lightning/pull/18087))
9595

9696

97+
- Improved error messages when attempting to load a DeepSpeed checkpoint at an invalid path ([#17795](https://github.com/Lightning-AI/lightning/pull/17795))
98+
99+
97100
### Changed
98101

99102
- Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331))

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,7 @@ def load_checkpoint(
478478
f" a model instance to reload is required. Pass it in like so:"
479479
" DeepSpeedStrategy.load_checkpoint(..., state={'model': model, ...})"
480480
)
481+
_validate_checkpoint_directory(path)
481482

482483
engines = _get_deepspeed_engines_from_state(state)
483484
if len(engines) == 0:
@@ -503,6 +504,7 @@ def load_checkpoint(
503504
load_lr_scheduler_states=False,
504505
load_module_strict=strict,
505506
)
507+
506508
if client_state is None:
507509
raise RuntimeError(
508510
"DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint"
@@ -843,3 +845,43 @@ def _validate_device_index_selection(parallel_devices: List[torch.device]) -> No
843845
" If you need to select GPUs at a specific index, set the `CUDA_VISIBLE_DEVICES` environment variable"
844846
f" instead. For example: `CUDA_VISIBLE_DEVICES={','.join(str(i) for i in selected_device_indices)}`."
845847
)
848+
849+
850+
def _is_deepspeed_checkpoint(path: Path) -> bool:
851+
"""Heuristic check whether the path points to a top-level DeepSpeed checkpoint directory."""
852+
return path.is_dir() and (path / "checkpoint").is_dir()
853+
854+
855+
def _validate_checkpoint_directory(path: _PATH) -> None:
856+
"""Validates that the path points to a DeepSpeed checkpoint directory and suggests fixes for user error."""
857+
# Example DeepSpeed checkpoint directory:
858+
#
859+
# epoch=5-step=10999.ckpt
860+
# ├── checkpoint
861+
# │ ├── zero_pp_rank_0_mp_rank_00_model_states.pt
862+
# │ ├── zero_pp_rank_0_mp_rank_00_optim_states.pt
863+
# │ ├── zero_pp_rank_1_mp_rank_00_model_states.pt
864+
# │ └── zero_pp_rank_1_mp_rank_00_optim_states.pt
865+
# ├── latest
866+
# └── zero_to_fp32.py
867+
868+
path = Path(path)
869+
path_is_ds_checkpoint = _is_deepspeed_checkpoint(path)
870+
default_message = f"The provided path is not a valid DeepSpeed checkpoint: {path}"
871+
872+
if not path_is_ds_checkpoint:
873+
# Case 1: User may have accidentally passed the subfolder "checkpoint"
874+
parent_is_ds_checkpoint = _is_deepspeed_checkpoint(path.parent)
875+
if parent_is_ds_checkpoint:
876+
raise FileNotFoundError(
877+
f"{default_message}. It looks like you passed the path to a subfolder."
878+
f" Try to load using this parent directory instead: {path.parent}"
879+
)
880+
# Case 2: User may have accidentally passed the path to a file inside the "checkpoint" subfolder
881+
parent_parent_is_ds_checkpoint = path.is_file() and _is_deepspeed_checkpoint(path.parent.parent)
882+
if parent_parent_is_ds_checkpoint:
883+
raise FileNotFoundError(
884+
f"{default_message}. It looks like you passed the path to a file inside a DeepSpeed checkpoint folder."
885+
f" Try to load using this parent directory instead: {path.parent.parent}"
886+
)
887+
raise FileNotFoundError(default_message)

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8585
- Added the ability to set the `torch.distributed.fsdp.ShardingStrategy` via string in `FSDPStrategy` ([#18087](https://github.com/Lightning-AI/lightning/pull/18087))
8686

8787

88+
- Improved error messages when attempting to load a DeepSpeed checkpoint at an invalid path ([#17795](https://github.com/Lightning-AI/lightning/pull/17795))
89+
90+
8891
### Changed
8992

9093
- Removed the limitation to call `self.trainer.model.parameters()` in `LightningModule.configure_optimizers()` ([#17309](https://github.com/Lightning-AI/lightning/pull/17309))

src/lightning/pytorch/strategies/deepspeed.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,11 @@
3030
import lightning.pytorch as pl
3131
from lightning.fabric.plugins import ClusterEnvironment
3232
from lightning.fabric.strategies import _StrategyRegistry
33-
from lightning.fabric.strategies.deepspeed import _DEEPSPEED_AVAILABLE, _validate_device_index_selection
33+
from lightning.fabric.strategies.deepspeed import (
34+
_DEEPSPEED_AVAILABLE,
35+
_validate_checkpoint_directory,
36+
_validate_device_index_selection,
37+
)
3438
from lightning.fabric.utilities.optimizer import _optimizers_to_device
3539
from lightning.fabric.utilities.seed import reset_seed
3640
from lightning.fabric.utilities.types import _PATH, LRScheduler, ReduceLROnPlateau
@@ -790,12 +794,15 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
790794
checkpoint_path = self.broadcast(checkpoint_path)
791795
return super().load_checkpoint(checkpoint_path)
792796

797+
_validate_checkpoint_directory(checkpoint_path)
798+
793799
# Rely on deepspeed to load the checkpoint and necessary information
794800
assert self.lightning_module is not None
795801

796802
from lightning.pytorch.trainer.states import TrainerFn
797803

798804
is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING
805+
799806
_, client_state = self.deepspeed_engine.load_checkpoint(
800807
checkpoint_path, load_optimizer_states=is_fitting, load_lr_scheduler_states=False
801808
)

tests/tests_fabric/strategies/test_deepspeed.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,26 @@ def test_deepspeed_save_checkpoint_warn_colliding_keys(tmp_path):
219219
strategy.save_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "mp_world_size": 2})
220220

221221

222+
@RunIf(deepspeed=True)
223+
def test_deepspeed_load_checkpoint_validate_path(tmp_path):
224+
"""Test that we validate the checkpoint path for a DeepSpeed checkpoint and give suggestions for user error."""
225+
strategy = DeepSpeedStrategy()
226+
with pytest.raises(FileNotFoundError, match="The provided path is not a valid DeepSpeed checkpoint"):
227+
strategy.load_checkpoint(path=tmp_path, state={"model": Mock()})
228+
229+
# User tries to pass the subfolder as the path
230+
checkpoint_path = tmp_path / "checkpoint"
231+
checkpoint_path.mkdir()
232+
with pytest.raises(FileNotFoundError, match=f"Try to load using this parent directory instead: {tmp_path}"):
233+
strategy.load_checkpoint(path=checkpoint_path, state={"model": Mock()})
234+
235+
# User tries to pass an individual file inside the checkpoint folder
236+
checkpoint_path = checkpoint_path / "zero_pp_rank_0_mp_rank_00_model_states.pt"
237+
checkpoint_path.touch()
238+
with pytest.raises(FileNotFoundError, match=f"Try to load using this parent directory instead: {tmp_path}"):
239+
strategy.load_checkpoint(path=checkpoint_path, state={"model": Mock()})
240+
241+
222242
@RunIf(deepspeed=True)
223243
def test_deepspeed_load_checkpoint_no_state(tmp_path):
224244
"""Test that DeepSpeed can't load the full state without access to a model instance from the user."""
@@ -230,7 +250,8 @@ def test_deepspeed_load_checkpoint_no_state(tmp_path):
230250

231251

232252
@RunIf(deepspeed=True)
233-
def test_deepspeed_load_checkpoint_one_deepspeed_engine_required(tmp_path):
253+
@mock.patch("lightning.fabric.strategies.deepspeed._is_deepspeed_checkpoint", return_value=True)
254+
def test_deepspeed_load_checkpoint_one_deepspeed_engine_required(_, tmp_path):
234255
"""Test that the DeepSpeed strategy can only load one DeepSpeedEngine per checkpoint."""
235256
from deepspeed import DeepSpeedEngine
236257

@@ -266,12 +287,13 @@ def test_deepspeed_load_checkpoint_client_state_missing(tmp_path):
266287
model.load_checkpoint.return_value = [None, None]
267288

268289
# Check for our custom user error
269-
with pytest.raises(RuntimeError, match="DeepSpeed was unable to load the checkpoint"):
290+
with pytest.raises(FileNotFoundError, match="The provided path is not a valid DeepSpeed checkpoint"):
270291
strategy.load_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "test": "data"})
271292

272293

273294
@RunIf(deepspeed=True)
274-
def test_deepspeed_load_checkpoint_state_updated_with_client_state(tmp_path):
295+
@mock.patch("lightning.fabric.strategies.deepspeed._is_deepspeed_checkpoint", return_value=True)
296+
def test_deepspeed_load_checkpoint_state_updated_with_client_state(_, tmp_path):
275297
"""Test that the DeepSpeed strategy properly updates the state variables and returns additional metadata."""
276298
from deepspeed import DeepSpeedEngine
277299

@@ -295,7 +317,8 @@ def test_deepspeed_load_checkpoint_state_updated_with_client_state(tmp_path):
295317

296318
@RunIf(deepspeed=True)
297319
@pytest.mark.parametrize("optimzer_state_requested", [True, False])
298-
def test_deepspeed_load_checkpoint_optimzer_state_requested(optimzer_state_requested, tmp_path):
320+
@mock.patch("lightning.fabric.strategies.deepspeed._is_deepspeed_checkpoint", return_value=True)
321+
def test_deepspeed_load_checkpoint_optimzer_state_requested(_, optimzer_state_requested, tmp_path):
299322
"""Test that the DeepSpeed strategy loads the optimizer state only when requested."""
300323
from deepspeed import DeepSpeedEngine
301324

@@ -343,7 +366,7 @@ def test_errors_grad_clipping():
343366
strategy.clip_gradients_value(Mock(), Mock(), Mock())
344367

345368

346-
@RunIf(deepspeed=True)
369+
@RunIf(deepspeed=True, mps=False)
347370
def test_deepspeed_save_filter(tmp_path):
348371
strategy = DeepSpeedStrategy()
349372
with pytest.raises(TypeError, match="manages the state serialization internally"):

tests/tests_pytorch/strategies/test_deepspeed_strategy.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def test_deepspeed_strategy_env(tmpdir, monkeypatch, deepspeed_config):
136136
assert strategy.config == deepspeed_config
137137

138138

139-
@RunIf(deepspeed=True)
139+
@RunIf(deepspeed=True, mps=False)
140140
def test_deepspeed_precision_choice(cuda_count_1, tmpdir):
141141
"""Test to ensure precision plugin is also correctly chosen.
142142
@@ -547,7 +547,7 @@ def test_deepspeed_multigpu_single_file(tmpdir):
547547
strategy = trainer.strategy
548548
assert isinstance(strategy, DeepSpeedStrategy)
549549
assert not strategy.load_full_weights
550-
with pytest.raises(MisconfigurationException, match="DeepSpeed was unable to load the checkpoint."):
550+
with pytest.raises(FileNotFoundError, match="The provided path is not a valid DeepSpeed checkpoint"):
551551
trainer.test(model, ckpt_path=checkpoint_path)
552552

553553
trainer = Trainer(
@@ -955,7 +955,7 @@ def on_train_epoch_start(self) -> None:
955955
trainer.fit(model)
956956

957957

958-
@RunIf(deepspeed=True)
958+
@RunIf(deepspeed=True, mps=False)
959959
@mock.patch("deepspeed.init_distributed", autospec=True)
960960
@pytest.mark.parametrize("platform", ["Linux", "Windows"])
961961
def test_deepspeed_strategy_env_variables(mock_deepspeed_distributed, tmpdir, platform):
@@ -1315,3 +1315,23 @@ def test_deepspeed_init_module_with_stages_1_2(stage):
13151315

13161316
zero_init_mock.assert_not_called()
13171317
assert model.layer.weight.dtype == torch.bfloat16
1318+
1319+
1320+
@RunIf(deepspeed=True)
1321+
def test_deepspeed_load_checkpoint_validate_path(tmp_path):
1322+
"""Test that we validate the checkpoint path for a DeepSpeed checkpoint and give suggestions for user error."""
1323+
strategy = DeepSpeedStrategy()
1324+
with pytest.raises(FileNotFoundError, match="The provided path is not a valid DeepSpeed checkpoint"):
1325+
strategy.load_checkpoint(checkpoint_path=tmp_path)
1326+
1327+
# User tries to pass the subfolder as the path
1328+
checkpoint_path = tmp_path / "checkpoint"
1329+
checkpoint_path.mkdir()
1330+
with pytest.raises(FileNotFoundError, match=f"Try to load using this parent directory instead: {tmp_path}"):
1331+
strategy.load_checkpoint(checkpoint_path=checkpoint_path)
1332+
1333+
# User tries to pass an individual file inside the checkpoint folder
1334+
checkpoint_path = checkpoint_path / "zero_pp_rank_0_mp_rank_00_model_states.pt"
1335+
checkpoint_path.touch()
1336+
with pytest.raises(FileNotFoundError, match=f"Try to load using this parent directory instead: {tmp_path}"):
1337+
strategy.load_checkpoint(checkpoint_path=checkpoint_path)

0 commit comments

Comments
 (0)