diff --git a/tests/models/autoencoders/test_models_vae.py b/tests/models/autoencoders/test_models_vae.py index 0188f9121ae0..d29defbf6085 100644 --- a/tests/models/autoencoders/test_models_vae.py +++ b/tests/models/autoencoders/test_models_vae.py @@ -39,7 +39,6 @@ load_hf_numpy, require_torch_accelerator, require_torch_accelerator_with_fp16, - require_torch_accelerator_with_training, require_torch_gpu, skip_mps, slow, @@ -170,52 +169,17 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict + @unittest.skip("Not tested.") def test_forward_signature(self): pass + @unittest.skip("Not tested.") def test_training(self): pass - @require_torch_accelerator_with_training - def test_gradient_checkpointing(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - - assert not model.is_gradient_checkpointing and model.training - - out = model(**inputs_dict).sample - # run the backwards pass on the model. For backwards pass, for simplicity purpose, - # we won't calculate the loss and rather backprop on out.sum() - model.zero_grad() - - labels = torch.randn_like(out) - loss = (out - labels).mean() - loss.backward() - - # re-instantiate the model now enabling gradient checkpointing - model_2 = self.model_class(**init_dict) - # clone model - model_2.load_state_dict(model.state_dict()) - model_2.to(torch_device) - model_2.enable_gradient_checkpointing() - - assert model_2.is_gradient_checkpointing and model_2.training - - out_2 = model_2(**inputs_dict).sample - # run the backwards pass on the model. For backwards pass, for simplicity purpose, - # we won't calculate the loss and rather backprop on out.sum() - model_2.zero_grad() - loss_2 = (out_2 - labels).mean() - loss_2.backward() - - # compare the output and parameters gradients - self.assertTrue((loss - loss_2).abs() < 1e-5) - named_params = dict(model.named_parameters()) - named_params_2 = dict(model_2.named_parameters()) - for name, param in named_params.items(): - self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5)) + def test_gradient_checkpointing_is_applied(self): + expected_set = {"Decoder", "Encoder"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) def test_from_pretrained_hub(self): model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True) @@ -329,9 +293,11 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict + @unittest.skip("Not tested.") def test_forward_signature(self): pass + @unittest.skip("Not tested.") def test_forward_with_norm_groups(self): pass @@ -364,9 +330,20 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict + @unittest.skip("Not tested.") def test_outputs_equivalence(self): pass + def test_gradient_checkpointing_is_applied(self): + expected_set = {"DecoderTiny", "EncoderTiny"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + @unittest.skip( + "Gradient checkpointing is supported but this test doesn't apply to this class because it's forward is a bit different from the rest." + ) + def test_effective_gradient_checkpointing(self): + pass + class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase): model_class = ConsistencyDecoderVAE @@ -443,55 +420,17 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict + @unittest.skip("Not tested.") def test_forward_signature(self): pass + @unittest.skip("Not tested.") def test_training(self): pass - @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS") - def test_gradient_checkpointing(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - - assert not model.is_gradient_checkpointing and model.training - - out = model(**inputs_dict).sample - # run the backwards pass on the model. For backwards pass, for simplicity purpose, - # we won't calculate the loss and rather backprop on out.sum() - model.zero_grad() - - labels = torch.randn_like(out) - loss = (out - labels).mean() - loss.backward() - - # re-instantiate the model now enabling gradient checkpointing - model_2 = self.model_class(**init_dict) - # clone model - model_2.load_state_dict(model.state_dict()) - model_2.to(torch_device) - model_2.enable_gradient_checkpointing() - - assert model_2.is_gradient_checkpointing and model_2.training - - out_2 = model_2(**inputs_dict).sample - # run the backwards pass on the model. For backwards pass, for simplicity purpose, - # we won't calculate the loss and rather backprop on out.sum() - model_2.zero_grad() - loss_2 = (out_2 - labels).mean() - loss_2.backward() - - # compare the output and parameters gradients - self.assertTrue((loss - loss_2).abs() < 1e-5) - named_params = dict(model.named_parameters()) - named_params_2 = dict(model_2.named_parameters()) - for name, param in named_params.items(): - if "post_quant_conv" in name: - continue - - self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5)) + def test_gradient_checkpointing_is_applied(self): + expected_set = {"Encoder", "TemporalDecoder"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): @@ -522,9 +461,11 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict + @unittest.skip("Not tested.") def test_forward_signature(self): pass + @unittest.skip("Not tested.") def test_forward_with_norm_groups(self): pass diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 5548fdd0723d..7f8dc63e00ac 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import inspect import json import os @@ -57,6 +58,7 @@ require_torch_gpu, require_torch_multi_gpu, run_test_in_subprocess, + torch_all_close, torch_device, ) @@ -785,6 +787,101 @@ def test_enable_disable_gradient_checkpointing(self): model.disable_gradient_checkpointing() self.assertFalse(model.is_gradient_checkpointing) + @require_torch_accelerator_with_training + def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5): + if not self.model_class._supports_gradient_checkpointing: + return # Skip test if model does not support gradient checkpointing + + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + inputs_dict_copy = copy.deepcopy(inputs_dict) + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + assert not model.is_gradient_checkpointing and model.training + + out = model(**inputs_dict).sample + # run the backwards pass on the model. For backwards pass, for simplicity purpose, + # we won't calculate the loss and rather backprop on out.sum() + model.zero_grad() + + labels = torch.randn_like(out) + loss = (out - labels).mean() + loss.backward() + + # re-instantiate the model now enabling gradient checkpointing + torch.manual_seed(0) + model_2 = self.model_class(**init_dict) + # clone model + model_2.load_state_dict(model.state_dict()) + model_2.to(torch_device) + model_2.enable_gradient_checkpointing() + + assert model_2.is_gradient_checkpointing and model_2.training + + out_2 = model_2(**inputs_dict_copy).sample + # run the backwards pass on the model. For backwards pass, for simplicity purpose, + # we won't calculate the loss and rather backprop on out.sum() + model_2.zero_grad() + loss_2 = (out_2 - labels).mean() + loss_2.backward() + + # compare the output and parameters gradients + self.assertTrue((loss - loss_2).abs() < loss_tolerance) + named_params = dict(model.named_parameters()) + named_params_2 = dict(model_2.named_parameters()) + + for name, param in named_params.items(): + if "post_quant_conv" in name: + continue + self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol)) + + @unittest.skipIf(torch_device == "mps", "This test is not supported for MPS devices.") + def test_gradient_checkpointing_is_applied( + self, expected_set=None, attention_head_dim=None, num_attention_heads=None, block_out_channels=None + ): + if not self.model_class._supports_gradient_checkpointing: + return # Skip test if model does not support gradient checkpointing + if self.model_class.__name__ in [ + "UNetSpatioTemporalConditionModel", + "AutoencoderKLTemporalDecoder", + ]: + return + + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + if attention_head_dim is not None: + init_dict["attention_head_dim"] = attention_head_dim + if num_attention_heads is not None: + init_dict["num_attention_heads"] = num_attention_heads + if block_out_channels is not None: + init_dict["block_out_channels"] = block_out_channels + + model_class_copy = copy.copy(self.model_class) + + modules_with_gc_enabled = {} + + # now monkey patch the following function: + # def _set_gradient_checkpointing(self, module, value=False): + # if hasattr(module, "gradient_checkpointing"): + # module.gradient_checkpointing = value + + def _set_gradient_checkpointing_new(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + modules_with_gc_enabled[module.__class__.__name__] = True + + model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new + + model = model_class_copy(**init_dict) + model.enable_gradient_checkpointing() + + print(f"{set(modules_with_gc_enabled.keys())=}, {expected_set=}") + + assert set(modules_with_gc_enabled.keys()) == expected_set + assert all(modules_with_gc_enabled.values()), "All modules should be enabled" + def test_deprecated_kwargs(self): has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0 diff --git a/tests/models/transformers/test_models_dit_transformer2d.py b/tests/models/transformers/test_models_dit_transformer2d.py index b12cae1a8879..5f4a2f587e92 100644 --- a/tests/models/transformers/test_models_dit_transformer2d.py +++ b/tests/models/transformers/test_models_dit_transformer2d.py @@ -84,6 +84,13 @@ def test_correct_class_remapping_from_dict_config(self): model = Transformer2DModel.from_config(init_dict) assert isinstance(model, DiTTransformer2DModel) + def test_gradient_checkpointing_is_applied(self): + expected_set = {"DiTTransformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + def test_effective_gradient_checkpointing(self): + super().test_effective_gradient_checkpointing(loss_tolerance=1e-4) + def test_correct_class_remapping_from_pretrained_config(self): config = DiTTransformer2DModel.load_config("facebook/DiT-XL-2-256", subfolder="transformer") model = Transformer2DModel.from_config(config) diff --git a/tests/models/transformers/test_models_pixart_transformer2d.py b/tests/models/transformers/test_models_pixart_transformer2d.py index 30293f5d35cb..a544a3fc4607 100644 --- a/tests/models/transformers/test_models_pixart_transformer2d.py +++ b/tests/models/transformers/test_models_pixart_transformer2d.py @@ -92,6 +92,10 @@ def test_output(self): expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape ) + def test_gradient_checkpointing_is_applied(self): + expected_set = {"PixArtTransformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + def test_correct_class_remapping_from_dict_config(self): init_dict, _ = self.prepare_init_args_and_inputs_for_common() model = Transformer2DModel.from_config(init_dict) diff --git a/tests/models/transformers/test_models_transformer_allegro.py b/tests/models/transformers/test_models_transformer_allegro.py index ad8b7a3824ba..3479803da61d 100644 --- a/tests/models/transformers/test_models_transformer_allegro.py +++ b/tests/models/transformers/test_models_transformer_allegro.py @@ -77,3 +77,7 @@ def prepare_init_args_and_inputs_for_common(self): } inputs_dict = self.dummy_input return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"AllegroTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/models/transformers/test_models_transformer_aura_flow.py b/tests/models/transformers/test_models_transformer_aura_flow.py index 376d8b57da4d..d1ff7d2c96d3 100644 --- a/tests/models/transformers/test_models_transformer_aura_flow.py +++ b/tests/models/transformers/test_models_transformer_aura_flow.py @@ -74,6 +74,10 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict + def test_gradient_checkpointing_is_applied(self): + expected_set = {"AuraFlowTransformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + @unittest.skip("AuraFlowTransformer2DModel uses its own dedicated attention processor. This test does not apply") def test_set_attn_processor_for_determinism(self): pass diff --git a/tests/models/transformers/test_models_transformer_cogvideox.py b/tests/models/transformers/test_models_transformer_cogvideox.py index 6db4113cbd1b..1342577f0114 100644 --- a/tests/models/transformers/test_models_transformer_cogvideox.py +++ b/tests/models/transformers/test_models_transformer_cogvideox.py @@ -81,3 +81,7 @@ def prepare_init_args_and_inputs_for_common(self): } inputs_dict = self.dummy_input return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"CogVideoXTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/models/transformers/test_models_transformer_cogview3plus.py b/tests/models/transformers/test_models_transformer_cogview3plus.py index 46612dbd9190..eda9813808e9 100644 --- a/tests/models/transformers/test_models_transformer_cogview3plus.py +++ b/tests/models/transformers/test_models_transformer_cogview3plus.py @@ -83,3 +83,7 @@ def prepare_init_args_and_inputs_for_common(self): } inputs_dict = self.dummy_input return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"CogView3PlusTransformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 6cf7a4f75707..4a784eee4732 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -111,3 +111,7 @@ def test_deprecated_inputs_img_txt_ids_3d(self): torch.allclose(output_1, output_2, atol=1e-5), msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs", ) + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"FluxTransformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/models/transformers/test_models_transformer_latte.py b/tests/models/transformers/test_models_transformer_latte.py index 3fe0a6098045..0cb9094f5165 100644 --- a/tests/models/transformers/test_models_transformer_latte.py +++ b/tests/models/transformers/test_models_transformer_latte.py @@ -86,3 +86,7 @@ def test_output(self): super().test_output( expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape ) + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"LatteTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/models/transformers/test_models_transformer_sd3.py b/tests/models/transformers/test_models_transformer_sd3.py index 2be4744c5ac4..af86fa9c3bc1 100644 --- a/tests/models/transformers/test_models_transformer_sd3.py +++ b/tests/models/transformers/test_models_transformer_sd3.py @@ -84,6 +84,10 @@ def prepare_init_args_and_inputs_for_common(self): def test_set_attn_processor_for_determinism(self): pass + def test_gradient_checkpointing_is_applied(self): + expected_set = {"SD3Transformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + class SD35TransformerTests(ModelTesterMixin, unittest.TestCase): model_class = SD3Transformer2DModel @@ -139,3 +143,7 @@ def prepare_init_args_and_inputs_for_common(self): @unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply") def test_set_attn_processor_for_determinism(self): pass + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"SD3Transformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index 37d55cedeb28..fec34822904c 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -43,7 +43,6 @@ require_peft_backend, require_torch_accelerator, require_torch_accelerator_with_fp16, - require_torch_accelerator_with_training, require_torch_gpu, skip_mps, slow, @@ -406,47 +405,6 @@ def test_xformers_enable_works(self): == "XFormersAttnProcessor" ), "xformers is not enabled" - @require_torch_accelerator_with_training - def test_gradient_checkpointing(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - - assert not model.is_gradient_checkpointing and model.training - - out = model(**inputs_dict).sample - # run the backwards pass on the model. For backwards pass, for simplicity purpose, - # we won't calculate the loss and rather backprop on out.sum() - model.zero_grad() - - labels = torch.randn_like(out) - loss = (out - labels).mean() - loss.backward() - - # re-instantiate the model now enabling gradient checkpointing - model_2 = self.model_class(**init_dict) - # clone model - model_2.load_state_dict(model.state_dict()) - model_2.to(torch_device) - model_2.enable_gradient_checkpointing() - - assert model_2.is_gradient_checkpointing and model_2.training - - out_2 = model_2(**inputs_dict).sample - # run the backwards pass on the model. For backwards pass, for simplicity purpose, - # we won't calculate the loss and rather backprop on out.sum() - model_2.zero_grad() - loss_2 = (out_2 - labels).mean() - loss_2.backward() - - # compare the output and parameters gradients - self.assertTrue((loss - loss_2).abs() < 1e-5) - named_params = dict(model.named_parameters()) - named_params_2 = dict(model_2.named_parameters()) - for name, param in named_params.items(): - self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5)) - def test_model_with_attention_head_dim_tuple(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -599,31 +557,7 @@ def check_sliceable_dim_attr(module: torch.nn.Module): check_sliceable_dim_attr(module) def test_gradient_checkpointing_is_applied(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["block_out_channels"] = (16, 32) - init_dict["attention_head_dim"] = (8, 16) - - model_class_copy = copy.copy(self.model_class) - - modules_with_gc_enabled = {} - - # now monkey patch the following function: - # def _set_gradient_checkpointing(self, module, value=False): - # if hasattr(module, "gradient_checkpointing"): - # module.gradient_checkpointing = value - - def _set_gradient_checkpointing_new(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - modules_with_gc_enabled[module.__class__.__name__] = True - - model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new - - model = model_class_copy(**init_dict) - model.enable_gradient_checkpointing() - - EXPECTED_SET = { + expected_set = { "CrossAttnUpBlock2D", "CrossAttnDownBlock2D", "UNetMidBlock2DCrossAttn", @@ -631,9 +565,11 @@ def _set_gradient_checkpointing_new(self, module, value=False): "Transformer2DModel", "DownBlock2D", } - - assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET - assert all(modules_with_gc_enabled.values()), "All modules should be enabled" + attention_head_dim = (8, 16) + block_out_channels = (16, 32) + super().test_gradient_checkpointing_is_applied( + expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels + ) def test_special_attn_proc(self): class AttnEasyProc(torch.nn.Module): diff --git a/tests/models/unets/test_models_unet_controlnetxs.py b/tests/models/unets/test_models_unet_controlnetxs.py index 6f3662e01750..3025d7117f35 100644 --- a/tests/models/unets/test_models_unet_controlnetxs.py +++ b/tests/models/unets/test_models_unet_controlnetxs.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import unittest import numpy as np @@ -269,37 +268,14 @@ def assert_unfrozen(module): assert_unfrozen(u.ctrl_to_base) def test_gradient_checkpointing_is_applied(self): - model_class_copy = copy.copy(UNetControlNetXSModel) - - modules_with_gc_enabled = {} - - # now monkey patch the following function: - # def _set_gradient_checkpointing(self, module, value=False): - # if hasattr(module, "gradient_checkpointing"): - # module.gradient_checkpointing = value - - def _set_gradient_checkpointing_new(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - modules_with_gc_enabled[module.__class__.__name__] = True - - model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new - - init_dict, _ = self.prepare_init_args_and_inputs_for_common() - model = model_class_copy(**init_dict) - - model.enable_gradient_checkpointing() - - EXPECTED_SET = { + expected_set = { "Transformer2DModel", "UNetMidBlock2DCrossAttn", "ControlNetXSCrossAttnDownBlock2D", "ControlNetXSCrossAttnMidBlock2D", "ControlNetXSCrossAttnUpBlock2D", } - - assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET - assert all(modules_with_gc_enabled.values()), "All modules should be enabled" + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) @is_flaky def test_forward_no_control(self): diff --git a/tests/models/unets/test_models_unet_motion.py b/tests/models/unets/test_models_unet_motion.py index ee05f0d93824..209806a5fe26 100644 --- a/tests/models/unets/test_models_unet_motion.py +++ b/tests/models/unets/test_models_unet_motion.py @@ -161,27 +161,7 @@ def test_xformers_enable_works(self): ), "xformers is not enabled" def test_gradient_checkpointing_is_applied(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model_class_copy = copy.copy(self.model_class) - - modules_with_gc_enabled = {} - - # now monkey patch the following function: - # def _set_gradient_checkpointing(self, module, value=False): - # if hasattr(module, "gradient_checkpointing"): - # module.gradient_checkpointing = value - - def _set_gradient_checkpointing_new(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - modules_with_gc_enabled[module.__class__.__name__] = True - - model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new - - model = model_class_copy(**init_dict) - model.enable_gradient_checkpointing() - - EXPECTED_SET = { + expected_set = { "CrossAttnUpBlockMotion", "CrossAttnDownBlockMotion", "UNetMidBlockCrossAttnMotion", @@ -189,9 +169,7 @@ def _set_gradient_checkpointing_new(self, module, value=False): "Transformer2DModel", "DownBlockMotion", } - - assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET - assert all(modules_with_gc_enabled.values()), "All modules should be enabled" + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) def test_feed_forward_chunking(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() diff --git a/tests/models/unets/test_models_unet_spatiotemporal.py b/tests/models/unets/test_models_unet_spatiotemporal.py index afdd3d127702..0d7dc823b026 100644 --- a/tests/models/unets/test_models_unet_spatiotemporal.py +++ b/tests/models/unets/test_models_unet_spatiotemporal.py @@ -25,7 +25,6 @@ enable_full_determinism, floats_tensor, skip_mps, - torch_all_close, torch_device, ) @@ -160,47 +159,6 @@ def test_xformers_enable_works(self): == "XFormersAttnProcessor" ), "xformers is not enabled" - @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS") - def test_gradient_checkpointing(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - - assert not model.is_gradient_checkpointing and model.training - - out = model(**inputs_dict).sample - # run the backwards pass on the model. For backwards pass, for simplicity purpose, - # we won't calculate the loss and rather backprop on out.sum() - model.zero_grad() - - labels = torch.randn_like(out) - loss = (out - labels).mean() - loss.backward() - - # re-instantiate the model now enabling gradient checkpointing - model_2 = self.model_class(**init_dict) - # clone model - model_2.load_state_dict(model.state_dict()) - model_2.to(torch_device) - model_2.enable_gradient_checkpointing() - - assert model_2.is_gradient_checkpointing and model_2.training - - out_2 = model_2(**inputs_dict).sample - # run the backwards pass on the model. For backwards pass, for simplicity purpose, - # we won't calculate the loss and rather backprop on out.sum() - model_2.zero_grad() - loss_2 = (out_2 - labels).mean() - loss_2.backward() - - # compare the output and parameters gradients - self.assertTrue((loss - loss_2).abs() < 1e-5) - named_params = dict(model.named_parameters()) - named_params_2 = dict(model_2.named_parameters()) - for name, param in named_params.items(): - self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5)) - def test_model_with_num_attention_heads_tuple(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -239,30 +197,7 @@ def test_model_with_cross_attention_dim_tuple(self): self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") def test_gradient_checkpointing_is_applied(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["num_attention_heads"] = (8, 16) - - model_class_copy = copy.copy(self.model_class) - - modules_with_gc_enabled = {} - - # now monkey patch the following function: - # def _set_gradient_checkpointing(self, module, value=False): - # if hasattr(module, "gradient_checkpointing"): - # module.gradient_checkpointing = value - - def _set_gradient_checkpointing_new(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - modules_with_gc_enabled[module.__class__.__name__] = True - - model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new - - model = model_class_copy(**init_dict) - model.enable_gradient_checkpointing() - - EXPECTED_SET = { + expected_set = { "TransformerSpatioTemporalModel", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal", @@ -270,9 +205,10 @@ def _set_gradient_checkpointing_new(self, module, value=False): "CrossAttnUpBlockSpatioTemporal", "UNetMidBlockSpatioTemporal", } - - assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET - assert all(modules_with_gc_enabled.values()), "All modules should be enabled" + num_attention_heads = (8, 16) + super().test_gradient_checkpointing_is_applied( + expected_set=expected_set, num_attention_heads=num_attention_heads + ) def test_pickle(self): # enable deterministic behavior for gradient checkpointing