@@ -84,8 +84,8 @@ def __call__(
84
84
] = None ,
85
85
blur = 24 ,
86
86
blur_compose = 4 ,
87
- sample_mode = ' sample' ,
88
- ** kwargs
87
+ sample_mode = " sample" ,
88
+ ** kwargs ,
89
89
):
90
90
r"""
91
91
The call function to the pipeline for generation.
@@ -174,7 +174,6 @@ def __call__(
174
174
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`" ,
175
175
)
176
176
177
-
178
177
# 0. Check inputs. Raise error if not correct
179
178
self .check_inputs (
180
179
prompt ,
@@ -249,7 +248,6 @@ def __call__(
249
248
clip_skip = self .clip_skip ,
250
249
)
251
250
252
-
253
251
# 3. Preprocess image
254
252
input_image = image if image is not None else original_image
255
253
image = self .image_processor .preprocess (input_image )
@@ -282,25 +280,26 @@ def denoising_value_valid(dnv):
282
280
device ,
283
281
generator ,
284
282
add_noise ,
285
- sample_mode = sample_mode
283
+ sample_mode = sample_mode ,
286
284
)
287
285
288
286
# mean of the latent distribution
289
287
# it is multiplied by self.vae.config.scaling_factor
290
288
non_paint_latents = self .prepare_latents (
291
- original_image ,
292
- latent_timestep ,
293
- batch_size ,
294
- num_images_per_prompt ,
295
- prompt_embeds .dtype ,
296
- device ,
297
- generator ,
298
- add_noise = False ,
299
- sample_mode = "argmax" )
289
+ original_image ,
290
+ latent_timestep ,
291
+ batch_size ,
292
+ num_images_per_prompt ,
293
+ prompt_embeds .dtype ,
294
+ device ,
295
+ generator ,
296
+ add_noise = False ,
297
+ sample_mode = "argmax" ,
298
+ )
300
299
301
300
if self .debug_save :
302
301
init_img_from_latents = self .latents_to_img (non_paint_latents )
303
- init_img_from_latents [0 ].save (' non_paint_latents.png' )
302
+ init_img_from_latents [0 ].save (" non_paint_latents.png" )
304
303
# 6. create latent mask
305
304
latent_mask = self ._make_latent_mask (latents , mask )
306
305
@@ -359,7 +358,6 @@ def denoising_value_valid(dnv):
359
358
self .do_classifier_free_guidance ,
360
359
)
361
360
362
-
363
361
# 10. Denoising loop
364
362
num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
365
363
@@ -406,15 +404,14 @@ def denoising_value_valid(dnv):
406
404
orig_latents_t = self .scheduler .add_noise (non_paint_latents , noise , t .unsqueeze (0 ))
407
405
408
406
# orig_latents_t (1 - latent_mask) + latents * latent_mask
409
- latents = torch .lerp (orig_latents_t , latents , latent_mask )
407
+ latents = torch .lerp (orig_latents_t , latents , latent_mask )
410
408
411
409
if self .debug_save :
412
410
img1 = self .latents_to_img (latents )
413
411
t_str = str (t .int ().item ())
414
412
for i in range (3 - len (t_str )):
415
- t_str = '0' + t_str
416
- img1 [0 ].save (f'step{ t_str } .png' )
417
-
413
+ t_str = "0" + t_str
414
+ img1 [0 ].save (f"step{ t_str } .png" )
418
415
419
416
# expand the latents if we are doing classifier free guidance
420
417
latent_model_input = torch .cat ([latents ] * 2 ) if self .do_classifier_free_guidance else latents
@@ -444,7 +441,6 @@ def denoising_value_valid(dnv):
444
441
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
445
442
noise_pred = rescale_noise_cfg (noise_pred , noise_pred_text , guidance_rescale = self .guidance_rescale )
446
443
447
-
448
444
# compute the previous noisy sample x_t -> x_t-1
449
445
latents_dtype = latents .dtype
450
446
latents = self .scheduler .step (noise_pred , t , latents , ** extra_step_kwargs , return_dict = False )[0 ]
@@ -501,7 +497,7 @@ def denoising_value_valid(dnv):
501
497
502
498
latents = self .denormalize (latents )
503
499
image = self .vae .decode (latents , return_dict = False )[0 ]
504
- m = mask_compose .permute (2 ,0 , 1 ).unsqueeze (0 ).to (image )
500
+ m = mask_compose .permute (2 , 0 , 1 ).unsqueeze (0 ).to (image )
505
501
img_compose = m * image + (1 - m ) * original_image .to (image )
506
502
image = img_compose
507
503
# cast back to fp16 if needed
@@ -519,7 +515,6 @@ def denoising_value_valid(dnv):
519
515
# Offload all models
520
516
self .maybe_free_model_hooks ()
521
517
522
-
523
518
if not return_dict :
524
519
return (image ,)
525
520
@@ -551,12 +546,17 @@ def _make_latent_mask(self, latents, mask):
551
546
return latent_mask
552
547
553
548
def prepare_latents (
554
- self , image , timestep , batch_size , num_images_per_prompt , dtype , device ,
549
+ self ,
550
+ image ,
551
+ timestep ,
552
+ batch_size ,
553
+ num_images_per_prompt ,
554
+ dtype ,
555
+ device ,
555
556
generator = None ,
556
557
add_noise = True ,
557
- sample_mode : str = "sample"
558
+ sample_mode : str = "sample" ,
558
559
):
559
-
560
560
if not isinstance (image , (torch .Tensor , Image .Image , list )):
561
561
raise ValueError (
562
562
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is { type (image )} "
@@ -573,7 +573,7 @@ def prepare_latents(
573
573
574
574
if image .shape [1 ] == 4 :
575
575
init_latents = image
576
- elif sample_mode == ' random' :
576
+ elif sample_mode == " random" :
577
577
height , width = image .shape [- 2 :]
578
578
num_channels_latents = self .unet .config .in_channels
579
579
latents = self .random_latents (
@@ -600,7 +600,9 @@ def prepare_latents(
600
600
601
601
elif isinstance (generator , list ):
602
602
init_latents = [
603
- retrieve_latents (self .vae .encode (image [i : i + 1 ]), generator = generator [i ], sample_mode = sample_mode )
603
+ retrieve_latents (
604
+ self .vae .encode (image [i : i + 1 ]), generator = generator [i ], sample_mode = sample_mode
605
+ )
604
606
for i in range (batch_size )
605
607
]
606
608
init_latents = torch .cat (init_latents , dim = 0 )
@@ -661,9 +663,7 @@ def denormalize(self, latents):
661
663
latents_mean = (
662
664
torch .tensor (self .vae .config .latents_mean ).view (1 , 4 , 1 , 1 ).to (latents .device , latents .dtype )
663
665
)
664
- latents_std = (
665
- torch .tensor (self .vae .config .latents_std ).view (1 , 4 , 1 , 1 ).to (latents .device , latents .dtype )
666
- )
666
+ latents_std = torch .tensor (self .vae .config .latents_std ).view (1 , 4 , 1 , 1 ).to (latents .device , latents .dtype )
667
667
latents = latents * latents_std / self .vae .config .scaling_factor + latents_mean
668
668
else :
669
669
latents = latents / self .vae .config .scaling_factor
@@ -673,10 +673,10 @@ def denormalize(self, latents):
673
673
def latents_to_img (self , latents ):
674
674
l1 = self .denormalize (latents )
675
675
img1 = self .vae .decode (l1 , return_dict = False )[0 ]
676
- img1 = self .image_processor .postprocess (img1 , output_type = ' pil' , do_denormalize = [True ])
676
+ img1 = self .image_processor .postprocess (img1 , output_type = " pil" , do_denormalize = [True ])
677
677
return img1
678
678
679
679
def blur_mask (self , pil_mask , blur ):
680
680
mask_blur = pil_mask .filter (ImageFilter .GaussianBlur (radius = blur ))
681
681
mask_blur = np .array (mask_blur )
682
- return torch .from_numpy (np .tile (mask_blur / mask_blur .max (), (3 , 1 , 1 )).transpose (1 ,2 , 0 ))
682
+ return torch .from_numpy (np .tile (mask_blur / mask_blur .max (), (3 , 1 , 1 )).transpose (1 , 2 , 0 ))
0 commit comments