Skip to content

Commit 0f09b01

Browse files
authored
[Core] fix: shard loading and saving when variant is provided. (#8869)
fix: shard loading and saving when variant is provided.
1 parent f6cfe0a commit 0f09b01

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

src/diffusers/utils/hub_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,8 @@ def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str]
271271
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
272272
if variant is not None:
273273
splits = weights_name.split(".")
274-
splits = splits[:-1] + [variant] + splits[-1:]
274+
split_index = -2 if weights_name.endswith(".index.json") else -1
275+
splits = splits[:-split_index] + [variant] + splits[-split_index:]
275276
weights_name = ".".join(splits)
276277

277278
return weights_name

tests/models/test_modeling_common.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
)
4141
from diffusers.training_utils import EMAModel
4242
from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, is_torch_npu_available, is_xformers_available, logging
43+
from diffusers.utils.hub_utils import _add_variant
4344
from diffusers.utils.testing_utils import (
4445
CaptureLogger,
4546
get_python_version,
@@ -915,6 +916,43 @@ def test_sharded_checkpoints(self):
915916

916917
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
917918

919+
@require_torch_gpu
920+
def test_sharded_checkpoints_with_variant(self):
921+
torch.manual_seed(0)
922+
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
923+
model = self.model_class(**config).eval()
924+
model = model.to(torch_device)
925+
926+
base_output = model(**inputs_dict)
927+
928+
model_size = compute_module_sizes(model)[""]
929+
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
930+
variant = "fp16"
931+
with tempfile.TemporaryDirectory() as tmp_dir:
932+
# It doesn't matter if the actual model is in fp16 or not. Just adding the variant and
933+
# testing if loading works with the variant when the checkpoint is sharded should be
934+
# enough.
935+
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB", variant=variant)
936+
index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
937+
self.assertTrue(os.path.exists(os.path.join(tmp_dir, index_filename)))
938+
939+
# Now check if the right number of shards exists. First, let's get the number of shards.
940+
# Since this number can be dependent on the model being tested, it's important that we calculate it
941+
# instead of hardcoding it.
942+
expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, index_filename))
943+
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
944+
self.assertTrue(actual_num_shards == expected_num_shards)
945+
946+
new_model = self.model_class.from_pretrained(tmp_dir, variant=variant).eval()
947+
new_model = new_model.to(torch_device)
948+
949+
torch.manual_seed(0)
950+
if "generator" in inputs_dict:
951+
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
952+
new_output = new_model(**inputs_dict)
953+
954+
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
955+
918956
@require_torch_gpu
919957
def test_sharded_checkpoints_device_map(self):
920958
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()

0 commit comments

Comments
 (0)