Skip to content

Commit b2c9b54

Browse files
patil-surajpatrickvonplaten
authored andcommitted
[img2img, inpainting] fix fp16 inference (#769)
* handle dtype in vae and image2image pipeline * fix inpaint in fp16 * dtype should be handled in add_noise * style * address review comments * add simple fast tests to check fp16 * fix test name * put mask in fp16
1 parent 7a6cf89 commit b2c9b54

File tree

4 files changed

+186
-60
lines changed

4 files changed

+186
-60
lines changed

src/diffusers/models/vae.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,12 +337,16 @@ def __init__(self, parameters, deterministic=False):
337337
self.std = torch.exp(0.5 * self.logvar)
338338
self.var = torch.exp(self.logvar)
339339
if self.deterministic:
340-
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
340+
self.var = self.std = torch.zeros_like(
341+
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
342+
)
341343

342344
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
343345
device = self.parameters.device
344346
sample_device = "cpu" if device.type == "mps" else device
345-
sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device)
347+
sample = torch.randn(self.mean.shape, generator=generator, device=sample_device)
348+
# make sure sample is on the same device as the parameters and has same dtype
349+
sample = sample.to(device=device, dtype=self.parameters.dtype)
346350
x = self.mean + self.std * sample
347351
return x
348352

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -217,26 +217,6 @@ def __call__(
217217
if isinstance(init_image, PIL.Image.Image):
218218
init_image = preprocess(init_image)
219219

220-
# encode the init image into latents and scale the latents
221-
init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
222-
init_latents = init_latent_dist.sample(generator=generator)
223-
init_latents = 0.18215 * init_latents
224-
225-
# expand init_latents for batch_size
226-
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
227-
228-
# get the original timestep using init_timestep
229-
offset = self.scheduler.config.get("steps_offset", 0)
230-
init_timestep = int(num_inference_steps * strength) + offset
231-
init_timestep = min(init_timestep, num_inference_steps)
232-
233-
timesteps = self.scheduler.timesteps[-init_timestep]
234-
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
235-
236-
# add noise to latents using the timesteps
237-
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
238-
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
239-
240220
# get prompt text embeddings
241221
text_inputs = self.tokenizer(
242222
prompt,
@@ -297,6 +277,28 @@ def __call__(
297277
# to avoid doing two forward passes
298278
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
299279

280+
# encode the init image into latents and scale the latents
281+
latents_dtype = text_embeddings.dtype
282+
init_image = init_image.to(device=self.device, dtype=latents_dtype)
283+
init_latent_dist = self.vae.encode(init_image).latent_dist
284+
init_latents = init_latent_dist.sample(generator=generator)
285+
init_latents = 0.18215 * init_latents
286+
287+
# expand init_latents for batch_size
288+
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
289+
290+
# get the original timestep using init_timestep
291+
offset = self.scheduler.config.get("steps_offset", 0)
292+
init_timestep = int(num_inference_steps * strength) + offset
293+
init_timestep = min(init_timestep, num_inference_steps)
294+
295+
timesteps = self.scheduler.timesteps[-init_timestep]
296+
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
297+
298+
# add noise to latents using the timesteps
299+
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
300+
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
301+
300302
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
301303
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
302304
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
@@ -341,7 +343,9 @@ def __call__(
341343
image = image.cpu().permute(0, 2, 3, 1).numpy()
342344

343345
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
344-
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
346+
image, has_nsfw_concept = self.safety_checker(
347+
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
348+
)
345349

346350
if output_type == "pil":
347351
image = self.numpy_to_pil(image)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -234,43 +234,6 @@ def __call__(
234234
# set timesteps
235235
self.scheduler.set_timesteps(num_inference_steps)
236236

237-
# preprocess image
238-
if not isinstance(init_image, torch.FloatTensor):
239-
init_image = preprocess_image(init_image)
240-
init_image = init_image.to(self.device)
241-
242-
# encode the init image into latents and scale the latents
243-
init_latent_dist = self.vae.encode(init_image).latent_dist
244-
init_latents = init_latent_dist.sample(generator=generator)
245-
246-
init_latents = 0.18215 * init_latents
247-
248-
# Expand init_latents for batch_size and num_images_per_prompt
249-
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
250-
init_latents_orig = init_latents
251-
252-
# preprocess mask
253-
if not isinstance(mask_image, torch.FloatTensor):
254-
mask_image = preprocess_mask(mask_image)
255-
mask_image = mask_image.to(self.device)
256-
mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)
257-
258-
# check sizes
259-
if not mask.shape == init_latents.shape:
260-
raise ValueError("The mask and init_image should be the same size!")
261-
262-
# get the original timestep using init_timestep
263-
offset = self.scheduler.config.get("steps_offset", 0)
264-
init_timestep = int(num_inference_steps * strength) + offset
265-
init_timestep = min(init_timestep, num_inference_steps)
266-
267-
timesteps = self.scheduler.timesteps[-init_timestep]
268-
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
269-
270-
# add noise to latents using the timesteps
271-
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
272-
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
273-
274237
# get prompt text embeddings
275238
text_inputs = self.tokenizer(
276239
prompt,
@@ -335,6 +298,43 @@ def __call__(
335298
# to avoid doing two forward passes
336299
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
337300

301+
# preprocess image
302+
if not isinstance(init_image, torch.FloatTensor):
303+
init_image = preprocess_image(init_image)
304+
305+
# encode the init image into latents and scale the latents
306+
latents_dtype = text_embeddings.dtype
307+
init_image = init_image.to(device=self.device, dtype=latents_dtype)
308+
init_latent_dist = self.vae.encode(init_image).latent_dist
309+
init_latents = init_latent_dist.sample(generator=generator)
310+
init_latents = 0.18215 * init_latents
311+
312+
# Expand init_latents for batch_size and num_images_per_prompt
313+
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
314+
init_latents_orig = init_latents
315+
316+
# preprocess mask
317+
if not isinstance(mask_image, torch.FloatTensor):
318+
mask_image = preprocess_mask(mask_image)
319+
mask_image = mask_image.to(device=self.device, dtype=latents_dtype)
320+
mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)
321+
322+
# check sizes
323+
if not mask.shape == init_latents.shape:
324+
raise ValueError("The mask and init_image should be the same size!")
325+
326+
# get the original timestep using init_timestep
327+
offset = self.scheduler.config.get("steps_offset", 0)
328+
init_timestep = int(num_inference_steps * strength) + offset
329+
init_timestep = min(init_timestep, num_inference_steps)
330+
331+
timesteps = self.scheduler.timesteps[-init_timestep]
332+
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
333+
334+
# add noise to latents using the timesteps
335+
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
336+
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
337+
338338
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
339339
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
340340
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502

tests/test_pipelines.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,6 +1005,124 @@ def test_stable_diffusion_inpaint_num_images_per_prompt(self):
10051005

10061006
assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)
10071007

1008+
@unittest.skipIf(torch_device == "cpu", "This test requires a GPU")
1009+
def test_stable_diffusion_fp16(self):
1010+
"""Test that stable diffusion works with fp16"""
1011+
unet = self.dummy_cond_unet
1012+
scheduler = PNDMScheduler(skip_prk_steps=True)
1013+
vae = self.dummy_vae
1014+
bert = self.dummy_text_encoder
1015+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
1016+
1017+
# put models in fp16
1018+
unet = unet.half()
1019+
vae = vae.half()
1020+
bert = bert.half()
1021+
1022+
# make sure here that pndm scheduler skips prk
1023+
sd_pipe = StableDiffusionPipeline(
1024+
unet=unet,
1025+
scheduler=scheduler,
1026+
vae=vae,
1027+
text_encoder=bert,
1028+
tokenizer=tokenizer,
1029+
safety_checker=self.dummy_safety_checker,
1030+
feature_extractor=self.dummy_extractor,
1031+
)
1032+
sd_pipe = sd_pipe.to(torch_device)
1033+
sd_pipe.set_progress_bar_config(disable=None)
1034+
1035+
prompt = "A painting of a squirrel eating a burger"
1036+
generator = torch.Generator(device=torch_device).manual_seed(0)
1037+
image = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images
1038+
1039+
assert image.shape == (1, 128, 128, 3)
1040+
1041+
@unittest.skipIf(torch_device == "cpu", "This test requires a GPU")
1042+
def test_stable_diffusion_img2img_fp16(self):
1043+
"""Test that stable diffusion img2img works with fp16"""
1044+
unet = self.dummy_cond_unet
1045+
scheduler = PNDMScheduler(skip_prk_steps=True)
1046+
vae = self.dummy_vae
1047+
bert = self.dummy_text_encoder
1048+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
1049+
1050+
init_image = self.dummy_image.to(torch_device)
1051+
1052+
# put models in fp16
1053+
unet = unet.half()
1054+
vae = vae.half()
1055+
bert = bert.half()
1056+
1057+
# make sure here that pndm scheduler skips prk
1058+
sd_pipe = StableDiffusionImg2ImgPipeline(
1059+
unet=unet,
1060+
scheduler=scheduler,
1061+
vae=vae,
1062+
text_encoder=bert,
1063+
tokenizer=tokenizer,
1064+
safety_checker=self.dummy_safety_checker,
1065+
feature_extractor=self.dummy_extractor,
1066+
)
1067+
sd_pipe = sd_pipe.to(torch_device)
1068+
sd_pipe.set_progress_bar_config(disable=None)
1069+
1070+
prompt = "A painting of a squirrel eating a burger"
1071+
generator = torch.Generator(device=torch_device).manual_seed(0)
1072+
image = sd_pipe(
1073+
[prompt],
1074+
generator=generator,
1075+
num_inference_steps=2,
1076+
output_type="np",
1077+
init_image=init_image,
1078+
).images
1079+
1080+
assert image.shape == (1, 32, 32, 3)
1081+
1082+
@unittest.skipIf(torch_device == "cpu", "This test requires a GPU")
1083+
def test_stable_diffusion_inpaint_fp16(self):
1084+
"""Test that stable diffusion inpaint works with fp16"""
1085+
unet = self.dummy_cond_unet
1086+
scheduler = PNDMScheduler(skip_prk_steps=True)
1087+
vae = self.dummy_vae
1088+
bert = self.dummy_text_encoder
1089+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
1090+
1091+
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
1092+
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
1093+
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
1094+
1095+
# put models in fp16
1096+
unet = unet.half()
1097+
vae = vae.half()
1098+
bert = bert.half()
1099+
1100+
# make sure here that pndm scheduler skips prk
1101+
sd_pipe = StableDiffusionInpaintPipeline(
1102+
unet=unet,
1103+
scheduler=scheduler,
1104+
vae=vae,
1105+
text_encoder=bert,
1106+
tokenizer=tokenizer,
1107+
safety_checker=self.dummy_safety_checker,
1108+
feature_extractor=self.dummy_extractor,
1109+
)
1110+
sd_pipe = sd_pipe.to(torch_device)
1111+
sd_pipe.set_progress_bar_config(disable=None)
1112+
1113+
prompt = "A painting of a squirrel eating a burger"
1114+
generator = torch.Generator(device=torch_device).manual_seed(0)
1115+
image = sd_pipe(
1116+
[prompt],
1117+
generator=generator,
1118+
num_inference_steps=2,
1119+
output_type="np",
1120+
init_image=init_image,
1121+
mask_image=mask_image,
1122+
).images
1123+
1124+
assert image.shape == (1, 32, 32, 3)
1125+
10081126

10091127
class PipelineTesterMixin(unittest.TestCase):
10101128
def tearDown(self):

0 commit comments

Comments
 (0)