From 146c387b812d722c5d88e5a5ba0044051931745d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 2 Jul 2024 10:51:37 +0530 Subject: [PATCH] fix sharding tests --- src/diffusers/models/embeddings.py | 3 ++- tests/models/autoencoders/test_models_vae.py | 7 ++++--- tests/models/test_modeling_common.py | 5 ++++- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 065b92425b19..cb6cb065dd32 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -415,9 +415,10 @@ def __init__( if set_W_to_weight: # to delete later + del self.weight self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) - self.weight = self.W + del self.W def forward(self, x): if self.log: diff --git a/tests/models/autoencoders/test_models_vae.py b/tests/models/autoencoders/test_models_vae.py index d78479247347..0fc185b602a3 100644 --- a/tests/models/autoencoders/test_models_vae.py +++ b/tests/models/autoencoders/test_models_vae.py @@ -361,9 +361,10 @@ class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase): forward_requires_fresh_args = True def inputs_dict(self, seed=None): - generator = torch.Generator("cpu") - if seed is not None: - generator.manual_seed(0) + if seed is None: + generator = torch.Generator("cpu").manual_seed(0) + else: + generator = torch.Generator("cpu").manual_seed(seed) image = randn_tensor((4, 3, 32, 32), generator=generator, device=torch.device(torch_device)) return {"sample": image, "generator": generator} diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index ac356d4c522d..259b4cc916d3 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -905,11 +905,13 @@ def test_sharded_checkpoints(self): actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")]) self.assertTrue(actual_num_shards == expected_num_shards) - new_model = self.model_class.from_pretrained(tmp_dir) + new_model = self.model_class.from_pretrained(tmp_dir).eval() new_model = new_model.to(torch_device) torch.manual_seed(0) + _, inputs_dict = self.prepare_init_args_and_inputs_for_common() new_output = new_model(**inputs_dict) + self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) @require_torch_gpu @@ -940,6 +942,7 @@ def test_sharded_checkpoints_device_map(self): new_model = new_model.to(torch_device) torch.manual_seed(0) + _, inputs_dict = self.prepare_init_args_and_inputs_for_common() new_output = new_model(**inputs_dict) self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))