Skip to content

Commit 8bdafc6

Browse files
SunMarcsayakpaul
andcommitted
Fix loading sharded checkpoints when we have variants (#9061)
* Fix loading sharded checkpoint when we have variant * add test * remote print --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent f258237 commit 8bdafc6

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -773,7 +773,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
773773
try:
774774
accelerate.load_checkpoint_and_dispatch(
775775
model,
776-
model_file if not is_sharded else sharded_ckpt_cached_folder,
776+
model_file if not is_sharded else index_file,
777777
device_map,
778778
max_memory=max_memory,
779779
offload_folder=offload_folder,
@@ -803,7 +803,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
803803
model._temp_convert_self_to_deprecated_attention_blocks()
804804
accelerate.load_checkpoint_and_dispatch(
805805
model,
806-
model_file if not is_sharded else sharded_ckpt_cached_folder,
806+
model_file if not is_sharded else index_file,
807807
device_map,
808808
max_memory=max_memory,
809809
offload_folder=offload_folder,

tests/models/unets/test_models_unet_2d_condition.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,6 +1121,18 @@ def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
11211121
assert loaded_model
11221122
assert new_output.sample.shape == (4, 4, 16, 16)
11231123

1124+
@require_torch_gpu
1125+
def test_load_sharded_checkpoint_with_variant_from_hub(self):
1126+
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1127+
loaded_model = self.model_class.from_pretrained(
1128+
"hf-internal-testing/unet2d-sharded-with-variant-dummy", variant="fp16"
1129+
)
1130+
loaded_model = loaded_model.to(torch_device)
1131+
new_output = loaded_model(**inputs_dict)
1132+
1133+
assert loaded_model
1134+
assert new_output.sample.shape == (4, 4, 16, 16)
1135+
11241136
@require_peft_backend
11251137
def test_lora(self):
11261138
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

0 commit comments

Comments
 (0)