Skip to content

[Tests] clean up and refactor gradient checkpointing tests #9494

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 25 additions & 84 deletions tests/models/autoencoders/test_models_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
load_hf_numpy,
require_torch_accelerator,
require_torch_accelerator_with_fp16,
require_torch_accelerator_with_training,
require_torch_gpu,
skip_mps,
slow,
Expand Down Expand Up @@ -170,52 +169,17 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict

@unittest.skip("Not tested.")
def test_forward_signature(self):
pass

@unittest.skip("Not tested.")
def test_training(self):
pass

@require_torch_accelerator_with_training
def test_gradient_checkpointing(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)

assert not model.is_gradient_checkpointing and model.training

out = model(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()

labels = torch.randn_like(out)
loss = (out - labels).mean()
loss.backward()

# re-instantiate the model now enabling gradient checkpointing
model_2 = self.model_class(**init_dict)
# clone model
model_2.load_state_dict(model.state_dict())
model_2.to(torch_device)
model_2.enable_gradient_checkpointing()

assert model_2.is_gradient_checkpointing and model_2.training

out_2 = model_2(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model_2.zero_grad()
loss_2 = (out_2 - labels).mean()
loss_2.backward()

# compare the output and parameters gradients
self.assertTrue((loss - loss_2).abs() < 1e-5)
named_params = dict(model.named_parameters())
named_params_2 = dict(model_2.named_parameters())
for name, param in named_params.items():
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
def test_gradient_checkpointing_is_applied(self):
expected_set = {"Decoder", "Encoder"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

def test_from_pretrained_hub(self):
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
Expand Down Expand Up @@ -329,9 +293,11 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict

@unittest.skip("Not tested.")
def test_forward_signature(self):
pass

@unittest.skip("Not tested.")
def test_forward_with_norm_groups(self):
pass

Expand Down Expand Up @@ -364,9 +330,20 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict

@unittest.skip("Not tested.")
def test_outputs_equivalence(self):
pass

def test_gradient_checkpointing_is_applied(self):
expected_set = {"DecoderTiny", "EncoderTiny"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

@unittest.skip(
"Gradient checkpointing is supported but this test doesn't apply to this class because it's forward is a bit different from the rest."
)
Comment on lines +341 to +343
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other forwards don't apply scaling and unscaling like it does. So, that makes this a little different.

def test_effective_gradient_checkpointing(self):
pass


class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
model_class = ConsistencyDecoderVAE
Expand Down Expand Up @@ -443,55 +420,17 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict

@unittest.skip("Not tested.")
def test_forward_signature(self):
pass

@unittest.skip("Not tested.")
def test_training(self):
pass

@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
def test_gradient_checkpointing(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)

assert not model.is_gradient_checkpointing and model.training

out = model(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()

labels = torch.randn_like(out)
loss = (out - labels).mean()
loss.backward()

# re-instantiate the model now enabling gradient checkpointing
model_2 = self.model_class(**init_dict)
# clone model
model_2.load_state_dict(model.state_dict())
model_2.to(torch_device)
model_2.enable_gradient_checkpointing()

assert model_2.is_gradient_checkpointing and model_2.training

out_2 = model_2(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model_2.zero_grad()
loss_2 = (out_2 - labels).mean()
loss_2.backward()

# compare the output and parameters gradients
self.assertTrue((loss - loss_2).abs() < 1e-5)
named_params = dict(model.named_parameters())
named_params_2 = dict(model_2.named_parameters())
for name, param in named_params.items():
if "post_quant_conv" in name:
continue

self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
def test_gradient_checkpointing_is_applied(self):
expected_set = {"Encoder", "TemporalDecoder"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)


class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
Expand Down Expand Up @@ -522,9 +461,11 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict

@unittest.skip("Not tested.")
def test_forward_signature(self):
pass

@unittest.skip("Not tested.")
def test_forward_with_norm_groups(self):
pass

Expand Down
94 changes: 94 additions & 0 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import inspect
import json
import os
Expand Down Expand Up @@ -57,6 +58,7 @@
require_torch_gpu,
require_torch_multi_gpu,
run_test_in_subprocess,
torch_all_close,
torch_device,
)

Expand Down Expand Up @@ -785,6 +787,98 @@ def test_enable_disable_gradient_checkpointing(self):
model.disable_gradient_checkpointing()
self.assertFalse(model.is_gradient_checkpointing)

@require_torch_accelerator_with_training
def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5):
if not self.model_class._supports_gradient_checkpointing:
return # Skip test if model does not support gradient checkpointing

# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
inputs_dict_copy = copy.deepcopy(inputs_dict)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)

assert not model.is_gradient_checkpointing and model.training

out = model(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()

labels = torch.randn_like(out)
loss = (out - labels).mean()
loss.backward()

# re-instantiate the model now enabling gradient checkpointing
torch.manual_seed(0)
model_2 = self.model_class(**init_dict)
# clone model
model_2.load_state_dict(model.state_dict())
model_2.to(torch_device)
model_2.enable_gradient_checkpointing()

assert model_2.is_gradient_checkpointing and model_2.training

out_2 = model_2(**inputs_dict_copy).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model_2.zero_grad()
loss_2 = (out_2 - labels).mean()
loss_2.backward()

# compare the output and parameters gradients
self.assertTrue((loss - loss_2).abs() < loss_tolerance)
named_params = dict(model.named_parameters())
named_params_2 = dict(model_2.named_parameters())

for name, param in named_params.items():
if "post_quant_conv" in name:
continue
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol))

def test_gradient_checkpointing_is_applied(
self, expected_set=None, attention_head_dim=None, num_attention_heads=None, block_out_channels=None
):
if not self.model_class._supports_gradient_checkpointing:
return # Skip test if model does not support gradient checkpointing
if torch_device == "mps" and self.model_class.__name__ in [
"UNetSpatioTemporalConditionModel",
"AutoencoderKLTemporalDecoder",
]:
return

init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

if attention_head_dim is not None:
init_dict["attention_head_dim"] = attention_head_dim
if num_attention_heads is not None:
init_dict["num_attention_heads"] = num_attention_heads
if block_out_channels is not None:
init_dict["block_out_channels"] = block_out_channels

model_class_copy = copy.copy(self.model_class)

modules_with_gc_enabled = {}

# now monkey patch the following function:
# def _set_gradient_checkpointing(self, module, value=False):
# if hasattr(module, "gradient_checkpointing"):
# module.gradient_checkpointing = value

def _set_gradient_checkpointing_new(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
modules_with_gc_enabled[module.__class__.__name__] = True

model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new

model = model_class_copy(**init_dict)
model.enable_gradient_checkpointing()

assert set(modules_with_gc_enabled.keys()) == expected_set
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"

def test_deprecated_kwargs(self):
has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters
has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0
Expand Down
7 changes: 7 additions & 0 deletions tests/models/transformers/test_models_dit_transformer2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ def test_correct_class_remapping_from_dict_config(self):
model = Transformer2DModel.from_config(init_dict)
assert isinstance(model, DiTTransformer2DModel)

def test_gradient_checkpointing_is_applied(self):
expected_set = {"DiTTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

def test_effective_gradient_checkpointing(self):
super().test_effective_gradient_checkpointing(loss_tolerance=1e-4)

def test_correct_class_remapping_from_pretrained_config(self):
config = DiTTransformer2DModel.load_config("facebook/DiT-XL-2-256", subfolder="transformer")
model = Transformer2DModel.from_config(config)
Expand Down
4 changes: 4 additions & 0 deletions tests/models/transformers/test_models_pixart_transformer2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def test_output(self):
expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape
)

def test_gradient_checkpointing_is_applied(self):
expected_set = {"PixArtTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

def test_correct_class_remapping_from_dict_config(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = Transformer2DModel.from_config(init_dict)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict

def test_gradient_checkpointing_is_applied(self):
expected_set = {"AuraFlowTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

@unittest.skip("AuraFlowTransformer2DModel uses its own dedicated attention processor. This test does not apply")
def test_set_attn_processor_for_determinism(self):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,7 @@ def prepare_init_args_and_inputs_for_common(self):
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict

def test_gradient_checkpointing_is_applied(self):
expected_set = {"CogVideoXTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
4 changes: 4 additions & 0 deletions tests/models/transformers/test_models_transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,7 @@ def test_deprecated_inputs_img_txt_ids_3d(self):
torch.allclose(output_1, output_2, atol=1e-5),
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
)

def test_gradient_checkpointing_is_applied(self):
expected_set = {"FluxTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
4 changes: 4 additions & 0 deletions tests/models/transformers/test_models_transformer_latte.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,7 @@ def test_output(self):
super().test_output(
expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape
)

def test_gradient_checkpointing_is_applied(self):
expected_set = {"LatteTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
4 changes: 4 additions & 0 deletions tests/models/transformers/test_models_transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,7 @@ def prepare_init_args_and_inputs_for_common(self):
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
def test_set_attn_processor_for_determinism(self):
pass

def test_gradient_checkpointing_is_applied(self):
expected_set = {"SD3Transformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
Loading
Loading