@@ -234,43 +234,6 @@ def __call__(
234
234
# set timesteps
235
235
self .scheduler .set_timesteps (num_inference_steps )
236
236
237
- # preprocess image
238
- if not isinstance (init_image , torch .FloatTensor ):
239
- init_image = preprocess_image (init_image )
240
- init_image = init_image .to (self .device )
241
-
242
- # encode the init image into latents and scale the latents
243
- init_latent_dist = self .vae .encode (init_image ).latent_dist
244
- init_latents = init_latent_dist .sample (generator = generator )
245
-
246
- init_latents = 0.18215 * init_latents
247
-
248
- # Expand init_latents for batch_size and num_images_per_prompt
249
- init_latents = torch .cat ([init_latents ] * batch_size * num_images_per_prompt , dim = 0 )
250
- init_latents_orig = init_latents
251
-
252
- # preprocess mask
253
- if not isinstance (mask_image , torch .FloatTensor ):
254
- mask_image = preprocess_mask (mask_image )
255
- mask_image = mask_image .to (self .device )
256
- mask = torch .cat ([mask_image ] * batch_size * num_images_per_prompt )
257
-
258
- # check sizes
259
- if not mask .shape == init_latents .shape :
260
- raise ValueError ("The mask and init_image should be the same size!" )
261
-
262
- # get the original timestep using init_timestep
263
- offset = self .scheduler .config .get ("steps_offset" , 0 )
264
- init_timestep = int (num_inference_steps * strength ) + offset
265
- init_timestep = min (init_timestep , num_inference_steps )
266
-
267
- timesteps = self .scheduler .timesteps [- init_timestep ]
268
- timesteps = torch .tensor ([timesteps ] * batch_size * num_images_per_prompt , device = self .device )
269
-
270
- # add noise to latents using the timesteps
271
- noise = torch .randn (init_latents .shape , generator = generator , device = self .device )
272
- init_latents = self .scheduler .add_noise (init_latents , noise , timesteps )
273
-
274
237
# get prompt text embeddings
275
238
text_inputs = self .tokenizer (
276
239
prompt ,
@@ -335,6 +298,43 @@ def __call__(
335
298
# to avoid doing two forward passes
336
299
text_embeddings = torch .cat ([uncond_embeddings , text_embeddings ])
337
300
301
+ # preprocess image
302
+ if not isinstance (init_image , torch .FloatTensor ):
303
+ init_image = preprocess_image (init_image )
304
+
305
+ # encode the init image into latents and scale the latents
306
+ latents_dtype = text_embeddings .dtype
307
+ init_image = init_image .to (device = self .device , dtype = latents_dtype )
308
+ init_latent_dist = self .vae .encode (init_image ).latent_dist
309
+ init_latents = init_latent_dist .sample (generator = generator )
310
+ init_latents = 0.18215 * init_latents
311
+
312
+ # Expand init_latents for batch_size and num_images_per_prompt
313
+ init_latents = torch .cat ([init_latents ] * batch_size * num_images_per_prompt , dim = 0 )
314
+ init_latents_orig = init_latents
315
+
316
+ # preprocess mask
317
+ if not isinstance (mask_image , torch .FloatTensor ):
318
+ mask_image = preprocess_mask (mask_image )
319
+ mask_image = mask_image .to (device = self .device , dtype = latents_dtype )
320
+ mask = torch .cat ([mask_image ] * batch_size * num_images_per_prompt )
321
+
322
+ # check sizes
323
+ if not mask .shape == init_latents .shape :
324
+ raise ValueError ("The mask and init_image should be the same size!" )
325
+
326
+ # get the original timestep using init_timestep
327
+ offset = self .scheduler .config .get ("steps_offset" , 0 )
328
+ init_timestep = int (num_inference_steps * strength ) + offset
329
+ init_timestep = min (init_timestep , num_inference_steps )
330
+
331
+ timesteps = self .scheduler .timesteps [- init_timestep ]
332
+ timesteps = torch .tensor ([timesteps ] * batch_size * num_images_per_prompt , device = self .device )
333
+
334
+ # add noise to latents using the timesteps
335
+ noise = torch .randn (init_latents .shape , generator = generator , device = self .device , dtype = latents_dtype )
336
+ init_latents = self .scheduler .add_noise (init_latents , noise , timesteps )
337
+
338
338
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
339
339
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
340
340
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
0 commit comments