diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index ef4715ee0e1e..a6dfe18433e3 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -455,48 +455,39 @@ def _get_checkpoint_shard_files( allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns] ignore_patterns = ["*.json", "*.md"] - if not local_files_only: - # `model_info` call must guarded with the above condition. - model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token) - for shard_file in original_shard_filenames: - shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings) - if not shard_file_present: - raise EnvironmentError( - f"{shards_path} does not appear to have a file named {shard_file} which is " - "required according to the checkpoint index." - ) - - try: - # Load from URL - cached_folder = snapshot_download( - pretrained_model_name_or_path, - cache_dir=cache_dir, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - allow_patterns=allow_patterns, - ignore_patterns=ignore_patterns, - user_agent=user_agent, - ) - if subfolder is not None: - cached_folder = os.path.join(cached_folder, subfolder) - - # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so - # we don't have to catch them here. We have also dealt with EntryNotFoundError. - except HTTPError as e: + # `model_info` call must guarded with the above condition. + model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token) + for shard_file in original_shard_filenames: + shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings) + if not shard_file_present: raise EnvironmentError( - f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try" - " again after checking your internet connection." - ) from e + f"{shards_path} does not appear to have a file named {shard_file} which is " + "required according to the checkpoint index." + ) - # If `local_files_only=True`, `cached_folder` may not contain all the shard files. - elif local_files_only: - _check_if_shards_exist_locally( - local_dir=cache_dir, subfolder=subfolder, original_shard_filenames=original_shard_filenames + try: + # Load from URL + cached_folder = snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + user_agent=user_agent, ) if subfolder is not None: - cached_folder = os.path.join(cache_dir, subfolder) + cached_folder = os.path.join(cached_folder, subfolder) + + # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so + # we don't have to catch them here. We have also dealt with EntryNotFoundError. + except HTTPError as e: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try" + " again after checking your internet connection." + ) from e return cached_folder, sharded_metadata