Skip to content

Commit 074798b

Browse files
authored
Fix local_files_only for checkpoints with shards (#10294)
1 parent 3ee9669 commit 074798b

File tree

1 file changed

+29
-38
lines changed

1 file changed

+29
-38
lines changed

src/diffusers/utils/hub_utils.py

Lines changed: 29 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -455,48 +455,39 @@ def _get_checkpoint_shard_files(
455455
allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns]
456456

457457
ignore_patterns = ["*.json", "*.md"]
458-
if not local_files_only:
459-
# `model_info` call must guarded with the above condition.
460-
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
461-
for shard_file in original_shard_filenames:
462-
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
463-
if not shard_file_present:
464-
raise EnvironmentError(
465-
f"{shards_path} does not appear to have a file named {shard_file} which is "
466-
"required according to the checkpoint index."
467-
)
468-
469-
try:
470-
# Load from URL
471-
cached_folder = snapshot_download(
472-
pretrained_model_name_or_path,
473-
cache_dir=cache_dir,
474-
proxies=proxies,
475-
local_files_only=local_files_only,
476-
token=token,
477-
revision=revision,
478-
allow_patterns=allow_patterns,
479-
ignore_patterns=ignore_patterns,
480-
user_agent=user_agent,
481-
)
482-
if subfolder is not None:
483-
cached_folder = os.path.join(cached_folder, subfolder)
484-
485-
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
486-
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
487-
except HTTPError as e:
458+
# `model_info` call must guarded with the above condition.
459+
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
460+
for shard_file in original_shard_filenames:
461+
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
462+
if not shard_file_present:
488463
raise EnvironmentError(
489-
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
490-
" again after checking your internet connection."
491-
) from e
464+
f"{shards_path} does not appear to have a file named {shard_file} which is "
465+
"required according to the checkpoint index."
466+
)
492467

493-
# If `local_files_only=True`, `cached_folder` may not contain all the shard files.
494-
elif local_files_only:
495-
_check_if_shards_exist_locally(
496-
local_dir=cache_dir, subfolder=subfolder, original_shard_filenames=original_shard_filenames
468+
try:
469+
# Load from URL
470+
cached_folder = snapshot_download(
471+
pretrained_model_name_or_path,
472+
cache_dir=cache_dir,
473+
proxies=proxies,
474+
local_files_only=local_files_only,
475+
token=token,
476+
revision=revision,
477+
allow_patterns=allow_patterns,
478+
ignore_patterns=ignore_patterns,
479+
user_agent=user_agent,
497480
)
498481
if subfolder is not None:
499-
cached_folder = os.path.join(cache_dir, subfolder)
482+
cached_folder = os.path.join(cached_folder, subfolder)
483+
484+
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
485+
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
486+
except HTTPError as e:
487+
raise EnvironmentError(
488+
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
489+
" again after checking your internet connection."
490+
) from e
500491

501492
return cached_folder, sharded_metadata
502493

0 commit comments

Comments
 (0)