Skip to content

Commit 102cabe

Browse files
kashifanton-l
andauthored
split tests_modeling_utils (#223)
* split tests_modeling_utils * Fix SD tests .to(device) * fix merge * Fix style Co-authored-by: anton-l <anton@huggingface.co>
1 parent 511bd3a commit 102cabe

File tree

6 files changed

+1026
-921
lines changed

6 files changed

+1026
-921
lines changed

tests/test_modeling_common.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# coding=utf-8
2+
# Copyright 2022 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import inspect
17+
import tempfile
18+
19+
import numpy as np
20+
import torch
21+
22+
from diffusers.testing_utils import torch_device
23+
from diffusers.training_utils import EMAModel
24+
25+
26+
class ModelTesterMixin:
27+
def test_from_pretrained_save_pretrained(self):
28+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
29+
30+
model = self.model_class(**init_dict)
31+
model.to(torch_device)
32+
model.eval()
33+
34+
with tempfile.TemporaryDirectory() as tmpdirname:
35+
model.save_pretrained(tmpdirname)
36+
new_model = self.model_class.from_pretrained(tmpdirname)
37+
new_model.to(torch_device)
38+
39+
with torch.no_grad():
40+
image = model(**inputs_dict)
41+
if isinstance(image, dict):
42+
image = image["sample"]
43+
44+
new_image = new_model(**inputs_dict)
45+
46+
if isinstance(new_image, dict):
47+
new_image = new_image["sample"]
48+
49+
max_diff = (image - new_image).abs().sum().item()
50+
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
51+
52+
def test_determinism(self):
53+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
54+
model = self.model_class(**init_dict)
55+
model.to(torch_device)
56+
model.eval()
57+
with torch.no_grad():
58+
first = model(**inputs_dict)
59+
if isinstance(first, dict):
60+
first = first["sample"]
61+
62+
second = model(**inputs_dict)
63+
if isinstance(second, dict):
64+
second = second["sample"]
65+
66+
out_1 = first.cpu().numpy()
67+
out_2 = second.cpu().numpy()
68+
out_1 = out_1[~np.isnan(out_1)]
69+
out_2 = out_2[~np.isnan(out_2)]
70+
max_diff = np.amax(np.abs(out_1 - out_2))
71+
self.assertLessEqual(max_diff, 1e-5)
72+
73+
def test_output(self):
74+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
75+
model = self.model_class(**init_dict)
76+
model.to(torch_device)
77+
model.eval()
78+
79+
with torch.no_grad():
80+
output = model(**inputs_dict)
81+
82+
if isinstance(output, dict):
83+
output = output["sample"]
84+
85+
self.assertIsNotNone(output)
86+
expected_shape = inputs_dict["sample"].shape
87+
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
88+
89+
def test_forward_signature(self):
90+
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
91+
92+
model = self.model_class(**init_dict)
93+
signature = inspect.signature(model.forward)
94+
# signature.parameters is an OrderedDict => so arg_names order is deterministic
95+
arg_names = [*signature.parameters.keys()]
96+
97+
expected_arg_names = ["sample", "timestep"]
98+
self.assertListEqual(arg_names[:2], expected_arg_names)
99+
100+
def test_model_from_config(self):
101+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
102+
103+
model = self.model_class(**init_dict)
104+
model.to(torch_device)
105+
model.eval()
106+
107+
# test if the model can be loaded from the config
108+
# and has all the expected shape
109+
with tempfile.TemporaryDirectory() as tmpdirname:
110+
model.save_config(tmpdirname)
111+
new_model = self.model_class.from_config(tmpdirname)
112+
new_model.to(torch_device)
113+
new_model.eval()
114+
115+
# check if all paramters shape are the same
116+
for param_name in model.state_dict().keys():
117+
param_1 = model.state_dict()[param_name]
118+
param_2 = new_model.state_dict()[param_name]
119+
self.assertEqual(param_1.shape, param_2.shape)
120+
121+
with torch.no_grad():
122+
output_1 = model(**inputs_dict)
123+
124+
if isinstance(output_1, dict):
125+
output_1 = output_1["sample"]
126+
127+
output_2 = new_model(**inputs_dict)
128+
129+
if isinstance(output_2, dict):
130+
output_2 = output_2["sample"]
131+
132+
self.assertEqual(output_1.shape, output_2.shape)
133+
134+
def test_training(self):
135+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
136+
137+
model = self.model_class(**init_dict)
138+
model.to(torch_device)
139+
model.train()
140+
output = model(**inputs_dict)
141+
142+
if isinstance(output, dict):
143+
output = output["sample"]
144+
145+
noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
146+
loss = torch.nn.functional.mse_loss(output, noise)
147+
loss.backward()
148+
149+
def test_ema_training(self):
150+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
151+
152+
model = self.model_class(**init_dict)
153+
model.to(torch_device)
154+
model.train()
155+
ema_model = EMAModel(model, device=torch_device)
156+
157+
output = model(**inputs_dict)
158+
159+
if isinstance(output, dict):
160+
output = output["sample"]
161+
162+
noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
163+
loss = torch.nn.functional.mse_loss(output, noise)
164+
loss.backward()
165+
ema_model.step(model)

0 commit comments

Comments
 (0)