@@ -59,10 +59,12 @@ def test_from_pretrained_save_pretrained(self):
59
59
schedular = DDPMScheduler (num_train_timesteps = 10 )
60
60
61
61
ddpm = DDPMPipeline (model , schedular )
62
+ ddpm .to (torch_device )
62
63
63
64
with tempfile .TemporaryDirectory () as tmpdirname :
64
65
ddpm .save_pretrained (tmpdirname )
65
66
new_ddpm = DDPMPipeline .from_pretrained (tmpdirname )
67
+ new_ddpm .to (torch_device )
66
68
67
69
generator = torch .manual_seed (0 )
68
70
@@ -76,11 +78,12 @@ def test_from_pretrained_save_pretrained(self):
76
78
def test_from_pretrained_hub (self ):
77
79
model_path = "google/ddpm-cifar10-32"
78
80
79
- ddpm = DDPMPipeline .from_pretrained (model_path )
80
- ddpm_from_hub = DiffusionPipeline .from_pretrained (model_path )
81
+ scheduler = DDPMScheduler (num_train_timesteps = 10 )
81
82
82
- ddpm .scheduler .num_timesteps = 10
83
- ddpm_from_hub .scheduler .num_timesteps = 10
83
+ ddpm = DDPMPipeline .from_pretrained (model_path , scheduler = scheduler )
84
+ ddpm .to (torch_device )
85
+ ddpm_from_hub = DiffusionPipeline .from_pretrained (model_path , scheduler = scheduler )
86
+ ddpm_from_hub .to (torch_device )
84
87
85
88
generator = torch .manual_seed (0 )
86
89
@@ -94,14 +97,15 @@ def test_from_pretrained_hub(self):
94
97
def test_from_pretrained_hub_pass_model (self ):
95
98
model_path = "google/ddpm-cifar10-32"
96
99
100
+ scheduler = DDPMScheduler (num_train_timesteps = 10 )
101
+
97
102
# pass unet into DiffusionPipeline
98
103
unet = UNet2DModel .from_pretrained (model_path )
99
- ddpm_from_hub_custom_model = DiffusionPipeline .from_pretrained (model_path , unet = unet )
100
-
101
- ddpm_from_hub = DiffusionPipeline .from_pretrained (model_path )
104
+ ddpm_from_hub_custom_model = DiffusionPipeline .from_pretrained (model_path , unet = unet , scheduler = scheduler )
105
+ ddpm_from_hub_custom_model .to (torch_device )
102
106
103
- ddpm_from_hub_custom_model . scheduler . num_timesteps = 10
104
- ddpm_from_hub .scheduler . num_timesteps = 10
107
+ ddpm_from_hub = DiffusionPipeline . from_pretrained ( model_path , scheduler = scheduler )
108
+ ddpm_from_hub .to ( torch_device )
105
109
106
110
generator = torch .manual_seed (0 )
107
111
@@ -116,6 +120,7 @@ def test_output_format(self):
116
120
model_path = "google/ddpm-cifar10-32"
117
121
118
122
pipe = DDIMPipeline .from_pretrained (model_path )
123
+ pipe .to (torch_device )
119
124
120
125
generator = torch .manual_seed (0 )
121
126
images = pipe (generator = generator , output_type = "numpy" )["sample" ]
@@ -141,6 +146,7 @@ def test_ddpm_cifar10(self):
141
146
scheduler = scheduler .set_format ("pt" )
142
147
143
148
ddpm = DDPMPipeline (unet = unet , scheduler = scheduler )
149
+ ddpm .to (torch_device )
144
150
145
151
generator = torch .manual_seed (0 )
146
152
image = ddpm (generator = generator , output_type = "numpy" )["sample" ]
@@ -159,6 +165,7 @@ def test_ddim_lsun(self):
159
165
scheduler = DDIMScheduler .from_config (model_id )
160
166
161
167
ddpm = DDIMPipeline (unet = unet , scheduler = scheduler )
168
+ ddpm .to (torch_device )
162
169
163
170
generator = torch .manual_seed (0 )
164
171
image = ddpm (generator = generator , output_type = "numpy" )["sample" ]
@@ -177,6 +184,7 @@ def test_ddim_cifar10(self):
177
184
scheduler = DDIMScheduler (tensor_format = "pt" )
178
185
179
186
ddim = DDIMPipeline (unet = unet , scheduler = scheduler )
187
+ ddim .to (torch_device )
180
188
181
189
generator = torch .manual_seed (0 )
182
190
image = ddim (generator = generator , eta = 0.0 , output_type = "numpy" )["sample" ]
@@ -195,6 +203,7 @@ def test_pndm_cifar10(self):
195
203
scheduler = PNDMScheduler (tensor_format = "pt" )
196
204
197
205
pndm = PNDMPipeline (unet = unet , scheduler = scheduler )
206
+ pndm .to (torch_device )
198
207
generator = torch .manual_seed (0 )
199
208
image = pndm (generator = generator , output_type = "numpy" )["sample" ]
200
209
@@ -207,6 +216,7 @@ def test_pndm_cifar10(self):
207
216
@slow
208
217
def test_ldm_text2img (self ):
209
218
ldm = LDMTextToImagePipeline .from_pretrained ("CompVis/ldm-text2im-large-256" )
219
+ ldm .to (torch_device )
210
220
211
221
prompt = "A painting of a squirrel eating a burger"
212
222
generator = torch .manual_seed (0 )
@@ -223,6 +233,7 @@ def test_ldm_text2img(self):
223
233
@slow
224
234
def test_ldm_text2img_fast (self ):
225
235
ldm = LDMTextToImagePipeline .from_pretrained ("CompVis/ldm-text2im-large-256" )
236
+ ldm .to (torch_device )
226
237
227
238
prompt = "A painting of a squirrel eating a burger"
228
239
generator = torch .manual_seed (0 )
@@ -290,6 +301,7 @@ def test_score_sde_ve_pipeline(self):
290
301
scheduler = ScoreSdeVeScheduler .from_config (model_id )
291
302
292
303
sde_ve = ScoreSdeVePipeline (unet = model , scheduler = scheduler )
304
+ sde_ve .to (torch_device )
293
305
294
306
torch .manual_seed (0 )
295
307
image = sde_ve (num_inference_steps = 300 , output_type = "numpy" )["sample" ]
@@ -304,6 +316,7 @@ def test_score_sde_ve_pipeline(self):
304
316
@slow
305
317
def test_ldm_uncond (self ):
306
318
ldm = LDMPipeline .from_pretrained ("CompVis/ldm-celebahq-256" )
319
+ ldm .to (torch_device )
307
320
308
321
generator = torch .manual_seed (0 )
309
322
image = ldm (generator = generator , num_inference_steps = 5 , output_type = "numpy" )["sample" ]
@@ -323,7 +336,9 @@ def test_ddpm_ddim_equality(self):
323
336
ddim_scheduler = DDIMScheduler (tensor_format = "pt" )
324
337
325
338
ddpm = DDPMPipeline (unet = unet , scheduler = ddpm_scheduler )
339
+ ddpm .to (torch_device )
326
340
ddim = DDIMPipeline (unet = unet , scheduler = ddim_scheduler )
341
+ ddim .to (torch_device )
327
342
328
343
generator = torch .manual_seed (0 )
329
344
ddpm_image = ddpm (generator = generator , output_type = "numpy" )["sample" ]
@@ -343,7 +358,10 @@ def test_ddpm_ddim_equality_batched(self):
343
358
ddim_scheduler = DDIMScheduler (tensor_format = "pt" )
344
359
345
360
ddpm = DDPMPipeline (unet = unet , scheduler = ddpm_scheduler )
361
+ ddpm .to (torch_device )
362
+
346
363
ddim = DDIMPipeline (unet = unet , scheduler = ddim_scheduler )
364
+ ddim .to (torch_device )
347
365
348
366
generator = torch .manual_seed (0 )
349
367
ddpm_images = ddpm (batch_size = 4 , generator = generator , output_type = "numpy" )["sample" ]
@@ -363,6 +381,7 @@ def test_karras_ve_pipeline(self):
363
381
scheduler = KarrasVeScheduler (tensor_format = "pt" )
364
382
365
383
pipe = KarrasVePipeline (unet = model , scheduler = scheduler )
384
+ pipe .to (torch_device )
366
385
367
386
generator = torch .manual_seed (0 )
368
387
image = pipe (num_inference_steps = 20 , generator = generator , output_type = "numpy" )["sample" ]
0 commit comments