@@ -145,8 +145,9 @@ def __call__(
145
145
process. This is the image whose masked region will be inpainted.
146
146
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
147
147
`Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
148
- replaced by noise and therefore repainted, while black pixels will be preserved. The mask image will be
149
- converted to a single channel (luminance) before use.
148
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
149
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
150
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
150
151
strength (`float`, *optional*, defaults to 0.8):
151
152
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
152
153
is 1, the denoising process will be run on the masked area for the full number of iterations specified
@@ -202,10 +203,12 @@ def __call__(
202
203
self .scheduler .set_timesteps (num_inference_steps , ** extra_set_kwargs )
203
204
204
205
# preprocess image
205
- init_image = preprocess_image (init_image ).to (self .device )
206
+ if not isinstance (init_image , torch .FloatTensor ):
207
+ init_image = preprocess_image (init_image )
208
+ init_image .to (self .device )
206
209
207
210
# encode the init image into latents and scale the latents
208
- init_latent_dist = self .vae .encode (init_image . to ( self . device ) ).latent_dist
211
+ init_latent_dist = self .vae .encode (init_image ).latent_dist
209
212
init_latents = init_latent_dist .sample (generator = generator )
210
213
211
214
init_latents = 0.18215 * init_latents
@@ -215,8 +218,10 @@ def __call__(
215
218
init_latents_orig = init_latents
216
219
217
220
# preprocess mask
218
- mask = preprocess_mask (mask_image ).to (self .device )
219
- mask = torch .cat ([mask ] * batch_size )
221
+ if not isinstance (mask_image , torch .FloatTensor ):
222
+ mask_image = preprocess_mask (mask_image )
223
+ mask_image .to (self .device )
224
+ mask = torch .cat ([mask_image ] * batch_size )
220
225
221
226
# check sizes
222
227
if not mask .shape == init_latents .shape :
0 commit comments