diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index 7ecb7de89cd3..2ad4cc7b87d1 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -271,7 +271,8 @@ def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str] def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: if variant is not None: splits = weights_name.split(".") - splits = splits[:-1] + [variant] + splits[-1:] + split_index = -2 if weights_name.endswith(".index.json") else -1 + splits = splits[:-split_index] + [variant] + splits[-split_index:] weights_name = ".".join(splits) return weights_name diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 87ed1d9d17e5..034c7434a2d5 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -40,6 +40,7 @@ ) from diffusers.training_utils import EMAModel from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, is_torch_npu_available, is_xformers_available, logging +from diffusers.utils.hub_utils import _add_variant from diffusers.utils.testing_utils import ( CaptureLogger, get_python_version, @@ -915,6 +916,43 @@ def test_sharded_checkpoints(self): self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + @require_torch_gpu + def test_sharded_checkpoints_with_variant(self): + torch.manual_seed(0) + config, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**config).eval() + model = model.to(torch_device) + + base_output = model(**inputs_dict) + + model_size = compute_module_sizes(model)[""] + max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small. + variant = "fp16" + with tempfile.TemporaryDirectory() as tmp_dir: + # It doesn't matter if the actual model is in fp16 or not. Just adding the variant and + # testing if loading works with the variant when the checkpoint is sharded should be + # enough. + model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB", variant=variant) + index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) + self.assertTrue(os.path.exists(os.path.join(tmp_dir, index_filename))) + + # Now check if the right number of shards exists. First, let's get the number of shards. + # Since this number can be dependent on the model being tested, it's important that we calculate it + # instead of hardcoding it. + expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, index_filename)) + actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")]) + self.assertTrue(actual_num_shards == expected_num_shards) + + new_model = self.model_class.from_pretrained(tmp_dir, variant=variant).eval() + new_model = new_model.to(torch_device) + + torch.manual_seed(0) + if "generator" in inputs_dict: + _, inputs_dict = self.prepare_init_args_and_inputs_for_common() + new_output = new_model(**inputs_dict) + + self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + @require_torch_gpu def test_sharded_checkpoints_device_map(self): config, inputs_dict = self.prepare_init_args_and_inputs_for_common()