|
40 | 40 | )
|
41 | 41 | from diffusers.training_utils import EMAModel
|
42 | 42 | from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, is_torch_npu_available, is_xformers_available, logging
|
| 43 | +from diffusers.utils.hub_utils import _add_variant |
43 | 44 | from diffusers.utils.testing_utils import (
|
44 | 45 | CaptureLogger,
|
45 | 46 | get_python_version,
|
@@ -915,6 +916,43 @@ def test_sharded_checkpoints(self):
|
915 | 916 |
|
916 | 917 | self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
917 | 918 |
|
| 919 | + @require_torch_gpu |
| 920 | + def test_sharded_checkpoints_with_variant(self): |
| 921 | + torch.manual_seed(0) |
| 922 | + config, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
| 923 | + model = self.model_class(**config).eval() |
| 924 | + model = model.to(torch_device) |
| 925 | + |
| 926 | + base_output = model(**inputs_dict) |
| 927 | + |
| 928 | + model_size = compute_module_sizes(model)[""] |
| 929 | + max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small. |
| 930 | + variant = "fp16" |
| 931 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 932 | + # It doesn't matter if the actual model is in fp16 or not. Just adding the variant and |
| 933 | + # testing if loading works with the variant when the checkpoint is sharded should be |
| 934 | + # enough. |
| 935 | + model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB", variant=variant) |
| 936 | + index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) |
| 937 | + self.assertTrue(os.path.exists(os.path.join(tmp_dir, index_filename))) |
| 938 | + |
| 939 | + # Now check if the right number of shards exists. First, let's get the number of shards. |
| 940 | + # Since this number can be dependent on the model being tested, it's important that we calculate it |
| 941 | + # instead of hardcoding it. |
| 942 | + expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, index_filename)) |
| 943 | + actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")]) |
| 944 | + self.assertTrue(actual_num_shards == expected_num_shards) |
| 945 | + |
| 946 | + new_model = self.model_class.from_pretrained(tmp_dir, variant=variant).eval() |
| 947 | + new_model = new_model.to(torch_device) |
| 948 | + |
| 949 | + torch.manual_seed(0) |
| 950 | + if "generator" in inputs_dict: |
| 951 | + _, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
| 952 | + new_output = new_model(**inputs_dict) |
| 953 | + |
| 954 | + self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) |
| 955 | + |
918 | 956 | @require_torch_gpu
|
919 | 957 | def test_sharded_checkpoints_device_map(self):
|
920 | 958 | config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
|
0 commit comments