Skip to content

[StableDiffusionInpaintPipeline] accept tensors for init and mask image #439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 16, 2022
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ def __call__(
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)

# preprocess image
init_image = preprocess_image(init_image).to(self.device)
if not isinstance(init_image, torch.FloatTensor):
init_image = preprocess_image(init_image)

# encode the init image into latents and scale the latents
init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
Expand All @@ -215,8 +216,9 @@ def __call__(
init_latents_orig = init_latents

# preprocess mask
mask = preprocess_mask(mask_image).to(self.device)
mask = torch.cat([mask] * batch_size)
if not isinstance(mask_image, torch.FloatTensor):
mask_image = preprocess_mask(mask_image).to(self.device)
mask = torch.cat([mask_image] * batch_size)

# check sizes
if not mask.shape == init_latents.shape:
Expand Down