Skip to content

Commit 95a7832

Browse files
yiyixuxusayakpaul
andauthored
fix load sharded checkpoint from a subfolder (local path) (#8913)
fix Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent c646fbc commit 95a7832

File tree

2 files changed

+60
-24
lines changed

2 files changed

+60
-24
lines changed

src/diffusers/utils/hub_utils.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ def _get_checkpoint_shard_files(
448448
_check_if_shards_exist_locally(
449449
pretrained_model_name_or_path, subfolder=subfolder, original_shard_filenames=original_shard_filenames
450450
)
451-
return pretrained_model_name_or_path, sharded_metadata
451+
return shards_path, sharded_metadata
452452

453453
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
454454
allow_patterns = original_shard_filenames
@@ -467,35 +467,37 @@ def _get_checkpoint_shard_files(
467467
"required according to the checkpoint index."
468468
)
469469

470-
try:
471-
# Load from URL
472-
cached_folder = snapshot_download(
473-
pretrained_model_name_or_path,
474-
cache_dir=cache_dir,
475-
proxies=proxies,
476-
local_files_only=local_files_only,
477-
token=token,
478-
revision=revision,
479-
allow_patterns=allow_patterns,
480-
ignore_patterns=ignore_patterns,
481-
user_agent=user_agent,
482-
)
483-
if subfolder is not None:
484-
cached_folder = os.path.join(cached_folder, subfolder)
470+
try:
471+
# Load from URL
472+
cached_folder = snapshot_download(
473+
pretrained_model_name_or_path,
474+
cache_dir=cache_dir,
475+
proxies=proxies,
476+
local_files_only=local_files_only,
477+
token=token,
478+
revision=revision,
479+
allow_patterns=allow_patterns,
480+
ignore_patterns=ignore_patterns,
481+
user_agent=user_agent,
482+
)
483+
if subfolder is not None:
484+
cached_folder = os.path.join(cached_folder, subfolder)
485485

486-
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
487-
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
488-
except HTTPError as e:
489-
raise EnvironmentError(
490-
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
491-
" again after checking your internet connection."
492-
) from e
486+
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
487+
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
488+
except HTTPError as e:
489+
raise EnvironmentError(
490+
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
491+
" again after checking your internet connection."
492+
) from e
493493

494494
# If `local_files_only=True`, `cached_folder` may not contain all the shard files.
495-
if local_files_only:
495+
elif local_files_only:
496496
_check_if_shards_exist_locally(
497497
local_dir=cache_dir, subfolder=subfolder, original_shard_filenames=original_shard_filenames
498498
)
499+
if subfolder is not None:
500+
cached_folder = os.path.join(cached_folder, subfolder)
499501

500502
return cached_folder, sharded_metadata
501503

tests/models/unets/test_models_unet_2d_condition.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,6 +1068,17 @@ def test_load_sharded_checkpoint_from_hub_local(self):
10681068
assert loaded_model
10691069
assert new_output.sample.shape == (4, 4, 16, 16)
10701070

1071+
@require_torch_gpu
1072+
def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
1073+
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1074+
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
1075+
loaded_model = self.model_class.from_pretrained(ckpt_path, subfolder="unet", local_files_only=True)
1076+
loaded_model = loaded_model.to(torch_device)
1077+
new_output = loaded_model(**inputs_dict)
1078+
1079+
assert loaded_model
1080+
assert new_output.sample.shape == (4, 4, 16, 16)
1081+
10711082
@require_torch_gpu
10721083
def test_load_sharded_checkpoint_device_map_from_hub(self):
10731084
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -1077,6 +1088,17 @@ def test_load_sharded_checkpoint_device_map_from_hub(self):
10771088
assert loaded_model
10781089
assert new_output.sample.shape == (4, 4, 16, 16)
10791090

1091+
@require_torch_gpu
1092+
def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self):
1093+
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1094+
loaded_model = self.model_class.from_pretrained(
1095+
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map="auto"
1096+
)
1097+
new_output = loaded_model(**inputs_dict)
1098+
1099+
assert loaded_model
1100+
assert new_output.sample.shape == (4, 4, 16, 16)
1101+
10801102
@require_torch_gpu
10811103
def test_load_sharded_checkpoint_device_map_from_hub_local(self):
10821104
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -1087,6 +1109,18 @@ def test_load_sharded_checkpoint_device_map_from_hub_local(self):
10871109
assert loaded_model
10881110
assert new_output.sample.shape == (4, 4, 16, 16)
10891111

1112+
@require_torch_gpu
1113+
def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
1114+
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1115+
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
1116+
loaded_model = self.model_class.from_pretrained(
1117+
ckpt_path, local_files_only=True, subfolder="unet", device_map="auto"
1118+
)
1119+
new_output = loaded_model(**inputs_dict)
1120+
1121+
assert loaded_model
1122+
assert new_output.sample.shape == (4, 4, 16, 16)
1123+
10901124
@require_peft_backend
10911125
def test_lora(self):
10921126
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

0 commit comments

Comments
 (0)