Skip to content

fix load sharded checkpoint from a subfolder (local path) #8913

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 26 additions & 24 deletions src/diffusers/utils/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def _get_checkpoint_shard_files(
_check_if_shards_exist_locally(
pretrained_model_name_or_path, subfolder=subfolder, original_shard_filenames=original_shard_filenames
)
return pretrained_model_name_or_path, sharded_metadata
return shards_path, sharded_metadata

# At this stage pretrained_model_name_or_path is a model identifier on the Hub
allow_patterns = original_shard_filenames
Expand All @@ -467,35 +467,37 @@ def _get_checkpoint_shard_files(
"required according to the checkpoint index."
)

try:
# Load from URL
cached_folder = snapshot_download(
pretrained_model_name_or_path,
cache_dir=cache_dir,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
user_agent=user_agent,
)
if subfolder is not None:
cached_folder = os.path.join(cached_folder, subfolder)
try:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

# Load from URL
cached_folder = snapshot_download(
pretrained_model_name_or_path,
cache_dir=cache_dir,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
user_agent=user_agent,
)
if subfolder is not None:
cached_folder = os.path.join(cached_folder, subfolder)

# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
except HTTPError as e:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
" again after checking your internet connection."
) from e
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
except HTTPError as e:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
" again after checking your internet connection."
) from e

# If `local_files_only=True`, `cached_folder` may not contain all the shard files.
if local_files_only:
elif local_files_only:
_check_if_shards_exist_locally(
local_dir=cache_dir, subfolder=subfolder, original_shard_filenames=original_shard_filenames
)
if subfolder is not None:
cached_folder = os.path.join(cached_folder, subfolder)

return cached_folder, sharded_metadata

Expand Down
34 changes: 34 additions & 0 deletions tests/models/unets/test_models_unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,17 @@ def test_load_sharded_checkpoint_from_hub_local(self):
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)

@require_torch_gpu
def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
loaded_model = self.model_class.from_pretrained(ckpt_path, subfolder="unet", local_files_only=True)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)

assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)

@require_torch_gpu
def test_load_sharded_checkpoint_device_map_from_hub(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
Expand All @@ -1077,6 +1088,17 @@ def test_load_sharded_checkpoint_device_map_from_hub(self):
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)

@require_torch_gpu
def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map="auto"
)
new_output = loaded_model(**inputs_dict)

assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)

@require_torch_gpu
def test_load_sharded_checkpoint_device_map_from_hub_local(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
Expand All @@ -1087,6 +1109,18 @@ def test_load_sharded_checkpoint_device_map_from_hub_local(self):
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)

@require_torch_gpu
def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
loaded_model = self.model_class.from_pretrained(
ckpt_path, local_files_only=True, subfolder="unet", device_map="auto"
)
new_output = loaded_model(**inputs_dict)

assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)

@require_peft_backend
def test_lora(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
Expand Down
Loading