From 7f4b2e9a7c46589d8f0f2418c8e6667d6e623d91 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 2 Aug 2024 14:39:33 +0200 Subject: [PATCH 1/3] Fix loading sharded checkpoint when we have variant --- src/diffusers/models/modeling_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index f7324009f3c6..cfe692dcc54a 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -773,7 +773,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P try: accelerate.load_checkpoint_and_dispatch( model, - model_file if not is_sharded else sharded_ckpt_cached_folder, + model_file if not is_sharded else index_file, device_map, max_memory=max_memory, offload_folder=offload_folder, @@ -803,7 +803,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model._temp_convert_self_to_deprecated_attention_blocks() accelerate.load_checkpoint_and_dispatch( model, - model_file if not is_sharded else sharded_ckpt_cached_folder, + model_file if not is_sharded else index_file, device_map, max_memory=max_memory, offload_folder=offload_folder, From 22035254d8c4d2d956b0e0fa15a886228e6d56df Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Tue, 6 Aug 2024 19:18:22 +0200 Subject: [PATCH 2/3] add test --- tests/models/unets/test_models_unet_2d_condition.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index 1c688c9e9c8a..ac0e209bad58 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -1121,6 +1121,19 @@ def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self): assert loaded_model assert new_output.sample.shape == (4, 4, 16, 16) + @require_torch_gpu + def test_load_sharded_checkpoint_with_variant_from_hub(self): + _, inputs_dict = self.prepare_init_args_and_inputs_for_common() + loaded_model = self.model_class.from_pretrained( + "hf-internal-testing/unet2d-sharded-with-variant-dummy", variant="fp16" + ) + print("helooooo") + loaded_model = loaded_model.to(torch_device) + new_output = loaded_model(**inputs_dict) + + assert loaded_model + assert new_output.sample.shape == (4, 4, 16, 16) + @require_peft_backend def test_lora(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() From 50bf20f8644950e2359ec9887c2fb0e5de6ff65e Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Tue, 6 Aug 2024 19:25:36 +0200 Subject: [PATCH 3/3] remote print --- tests/models/unets/test_models_unet_2d_condition.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index ac0e209bad58..df88e7960ba9 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -1127,7 +1127,6 @@ def test_load_sharded_checkpoint_with_variant_from_hub(self): loaded_model = self.model_class.from_pretrained( "hf-internal-testing/unet2d-sharded-with-variant-dummy", variant="fp16" ) - print("helooooo") loaded_model = loaded_model.to(torch_device) new_output = loaded_model(**inputs_dict)