@@ -455,48 +455,39 @@ def _get_checkpoint_shard_files(
455
455
allow_patterns = [os .path .join (subfolder , p ) for p in allow_patterns ]
456
456
457
457
ignore_patterns = ["*.json" , "*.md" ]
458
- if not local_files_only :
459
- # `model_info` call must guarded with the above condition.
460
- model_files_info = model_info (pretrained_model_name_or_path , revision = revision , token = token )
461
- for shard_file in original_shard_filenames :
462
- shard_file_present = any (shard_file in k .rfilename for k in model_files_info .siblings )
463
- if not shard_file_present :
464
- raise EnvironmentError (
465
- f"{ shards_path } does not appear to have a file named { shard_file } which is "
466
- "required according to the checkpoint index."
467
- )
468
-
469
- try :
470
- # Load from URL
471
- cached_folder = snapshot_download (
472
- pretrained_model_name_or_path ,
473
- cache_dir = cache_dir ,
474
- proxies = proxies ,
475
- local_files_only = local_files_only ,
476
- token = token ,
477
- revision = revision ,
478
- allow_patterns = allow_patterns ,
479
- ignore_patterns = ignore_patterns ,
480
- user_agent = user_agent ,
481
- )
482
- if subfolder is not None :
483
- cached_folder = os .path .join (cached_folder , subfolder )
484
-
485
- # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
486
- # we don't have to catch them here. We have also dealt with EntryNotFoundError.
487
- except HTTPError as e :
458
+ # `model_info` call must guarded with the above condition.
459
+ model_files_info = model_info (pretrained_model_name_or_path , revision = revision , token = token )
460
+ for shard_file in original_shard_filenames :
461
+ shard_file_present = any (shard_file in k .rfilename for k in model_files_info .siblings )
462
+ if not shard_file_present :
488
463
raise EnvironmentError (
489
- f"We couldn't connect to ' { HUGGINGFACE_CO_RESOLVE_ENDPOINT } ' to load { pretrained_model_name_or_path } . You should try "
490
- " again after checking your internet connection ."
491
- ) from e
464
+ f"{ shards_path } does not appear to have a file named { shard_file } which is "
465
+ "required according to the checkpoint index ."
466
+ )
492
467
493
- # If `local_files_only=True`, `cached_folder` may not contain all the shard files.
494
- elif local_files_only :
495
- _check_if_shards_exist_locally (
496
- local_dir = cache_dir , subfolder = subfolder , original_shard_filenames = original_shard_filenames
468
+ try :
469
+ # Load from URL
470
+ cached_folder = snapshot_download (
471
+ pretrained_model_name_or_path ,
472
+ cache_dir = cache_dir ,
473
+ proxies = proxies ,
474
+ local_files_only = local_files_only ,
475
+ token = token ,
476
+ revision = revision ,
477
+ allow_patterns = allow_patterns ,
478
+ ignore_patterns = ignore_patterns ,
479
+ user_agent = user_agent ,
497
480
)
498
481
if subfolder is not None :
499
- cached_folder = os .path .join (cache_dir , subfolder )
482
+ cached_folder = os .path .join (cached_folder , subfolder )
483
+
484
+ # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
485
+ # we don't have to catch them here. We have also dealt with EntryNotFoundError.
486
+ except HTTPError as e :
487
+ raise EnvironmentError (
488
+ f"We couldn't connect to '{ HUGGINGFACE_CO_RESOLVE_ENDPOINT } ' to load { pretrained_model_name_or_path } . You should try"
489
+ " again after checking your internet connection."
490
+ ) from e
500
491
501
492
return cached_folder , sharded_metadata
502
493
0 commit comments