Skip to content

Commit fa55429

Browse files
rootonchairsayakpaula-r-r-o-w
committed
[Tests] Improve transformers model test suite coverage - Latte (#8919)
* add LatteTransformer3DModel model test * change patch_size to 1 * reduce req len * reduce channel dims * increase num_layers * reduce dims further * run make style --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Aryan <aryan@huggingface.co>
1 parent 499b7d6 commit fa55429

File tree

1 file changed

+88
-0
lines changed

1 file changed

+88
-0
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import unittest
17+
18+
import torch
19+
20+
from diffusers import LatteTransformer3DModel
21+
from diffusers.utils.testing_utils import (
22+
enable_full_determinism,
23+
torch_device,
24+
)
25+
26+
from ..test_modeling_common import ModelTesterMixin
27+
28+
29+
enable_full_determinism()
30+
31+
32+
class LatteTransformerTests(ModelTesterMixin, unittest.TestCase):
33+
model_class = LatteTransformer3DModel
34+
main_input_name = "hidden_states"
35+
36+
@property
37+
def dummy_input(self):
38+
batch_size = 2
39+
num_channels = 4
40+
num_frames = 1
41+
height = width = 8
42+
embedding_dim = 8
43+
sequence_length = 8
44+
45+
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
46+
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
47+
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
48+
49+
return {
50+
"hidden_states": hidden_states,
51+
"encoder_hidden_states": encoder_hidden_states,
52+
"timestep": timestep,
53+
"enable_temporal_attentions": True,
54+
}
55+
56+
@property
57+
def input_shape(self):
58+
return (4, 1, 8, 8)
59+
60+
@property
61+
def output_shape(self):
62+
return (8, 1, 8, 8)
63+
64+
def prepare_init_args_and_inputs_for_common(self):
65+
init_dict = {
66+
"sample_size": 8,
67+
"num_layers": 1,
68+
"patch_size": 2,
69+
"attention_head_dim": 4,
70+
"num_attention_heads": 2,
71+
"caption_channels": 8,
72+
"in_channels": 4,
73+
"cross_attention_dim": 8,
74+
"out_channels": 8,
75+
"attention_bias": True,
76+
"activation_fn": "gelu-approximate",
77+
"num_embeds_ada_norm": 1000,
78+
"norm_type": "ada_norm_single",
79+
"norm_elementwise_affine": False,
80+
"norm_eps": 1e-6,
81+
}
82+
inputs_dict = self.dummy_input
83+
return init_dict, inputs_dict
84+
85+
def test_output(self):
86+
super().test_output(
87+
expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape
88+
)

0 commit comments

Comments
 (0)