From 21fc39b0b8a25735cf81b2455eb12185d4e1460c Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 7 Oct 2022 13:57:21 +0200 Subject: [PATCH 1/8] handle dtype in vae and image2image pipeline --- src/diffusers/models/vae.py | 8 +++- .../pipeline_stable_diffusion_img2img.py | 46 ++++++++++--------- 2 files changed, 31 insertions(+), 23 deletions(-) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index fe89b41c074e..55f1d757b8df 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -337,12 +337,16 @@ def __init__(self, parameters, deterministic=False): self.std = torch.exp(0.5 * self.logvar) self.var = torch.exp(self.logvar) if self.deterministic: - self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + self.var = self.std = torch.zeros_like(self.mean).to( + device=self.parameters.device, dtype=self.parameters.dtype + ) def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: device = self.parameters.device sample_device = "cpu" if device.type == "mps" else device - sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device) + sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to( + device=device, dtype=self.parameters.dtype + ) x = self.mean + self.std * sample return x diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 15bdd0208825..37f4be658435 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -217,26 +217,6 @@ def __call__( if isinstance(init_image, PIL.Image.Image): init_image = preprocess(init_image) - # encode the init image into latents and scale the latents - init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist - init_latents = init_latent_dist.sample(generator=generator) - init_latents = 0.18215 * init_latents - - # expand init_latents for batch_size - init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0) - - # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) - - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) - - # add noise to latents using the timesteps - noise = torch.randn(init_latents.shape, generator=generator, device=self.device) - init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) - # get prompt text embeddings text_inputs = self.tokenizer( prompt, @@ -297,6 +277,28 @@ def __call__( # to avoid doing two forward passes text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + # encode the init image into latents and scale the latents + latents_dtype = text_embeddings.dtype + init_image = init_image.to(device=self.device, dtype=latents_dtype) + init_latent_dist = self.vae.encode(init_image).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = 0.18215 * init_latents + + # expand init_latents for batch_size + init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0) + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) + + # add noise to latents using the timesteps + noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype) + init_latents = self.scheduler.add_noise(init_latents, noise, timesteps).to(latents_dtype) + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 @@ -341,7 +343,9 @@ def __call__( image = image.cpu().permute(0, 2, 3, 1).numpy() safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) + ) if output_type == "pil": image = self.numpy_to_pil(image) From 6482ac4a43b3021393e2611693b12e5431bd0fa5 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 7 Oct 2022 14:27:33 +0200 Subject: [PATCH 2/8] fix inpaint in fp16 --- .../pipeline_stable_diffusion_inpaint.py | 74 ++++++++++--------- 1 file changed, 38 insertions(+), 36 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 24f4bc99bddc..0d9d67f4bc71 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -234,42 +234,7 @@ def __call__( # set timesteps self.scheduler.set_timesteps(num_inference_steps) - # preprocess image - if not isinstance(init_image, torch.FloatTensor): - init_image = preprocess_image(init_image) - init_image = init_image.to(self.device) - - # encode the init image into latents and scale the latents - init_latent_dist = self.vae.encode(init_image).latent_dist - init_latents = init_latent_dist.sample(generator=generator) - - init_latents = 0.18215 * init_latents - - # Expand init_latents for batch_size and num_images_per_prompt - init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0) - init_latents_orig = init_latents - - # preprocess mask - if not isinstance(mask_image, torch.FloatTensor): - mask_image = preprocess_mask(mask_image) - mask_image = mask_image.to(self.device) - mask = torch.cat([mask_image] * batch_size * num_images_per_prompt) - - # check sizes - if not mask.shape == init_latents.shape: - raise ValueError("The mask and init_image should be the same size!") - - # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) - - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) - - # add noise to latents using the timesteps - noise = torch.randn(init_latents.shape, generator=generator, device=self.device) - init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) + # get prompt text embeddings text_inputs = self.tokenizer( @@ -334,6 +299,43 @@ def __call__( # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + # preprocess image + if not isinstance(init_image, torch.FloatTensor): + init_image = preprocess_image(init_image) + + # encode the init image into latents and scale the latents + latents_dtype = text_embeddings.dtype + init_image = init_image.to(device=self.device, dtype=latents_dtype) + init_latent_dist = self.vae.encode(init_image).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = 0.18215 * init_latents + + # Expand init_latents for batch_size and num_images_per_prompt + init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0) + init_latents_orig = init_latents + + # preprocess mask + if not isinstance(mask_image, torch.FloatTensor): + mask_image = preprocess_mask(mask_image) + mask_image = mask_image.to(self.device) + mask = torch.cat([mask_image] * batch_size * num_images_per_prompt) + + # check sizes + if not mask.shape == init_latents.shape: + raise ValueError("The mask and init_image should be the same size!") + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) + + # add noise to latents using the timesteps + noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype) + init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. From fc69ea4529afe54883f40c98f064eb0d10b4cbf8 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 7 Oct 2022 14:27:52 +0200 Subject: [PATCH 3/8] dtype should be handled in add_noise --- .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 37f4be658435..72e15f4f904b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -297,7 +297,7 @@ def __call__( # add noise to latents using the timesteps noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype) - init_latents = self.scheduler.add_noise(init_latents, noise, timesteps).to(latents_dtype) + init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. From 1dfa63215406b8c9f21c86af8cae92da92a9e8ef Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 7 Oct 2022 14:28:05 +0200 Subject: [PATCH 4/8] style --- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 0d9d67f4bc71..47eb60f5d80c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -234,8 +234,6 @@ def __call__( # set timesteps self.scheduler.set_timesteps(num_inference_steps) - - # get prompt text embeddings text_inputs = self.tokenizer( prompt, @@ -299,7 +297,7 @@ def __call__( # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - + # preprocess image if not isinstance(init_image, torch.FloatTensor): init_image = preprocess_image(init_image) From 7ca295c13cb8cf32683071f1148fbd1ffa277b40 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 7 Oct 2022 16:25:58 +0200 Subject: [PATCH 5/8] address review comments --- src/diffusers/models/vae.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 55f1d757b8df..7ce2f98eee27 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -337,16 +337,16 @@ def __init__(self, parameters, deterministic=False): self.std = torch.exp(0.5 * self.logvar) self.var = torch.exp(self.logvar) if self.deterministic: - self.var = self.std = torch.zeros_like(self.mean).to( - device=self.parameters.device, dtype=self.parameters.dtype + self.var = self.std = torch.zeros_like( + self.mean, device=self.parameters.device, dtype=self.parameters.dtype ) def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: device = self.parameters.device sample_device = "cpu" if device.type == "mps" else device - sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to( - device=device, dtype=self.parameters.dtype - ) + sample = torch.randn(self.mean.shape, generator=generator, device=sample_device) + # make sure sample is on the same device as the parameters and has same dtype + sample = sample.to(device=device, dtype=self.parameters.dtype) x = self.mean + self.std * sample return x From a6a7c7b3b49f77725356104aaba9744d37082f1f Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 7 Oct 2022 16:37:24 +0200 Subject: [PATCH 6/8] add simple fast tests to check fp16 --- tests/test_pipelines.py | 118 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 4a0839ad490b..eb8fb907eedb 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -1012,6 +1012,124 @@ def test_stable_diffusion_inpaint_num_images_per_prompt(self): assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3) + @unittest.skipIf(torch_device == "cpu", "This test requires a GPU") + def test_stable_diffusion_fp16(self): + """Test that stable diffusion works with fp16""" + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # put models in fp16 + unet = unet.half() + vae = vae.half() + bert = bert.half() + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=self.dummy_safety_checker, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + image = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images + + assert image.shape == (1, 128, 128, 3) + + @unittest.skipIf(torch_device == "cpu", "This test requires a GPU") + def test_stable_diffusion_img2img_fp16(self): + """Test that stable diffusion img2img works with fp16""" + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + init_image = self.dummy_image.to(torch_device) + + # put models in fp16 + unet = unet.half() + vae = vae.half() + bert = bert.half() + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionImg2ImgPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=self.dummy_safety_checker, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + image = sd_pipe( + [prompt], + generator=generator, + num_inference_steps=2, + output_type="np", + init_image=init_image, + ).images + + assert image.shape == (1, 32, 32, 3) + + @unittest.skipIf(torch_device == "cpu", "This test requires a GPU") + def test_stable_diffusion_inpaint(self): + """Test that stable diffusion inpaint works with fp16""" + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] + init_image = Image.fromarray(np.uint8(image)).convert("RGB") + mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128)) + + # put models in fp16 + unet = unet.half() + vae = vae.half() + bert = bert.half() + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionInpaintPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=self.dummy_safety_checker, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + image = sd_pipe( + [prompt], + generator=generator, + num_inference_steps=2, + output_type="np", + init_image=init_image, + mask_image=mask_image, + ).images + + assert image.shape == (1, 32, 32, 3) + class PipelineTesterMixin(unittest.TestCase): def tearDown(self): From 76a6b2afe9452835b121e3ef4105aea939afeed6 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 7 Oct 2022 16:39:16 +0200 Subject: [PATCH 7/8] fix test name --- tests/test_pipelines.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index eb8fb907eedb..2e894e97f9ef 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -1087,7 +1087,7 @@ def test_stable_diffusion_img2img_fp16(self): assert image.shape == (1, 32, 32, 3) @unittest.skipIf(torch_device == "cpu", "This test requires a GPU") - def test_stable_diffusion_inpaint(self): + def test_stable_diffusion_inpaint_fp16(self): """Test that stable diffusion inpaint works with fp16""" unet = self.dummy_cond_unet scheduler = PNDMScheduler(skip_prk_steps=True) From 9a1bd243af9f8bca4c67d9508e7765cac421c0b9 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 7 Oct 2022 16:56:14 +0200 Subject: [PATCH 8/8] put mask in fp16 --- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 47eb60f5d80c..30a588e754b3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -316,7 +316,7 @@ def __call__( # preprocess mask if not isinstance(mask_image, torch.FloatTensor): mask_image = preprocess_mask(mask_image) - mask_image = mask_image.to(self.device) + mask_image = mask_image.to(device=self.device, dtype=latents_dtype) mask = torch.cat([mask_image] * batch_size * num_images_per_prompt) # check sizes