Skip to content

Commit e5d0a32

Browse files
authored
[refactor] LoRA tests (#9481)
* refactor scheduler class usage * reorder to make tests more readable * remove pipeline specific checks and skip tests directly * rewrite denoiser conditions cleaner * bump tolerance for cog test
1 parent 14a1b86 commit e5d0a32

File tree

4 files changed

+142
-289
lines changed

4 files changed

+142
-289
lines changed

tests/lora/test_lora_layers_cogvideox.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
4848
pipeline_class = CogVideoXPipeline
4949
scheduler_cls = CogVideoXDPMScheduler
5050
scheduler_kwargs = {"timestep_spacing": "trailing"}
51+
scheduler_classes = [CogVideoXDDIMScheduler, CogVideoXDPMScheduler]
5152

5253
transformer_kwargs = {
5354
"num_attention_heads": 4,
@@ -126,8 +127,7 @@ def get_dummy_inputs(self, with_generator=True):
126127

127128
@skip_mps
128129
def test_lora_fuse_nan(self):
129-
scheduler_classes = [CogVideoXDDIMScheduler, CogVideoXDPMScheduler]
130-
for scheduler_cls in scheduler_classes:
130+
for scheduler_cls in self.scheduler_classes:
131131
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
132132
pipe = self.pipeline_class(**components)
133133
pipe = pipe.to(torch_device)
@@ -156,10 +156,22 @@ def test_lora_fuse_nan(self):
156156
self.assertTrue(np.isnan(out).all())
157157

158158
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
159-
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=5e-3)
159+
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
160160

161161
def test_simple_inference_with_text_denoiser_lora_unfused(self):
162-
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=5e-3)
162+
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
163+
164+
@unittest.skip("Not supported in CogVideoX.")
165+
def test_simple_inference_with_text_denoiser_block_scale(self):
166+
pass
167+
168+
@unittest.skip("Not supported in CogVideoX.")
169+
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
170+
pass
171+
172+
@unittest.skip("Not supported in CogVideoX.")
173+
def test_modify_padding_mode(self):
174+
pass
163175

164176
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
165177
def test_simple_inference_with_partial_text_lora(self):

tests/lora/test_lora_layers_flux.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
4747
pipeline_class = FluxPipeline
4848
scheduler_cls = FlowMatchEulerDiscreteScheduler()
4949
scheduler_kwargs = {}
50-
uses_flow_matching = True
50+
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
5151
transformer_kwargs = {
5252
"patch_size": 1,
5353
"in_channels": 4,
@@ -154,6 +154,14 @@ def test_with_alpha_in_state_dict(self):
154154
)
155155
self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))
156156

157+
@unittest.skip("Not supported in Flux.")
158+
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
159+
pass
160+
161+
@unittest.skip("Not supported in Flux.")
162+
def test_modify_padding_mode(self):
163+
pass
164+
157165

158166
@slow
159167
@require_torch_gpu

tests/lora/test_lora_layers_sd3.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
3434
pipeline_class = StableDiffusion3Pipeline
3535
scheduler_cls = FlowMatchEulerDiscreteScheduler
3636
scheduler_kwargs = {}
37-
uses_flow_matching = True
37+
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
3838
transformer_kwargs = {
3939
"sample_size": 32,
4040
"patch_size": 1,
@@ -92,3 +92,19 @@ def test_sd3_lora(self):
9292

9393
lora_filename = "lora_peft_format.safetensors"
9494
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
95+
96+
@unittest.skip("Not supported in SD3.")
97+
def test_simple_inference_with_text_denoiser_block_scale(self):
98+
pass
99+
100+
@unittest.skip("Not supported in SD3.")
101+
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
102+
pass
103+
104+
@unittest.skip("Not supported in SD3.")
105+
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
106+
pass
107+
108+
@unittest.skip("Not supported in SD3.")
109+
def test_modify_padding_mode(self):
110+
pass

0 commit comments

Comments
 (0)