Skip to content

Commit 166dfd3

Browse files
authored
[img2img, inpainting] fix fp16 inference (huggingface#769)
* handle dtype in vae and image2image pipeline * fix inpaint in fp16 * dtype should be handled in add_noise * style * address review comments * add simple fast tests to check fp16 * fix test name * put mask in fp16
1 parent db95a68 commit 166dfd3

File tree

3 files changed

+68
-60
lines changed

3 files changed

+68
-60
lines changed

models/vae.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,12 +337,16 @@ def __init__(self, parameters, deterministic=False):
337337
self.std = torch.exp(0.5 * self.logvar)
338338
self.var = torch.exp(self.logvar)
339339
if self.deterministic:
340-
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
340+
self.var = self.std = torch.zeros_like(
341+
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
342+
)
341343

342344
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
343345
device = self.parameters.device
344346
sample_device = "cpu" if device.type == "mps" else device
345-
sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device)
347+
sample = torch.randn(self.mean.shape, generator=generator, device=sample_device)
348+
# make sure sample is on the same device as the parameters and has same dtype
349+
sample = sample.to(device=device, dtype=self.parameters.dtype)
346350
x = self.mean + self.std * sample
347351
return x
348352

pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -217,26 +217,6 @@ def __call__(
217217
if isinstance(init_image, PIL.Image.Image):
218218
init_image = preprocess(init_image)
219219

220-
# encode the init image into latents and scale the latents
221-
init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
222-
init_latents = init_latent_dist.sample(generator=generator)
223-
init_latents = 0.18215 * init_latents
224-
225-
# expand init_latents for batch_size
226-
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
227-
228-
# get the original timestep using init_timestep
229-
offset = self.scheduler.config.get("steps_offset", 0)
230-
init_timestep = int(num_inference_steps * strength) + offset
231-
init_timestep = min(init_timestep, num_inference_steps)
232-
233-
timesteps = self.scheduler.timesteps[-init_timestep]
234-
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
235-
236-
# add noise to latents using the timesteps
237-
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
238-
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
239-
240220
# get prompt text embeddings
241221
text_inputs = self.tokenizer(
242222
prompt,
@@ -297,6 +277,28 @@ def __call__(
297277
# to avoid doing two forward passes
298278
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
299279

280+
# encode the init image into latents and scale the latents
281+
latents_dtype = text_embeddings.dtype
282+
init_image = init_image.to(device=self.device, dtype=latents_dtype)
283+
init_latent_dist = self.vae.encode(init_image).latent_dist
284+
init_latents = init_latent_dist.sample(generator=generator)
285+
init_latents = 0.18215 * init_latents
286+
287+
# expand init_latents for batch_size
288+
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
289+
290+
# get the original timestep using init_timestep
291+
offset = self.scheduler.config.get("steps_offset", 0)
292+
init_timestep = int(num_inference_steps * strength) + offset
293+
init_timestep = min(init_timestep, num_inference_steps)
294+
295+
timesteps = self.scheduler.timesteps[-init_timestep]
296+
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
297+
298+
# add noise to latents using the timesteps
299+
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
300+
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
301+
300302
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
301303
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
302304
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
@@ -341,7 +343,9 @@ def __call__(
341343
image = image.cpu().permute(0, 2, 3, 1).numpy()
342344

343345
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
344-
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
346+
image, has_nsfw_concept = self.safety_checker(
347+
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
348+
)
345349

346350
if output_type == "pil":
347351
image = self.numpy_to_pil(image)

pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -234,43 +234,6 @@ def __call__(
234234
# set timesteps
235235
self.scheduler.set_timesteps(num_inference_steps)
236236

237-
# preprocess image
238-
if not isinstance(init_image, torch.FloatTensor):
239-
init_image = preprocess_image(init_image)
240-
init_image = init_image.to(self.device)
241-
242-
# encode the init image into latents and scale the latents
243-
init_latent_dist = self.vae.encode(init_image).latent_dist
244-
init_latents = init_latent_dist.sample(generator=generator)
245-
246-
init_latents = 0.18215 * init_latents
247-
248-
# Expand init_latents for batch_size and num_images_per_prompt
249-
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
250-
init_latents_orig = init_latents
251-
252-
# preprocess mask
253-
if not isinstance(mask_image, torch.FloatTensor):
254-
mask_image = preprocess_mask(mask_image)
255-
mask_image = mask_image.to(self.device)
256-
mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)
257-
258-
# check sizes
259-
if not mask.shape == init_latents.shape:
260-
raise ValueError("The mask and init_image should be the same size!")
261-
262-
# get the original timestep using init_timestep
263-
offset = self.scheduler.config.get("steps_offset", 0)
264-
init_timestep = int(num_inference_steps * strength) + offset
265-
init_timestep = min(init_timestep, num_inference_steps)
266-
267-
timesteps = self.scheduler.timesteps[-init_timestep]
268-
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
269-
270-
# add noise to latents using the timesteps
271-
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
272-
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
273-
274237
# get prompt text embeddings
275238
text_inputs = self.tokenizer(
276239
prompt,
@@ -335,6 +298,43 @@ def __call__(
335298
# to avoid doing two forward passes
336299
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
337300

301+
# preprocess image
302+
if not isinstance(init_image, torch.FloatTensor):
303+
init_image = preprocess_image(init_image)
304+
305+
# encode the init image into latents and scale the latents
306+
latents_dtype = text_embeddings.dtype
307+
init_image = init_image.to(device=self.device, dtype=latents_dtype)
308+
init_latent_dist = self.vae.encode(init_image).latent_dist
309+
init_latents = init_latent_dist.sample(generator=generator)
310+
init_latents = 0.18215 * init_latents
311+
312+
# Expand init_latents for batch_size and num_images_per_prompt
313+
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
314+
init_latents_orig = init_latents
315+
316+
# preprocess mask
317+
if not isinstance(mask_image, torch.FloatTensor):
318+
mask_image = preprocess_mask(mask_image)
319+
mask_image = mask_image.to(device=self.device, dtype=latents_dtype)
320+
mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)
321+
322+
# check sizes
323+
if not mask.shape == init_latents.shape:
324+
raise ValueError("The mask and init_image should be the same size!")
325+
326+
# get the original timestep using init_timestep
327+
offset = self.scheduler.config.get("steps_offset", 0)
328+
init_timestep = int(num_inference_steps * strength) + offset
329+
init_timestep = min(init_timestep, num_inference_steps)
330+
331+
timesteps = self.scheduler.timesteps[-init_timestep]
332+
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
333+
334+
# add noise to latents using the timesteps
335+
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
336+
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
337+
338338
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
339339
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
340340
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502

0 commit comments

Comments
 (0)