Skip to content

Commit 06924c6

Browse files
authored
[StableDiffusionInpaintPipeline] accept tensors for init and mask image (#439)
* accept tensors * fix mask handling * make device placement cleaner * update doc for mask image
1 parent 761f029 commit 06924c6

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,9 @@ def __call__(
145145
process. This is the image whose masked region will be inpainted.
146146
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
147147
`Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
148-
replaced by noise and therefore repainted, while black pixels will be preserved. The mask image will be
149-
converted to a single channel (luminance) before use.
148+
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
149+
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
150+
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
150151
strength (`float`, *optional*, defaults to 0.8):
151152
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
152153
is 1, the denoising process will be run on the masked area for the full number of iterations specified
@@ -202,10 +203,12 @@ def __call__(
202203
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
203204

204205
# preprocess image
205-
init_image = preprocess_image(init_image).to(self.device)
206+
if not isinstance(init_image, torch.FloatTensor):
207+
init_image = preprocess_image(init_image)
208+
init_image.to(self.device)
206209

207210
# encode the init image into latents and scale the latents
208-
init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
211+
init_latent_dist = self.vae.encode(init_image).latent_dist
209212
init_latents = init_latent_dist.sample(generator=generator)
210213

211214
init_latents = 0.18215 * init_latents
@@ -215,8 +218,10 @@ def __call__(
215218
init_latents_orig = init_latents
216219

217220
# preprocess mask
218-
mask = preprocess_mask(mask_image).to(self.device)
219-
mask = torch.cat([mask] * batch_size)
221+
if not isinstance(mask_image, torch.FloatTensor):
222+
mask_image = preprocess_mask(mask_image)
223+
mask_image.to(self.device)
224+
mask = torch.cat([mask_image] * batch_size)
220225

221226
# check sizes
222227
if not mask.shape == init_latents.shape:

0 commit comments

Comments
 (0)