Skip to content

Commit 6453942

Browse files
committed
style
1 parent fda02c3 commit 6453942

File tree

1 file changed

+33
-33
lines changed

1 file changed

+33
-33
lines changed

examples/community/masked_stable_diffusion_xl_img2img.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def __call__(
8484
] = None,
8585
blur=24,
8686
blur_compose=4,
87-
sample_mode='sample',
88-
**kwargs
87+
sample_mode="sample",
88+
**kwargs,
8989
):
9090
r"""
9191
The call function to the pipeline for generation.
@@ -174,7 +174,6 @@ def __call__(
174174
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
175175
)
176176

177-
178177
# 0. Check inputs. Raise error if not correct
179178
self.check_inputs(
180179
prompt,
@@ -249,7 +248,6 @@ def __call__(
249248
clip_skip=self.clip_skip,
250249
)
251250

252-
253251
# 3. Preprocess image
254252
input_image = image if image is not None else original_image
255253
image = self.image_processor.preprocess(input_image)
@@ -282,25 +280,26 @@ def denoising_value_valid(dnv):
282280
device,
283281
generator,
284282
add_noise,
285-
sample_mode=sample_mode
283+
sample_mode=sample_mode,
286284
)
287285

288286
# mean of the latent distribution
289287
# it is multiplied by self.vae.config.scaling_factor
290288
non_paint_latents = self.prepare_latents(
291-
original_image,
292-
latent_timestep,
293-
batch_size,
294-
num_images_per_prompt,
295-
prompt_embeds.dtype,
296-
device,
297-
generator,
298-
add_noise=False,
299-
sample_mode="argmax")
289+
original_image,
290+
latent_timestep,
291+
batch_size,
292+
num_images_per_prompt,
293+
prompt_embeds.dtype,
294+
device,
295+
generator,
296+
add_noise=False,
297+
sample_mode="argmax",
298+
)
300299

301300
if self.debug_save:
302301
init_img_from_latents = self.latents_to_img(non_paint_latents)
303-
init_img_from_latents[0].save('non_paint_latents.png')
302+
init_img_from_latents[0].save("non_paint_latents.png")
304303
# 6. create latent mask
305304
latent_mask = self._make_latent_mask(latents, mask)
306305

@@ -359,7 +358,6 @@ def denoising_value_valid(dnv):
359358
self.do_classifier_free_guidance,
360359
)
361360

362-
363361
# 10. Denoising loop
364362
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
365363

@@ -406,15 +404,14 @@ def denoising_value_valid(dnv):
406404
orig_latents_t = self.scheduler.add_noise(non_paint_latents, noise, t.unsqueeze(0))
407405

408406
# orig_latents_t (1 - latent_mask) + latents * latent_mask
409-
latents = torch.lerp(orig_latents_t , latents, latent_mask)
407+
latents = torch.lerp(orig_latents_t, latents, latent_mask)
410408

411409
if self.debug_save:
412410
img1 = self.latents_to_img(latents)
413411
t_str = str(t.int().item())
414412
for i in range(3 - len(t_str)):
415-
t_str = '0' + t_str
416-
img1[0].save(f'step{t_str}.png')
417-
413+
t_str = "0" + t_str
414+
img1[0].save(f"step{t_str}.png")
418415

419416
# expand the latents if we are doing classifier free guidance
420417
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
@@ -444,7 +441,6 @@ def denoising_value_valid(dnv):
444441
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
445442
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
446443

447-
448444
# compute the previous noisy sample x_t -> x_t-1
449445
latents_dtype = latents.dtype
450446
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
@@ -501,7 +497,7 @@ def denoising_value_valid(dnv):
501497

502498
latents = self.denormalize(latents)
503499
image = self.vae.decode(latents, return_dict=False)[0]
504-
m = mask_compose.permute(2,0,1).unsqueeze(0).to(image)
500+
m = mask_compose.permute(2, 0, 1).unsqueeze(0).to(image)
505501
img_compose = m * image + (1 - m) * original_image.to(image)
506502
image = img_compose
507503
# cast back to fp16 if needed
@@ -519,7 +515,6 @@ def denoising_value_valid(dnv):
519515
# Offload all models
520516
self.maybe_free_model_hooks()
521517

522-
523518
if not return_dict:
524519
return (image,)
525520

@@ -551,12 +546,17 @@ def _make_latent_mask(self, latents, mask):
551546
return latent_mask
552547

553548
def prepare_latents(
554-
self, image, timestep, batch_size, num_images_per_prompt, dtype, device,
549+
self,
550+
image,
551+
timestep,
552+
batch_size,
553+
num_images_per_prompt,
554+
dtype,
555+
device,
555556
generator=None,
556557
add_noise=True,
557-
sample_mode: str = "sample"
558+
sample_mode: str = "sample",
558559
):
559-
560560
if not isinstance(image, (torch.Tensor, Image.Image, list)):
561561
raise ValueError(
562562
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
@@ -573,7 +573,7 @@ def prepare_latents(
573573

574574
if image.shape[1] == 4:
575575
init_latents = image
576-
elif sample_mode == 'random':
576+
elif sample_mode == "random":
577577
height, width = image.shape[-2:]
578578
num_channels_latents = self.unet.config.in_channels
579579
latents = self.random_latents(
@@ -600,7 +600,9 @@ def prepare_latents(
600600

601601
elif isinstance(generator, list):
602602
init_latents = [
603-
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode)
603+
retrieve_latents(
604+
self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode
605+
)
604606
for i in range(batch_size)
605607
]
606608
init_latents = torch.cat(init_latents, dim=0)
@@ -661,9 +663,7 @@ def denormalize(self, latents):
661663
latents_mean = (
662664
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
663665
)
664-
latents_std = (
665-
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
666-
)
666+
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
667667
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
668668
else:
669669
latents = latents / self.vae.config.scaling_factor
@@ -673,10 +673,10 @@ def denormalize(self, latents):
673673
def latents_to_img(self, latents):
674674
l1 = self.denormalize(latents)
675675
img1 = self.vae.decode(l1, return_dict=False)[0]
676-
img1 = self.image_processor.postprocess(img1, output_type='pil', do_denormalize=[True])
676+
img1 = self.image_processor.postprocess(img1, output_type="pil", do_denormalize=[True])
677677
return img1
678678

679679
def blur_mask(self, pil_mask, blur):
680680
mask_blur = pil_mask.filter(ImageFilter.GaussianBlur(radius=blur))
681681
mask_blur = np.array(mask_blur)
682-
return torch.from_numpy(np.tile(mask_blur / mask_blur.max(), (3, 1, 1)).transpose(1,2,0))
682+
return torch.from_numpy(np.tile(mask_blur / mask_blur.max(), (3, 1, 1)).transpose(1, 2, 0))

0 commit comments

Comments
 (0)