@@ -448,7 +448,7 @@ def _get_checkpoint_shard_files(
448
448
_check_if_shards_exist_locally (
449
449
pretrained_model_name_or_path , subfolder = subfolder , original_shard_filenames = original_shard_filenames
450
450
)
451
- return pretrained_model_name_or_path , sharded_metadata
451
+ return shards_path , sharded_metadata
452
452
453
453
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
454
454
allow_patterns = original_shard_filenames
@@ -467,35 +467,37 @@ def _get_checkpoint_shard_files(
467
467
"required according to the checkpoint index."
468
468
)
469
469
470
- try :
471
- # Load from URL
472
- cached_folder = snapshot_download (
473
- pretrained_model_name_or_path ,
474
- cache_dir = cache_dir ,
475
- proxies = proxies ,
476
- local_files_only = local_files_only ,
477
- token = token ,
478
- revision = revision ,
479
- allow_patterns = allow_patterns ,
480
- ignore_patterns = ignore_patterns ,
481
- user_agent = user_agent ,
482
- )
483
- if subfolder is not None :
484
- cached_folder = os .path .join (cached_folder , subfolder )
470
+ try :
471
+ # Load from URL
472
+ cached_folder = snapshot_download (
473
+ pretrained_model_name_or_path ,
474
+ cache_dir = cache_dir ,
475
+ proxies = proxies ,
476
+ local_files_only = local_files_only ,
477
+ token = token ,
478
+ revision = revision ,
479
+ allow_patterns = allow_patterns ,
480
+ ignore_patterns = ignore_patterns ,
481
+ user_agent = user_agent ,
482
+ )
483
+ if subfolder is not None :
484
+ cached_folder = os .path .join (cached_folder , subfolder )
485
485
486
- # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
487
- # we don't have to catch them here. We have also dealt with EntryNotFoundError.
488
- except HTTPError as e :
489
- raise EnvironmentError (
490
- f"We couldn't connect to '{ HUGGINGFACE_CO_RESOLVE_ENDPOINT } ' to load { pretrained_model_name_or_path } . You should try"
491
- " again after checking your internet connection."
492
- ) from e
486
+ # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
487
+ # we don't have to catch them here. We have also dealt with EntryNotFoundError.
488
+ except HTTPError as e :
489
+ raise EnvironmentError (
490
+ f"We couldn't connect to '{ HUGGINGFACE_CO_RESOLVE_ENDPOINT } ' to load { pretrained_model_name_or_path } . You should try"
491
+ " again after checking your internet connection."
492
+ ) from e
493
493
494
494
# If `local_files_only=True`, `cached_folder` may not contain all the shard files.
495
- if local_files_only :
495
+ elif local_files_only :
496
496
_check_if_shards_exist_locally (
497
497
local_dir = cache_dir , subfolder = subfolder , original_shard_filenames = original_shard_filenames
498
498
)
499
+ if subfolder is not None :
500
+ cached_folder = os .path .join (cached_folder , subfolder )
499
501
500
502
return cached_folder , sharded_metadata
501
503
0 commit comments