Skip to content

Commit 9e1b1ca

Browse files
[Tests] Make sure tests are on GPU (#269)
* [Tests] Make sure tests are on GPU * move more models * speed up tests
1 parent 16172c1 commit 9e1b1ca

File tree

5 files changed

+47
-14
lines changed

5 files changed

+47
-14
lines changed
File renamed without changes.

tests/test_models_unet.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
from .test_modeling_common import ModelTesterMixin
2525

2626

27+
torch.backends.cuda.matmul.allow_tf32 = False
28+
29+
2730
class UnetModelTests(ModelTesterMixin, unittest.TestCase):
2831
model_class = UNet2DModel
2932

@@ -133,18 +136,20 @@ def test_from_pretrained_hub(self):
133136
def test_output_pretrained(self):
134137
model = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update")
135138
model.eval()
139+
model.to(torch_device)
136140

137141
torch.manual_seed(0)
138142
if torch.cuda.is_available():
139143
torch.cuda.manual_seed_all(0)
140144

141145
noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
142-
time_step = torch.tensor([10] * noise.shape[0])
146+
noise = noise.to(torch_device)
147+
time_step = torch.tensor([10] * noise.shape[0]).to(torch_device)
143148

144149
with torch.no_grad():
145150
output = model(noise, time_step)["sample"]
146151

147-
output_slice = output[0, -1, -3:, -3:].flatten()
152+
output_slice = output[0, -1, -3:, -3:].flatten().cpu()
148153
# fmt: off
149154
expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800])
150155
# fmt: on

tests/test_models_vae.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
from .test_modeling_common import ModelTesterMixin
2424

2525

26+
torch.backends.cuda.matmul.allow_tf32 = False
27+
28+
2629
class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
2730
model_class = AutoencoderKL
2831

@@ -74,17 +77,19 @@ def test_from_pretrained_hub(self):
7477

7578
def test_output_pretrained(self):
7679
model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy")
80+
model = model.to(torch_device)
7781
model.eval()
7882

7983
torch.manual_seed(0)
8084
if torch.cuda.is_available():
8185
torch.cuda.manual_seed_all(0)
8286

8387
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
88+
image = image.to(torch_device)
8489
with torch.no_grad():
8590
output = model(image, sample_posterior=True)
8691

87-
output_slice = output[0, -1, -3:, -3:].flatten()
92+
output_slice = output[0, -1, -3:, -3:].flatten().cpu()
8893
# fmt: off
8994
expected_output_slice = torch.tensor([-4.0078e-01, -3.8304e-04, -1.2681e-01, -1.1462e-01, 2.0095e-01, 1.0893e-01, -8.8248e-02, -3.0361e-01, -9.8646e-03])
9095
# fmt: on

tests/test_models_vq.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
from .test_modeling_common import ModelTesterMixin
2424

2525

26+
torch.backends.cuda.matmul.allow_tf32 = False
27+
28+
2629
class VQModelTests(ModelTesterMixin, unittest.TestCase):
2730
model_class = VQModel
2831

@@ -73,17 +76,18 @@ def test_from_pretrained_hub(self):
7376

7477
def test_output_pretrained(self):
7578
model = VQModel.from_pretrained("fusing/vqgan-dummy")
76-
model.eval()
79+
model.to(torch_device).eval()
7780

7881
torch.manual_seed(0)
7982
if torch.cuda.is_available():
8083
torch.cuda.manual_seed_all(0)
8184

8285
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
86+
image = image.to(torch_device)
8387
with torch.no_grad():
8488
output = model(image)
8589

86-
output_slice = output[0, -1, -3:, -3:].flatten()
90+
output_slice = output[0, -1, -3:, -3:].flatten().cpu()
8791
# fmt: off
8892
expected_output_slice = torch.tensor([-0.0153, -0.4044, -0.1880, -0.5161, -0.2418, -0.4072, -0.1612, -0.0633, -0.0143])
8993
# fmt: on

tests/test_pipelines.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,12 @@ def test_from_pretrained_save_pretrained(self):
5959
schedular = DDPMScheduler(num_train_timesteps=10)
6060

6161
ddpm = DDPMPipeline(model, schedular)
62+
ddpm.to(torch_device)
6263

6364
with tempfile.TemporaryDirectory() as tmpdirname:
6465
ddpm.save_pretrained(tmpdirname)
6566
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
67+
new_ddpm.to(torch_device)
6668

6769
generator = torch.manual_seed(0)
6870

@@ -76,11 +78,12 @@ def test_from_pretrained_save_pretrained(self):
7678
def test_from_pretrained_hub(self):
7779
model_path = "google/ddpm-cifar10-32"
7880

79-
ddpm = DDPMPipeline.from_pretrained(model_path)
80-
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path)
81+
scheduler = DDPMScheduler(num_train_timesteps=10)
8182

82-
ddpm.scheduler.num_timesteps = 10
83-
ddpm_from_hub.scheduler.num_timesteps = 10
83+
ddpm = DDPMPipeline.from_pretrained(model_path, scheduler=scheduler)
84+
ddpm.to(torch_device)
85+
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler)
86+
ddpm_from_hub.to(torch_device)
8487

