Skip to content

[Img2Img2] Re-add K LMS scheduler #340

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
import torch

import PIL
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer

from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, PNDMScheduler
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from .safety_checker import StableDiffusionSafetyChecker


Expand All @@ -31,7 +30,7 @@ def __init__(
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler],
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
):
Expand Down Expand Up @@ -93,12 +92,17 @@ def __call__(
# get the original timestep using init_timestep
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
if isinstance(self.scheduler, LMSDiscreteScheduler):
timesteps = torch.tensor(
[num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
)
else:
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)

# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps).to(self.device)

# get prompt text embeddings
text_input = self.tokenizer(
Expand Down Expand Up @@ -137,11 +141,22 @@ def __call__(
extra_step_kwargs["eta"] = eta

latents = init_latents

t_start = max(num_inference_steps - init_timestep + offset, 0)
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps[t_start:])):
t_index = t_start + i

# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents

# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[t_index]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
latent_model_input = latent_model_input.to(self.unet.dtype)
t = t.to(self.unet.dtype)

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]

Expand All @@ -151,11 +166,14 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs)["prev_sample"]
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents)
image = self.vae.decode(latents.to(self.vae.dtype))

image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
Expand Down
79 changes: 75 additions & 4 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,49 @@ def test_stable_diffusion_img2img(self):
expected_slice = np.array([0.4492, 0.3865, 0.4222, 0.5854, 0.5139, 0.4379, 0.4193, 0.48, 0.4218])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

def test_stable_diffusion_img2img_k_lms(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")

vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

init_image = self.dummy_image.to(device)

# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionImg2ImgPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)

prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
init_image=init_image,
)

image = output["sample"]

image_slice = image[0, -3:, -3:, -1]

assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.4367, 0.4986, 0.4372, 0.6706, 0.5665, 0.444, 0.5864, 0.6019, 0.5203])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

def test_stable_diffusion_inpaint(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
Expand Down Expand Up @@ -892,7 +935,7 @@ def test_lms_stable_diffusion_pipeline(self):
def test_stable_diffusion_img2img_pipeline(self):
ds = load_dataset("hf-internal-testing/diffusers-images", split="train")

init_image = ds[1]["image"].resize((768, 512))
init_image = ds[2]["image"].resize((768, 512))
output_image = ds[0]["image"].resize((768, 512))

model_id = "CompVis/stable-diffusion-v1-4"
Expand All @@ -915,12 +958,40 @@ def test_stable_diffusion_img2img_pipeline(self):

@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_in_paint_pipeline(self):
def test_stable_diffusion_img2img_pipeline_k_lms(self):
ds = load_dataset("hf-internal-testing/diffusers-images", split="train")

init_image = ds[2]["image"].resize((768, 512))
mask_image = ds[3]["image"].resize((768, 512))
output_image = ds[4]["image"].resize((768, 512))
output_image = ds[1]["image"].resize((768, 512))

lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")

model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id, scheduler=lms, use_auth_token=True)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

prompt = "A fantasy landscape, trending on artstation"

generator = torch.Generator(device=torch_device).manual_seed(0)
image = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5, generator=generator)[
"sample"
][0]

expected_array = np.array(output_image)
sampled_array = np.array(image)

assert sampled_array.shape == (512, 768, 3)
assert np.max(np.abs(sampled_array - expected_array)) < 1e-4

@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_in_paint_pipeline(self):
ds = load_dataset("hf-internal-testing/diffusers-images", split="train")

init_image = ds[3]["image"].resize((768, 512))
mask_image = ds[4]["image"].resize((768, 512))
output_image = ds[5]["image"].resize((768, 512))

model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, use_auth_token=True)
Expand Down