8588
generator = torch.manual_seed(0)
8689

@@ -94,14 +97,15 @@ def test_from_pretrained_hub(self):
9497
def test_from_pretrained_hub_pass_model(self):
9598
model_path = "google/ddpm-cifar10-32"
9699

100+
scheduler = DDPMScheduler(num_train_timesteps=10)
101+
97102
# pass unet into DiffusionPipeline
98103
unet = UNet2DModel.from_pretrained(model_path)
99-
ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(model_path, unet=unet)
100-
101-
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path)
104+
ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(model_path, unet=unet, scheduler=scheduler)
105+
ddpm_from_hub_custom_model.to(torch_device)
102106

103-
ddpm_from_hub_custom_model.scheduler.num_timesteps = 10
104-
ddpm_from_hub.scheduler.num_timesteps = 10
107+
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler)
108+
ddpm_from_hub.to(torch_device)
105109

106110
generator = torch.manual_seed(0)
107111

@@ -116,6 +120,7 @@ def test_output_format(self):
116120
model_path = "google/ddpm-cifar10-32"
117121

118122
pipe = DDIMPipeline.from_pretrained(model_path)
123+
pipe.to(torch_device)
119124

120125
generator = torch.manual_seed(0)
121126
images = pipe(generator=generator, output_type="numpy")["sample"]
@@ -141,6 +146,7 @@ def test_ddpm_cifar10(self):
141146
scheduler = scheduler.set_format("pt")
142147

143148
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
149+
ddpm.to(torch_device)
144150

145151
generator = torch.manual_seed(0)
146152
image = ddpm(generator=generator, output_type="numpy")["sample"]
@@ -159,6 +165,7 @@ def test_ddim_lsun(self):
159165
scheduler = DDIMScheduler.from_config(model_id)
160166

161167
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
168+
ddpm.to(torch_device)
162169

163170
generator = torch.manual_seed(0)
164171
image = ddpm(generator=generator, output_type="numpy")["sample"]
@@ -177,6 +184,7 @@ def test_ddim_cifar10(self):
177184
scheduler = DDIMScheduler(tensor_format="pt")
178185

179186
ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
187+
ddim.to(torch_device)
180188

181189
generator = torch.manual_seed(0)
182190
image = ddim(generator=generator, eta=0.0, output_type="numpy")["sample"]
@@ -195,6 +203,7 @@ def test_pndm_cifar10(self):
195203
scheduler = PNDMScheduler(tensor_format="pt")
196204

197205
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
206+
pndm.to(torch_device)
198207
generator = torch.manual_seed(0)
199208
image = pndm(generator=generator, output_type="numpy")["sample"]
200209

@@ -207,6 +216,7 @@ def test_pndm_cifar10(self):
207216
@slow
208217
def test_ldm_text2img(self):
209218
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
219+
ldm.to(torch_device)
210220

211221
prompt = "A painting of a squirrel eating a burger"
212222
generator = torch.manual_seed(0)
@@ -223,6 +233,7 @@ def test_ldm_text2img(self):
223233
@slow
224234
def test_ldm_text2img_fast(self):
225235
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
236+
ldm.to(torch_device)
226237

227238
prompt = "A painting of a squirrel eating a burger"
228239
generator = torch.manual_seed(0)
@@ -290,6 +301,7 @@ def test_score_sde_ve_pipeline(self):
290301
scheduler = ScoreSdeVeScheduler.from_config(model_id)
291302

292303
sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
304+
sde_ve.to(torch_device)
293305

294306
torch.manual_seed(0)
295307
image = sde_ve(num_inference_steps=300, output_type="numpy")["sample"]
@@ -304,6 +316,7 @@ def test_score_sde_ve_pipeline(self):
304316
@slow
305317
def test_ldm_uncond(self):
306318
ldm = LDMPipeline.from_pretrained("CompVis/ldm-celebahq-256")
319+
ldm.to(torch_device)
307320

308321
generator = torch.manual_seed(0)
309322
image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"]
@@ -323,7 +336,9 @@ def test_ddpm_ddim_equality(self):
323336
ddim_scheduler = DDIMScheduler(tensor_format="pt")
324337

325338
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
339+
ddpm.to(torch_device)
326340
ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
341+
ddim.to(torch_device)
327342

328343
generator = torch.manual_seed(0)
329344
ddpm_image = ddpm(generator=generator, output_type="numpy")["sample"]
@@ -343,7 +358,10 @@ def test_ddpm_ddim_equality_batched(self):
343358
ddim_scheduler = DDIMScheduler(tensor_format="pt")
344359

345360
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
361+
ddpm.to(torch_device)
362+
346363
ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
364+
ddim.to(torch_device)
347365

348366
generator = torch.manual_seed(0)
349367
ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy")["sample"]
@@ -363,6 +381,7 @@ def test_karras_ve_pipeline(self):
363381
scheduler = KarrasVeScheduler(tensor_format="pt")
364382

365383
pipe = KarrasVePipeline(unet=model, scheduler=scheduler)
384+
pipe.to(torch_device)
366385

367386
generator = torch.manual_seed(0)
368387
image = pipe(num_inference_steps=20, generator=generator, output_type="numpy")["sample"]

0 commit comments

Comments
 (0)