Skip to content

Commit 797b290

Browse files
authored
support bf16 for stable diffusion (#792)
* support bf16 for stable diffusion * fix typo * address review comments
1 parent 81bdbb5 commit 797b290

File tree

3 files changed

+17
-3
lines changed

3 files changed

+17
-3
lines changed

src/diffusers/models/resnet.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,24 @@ def forward(self, hidden_states, output_size=None):
4141
if self.use_conv_transpose:
4242
return self.conv(hidden_states)
4343

44+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
45+
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
46+
# https://github.com/pytorch/pytorch/issues/86679
47+
dtype = hidden_states.dtype
48+
if dtype == torch.bfloat16:
49+
hidden_states = hidden_states.to(torch.float32)
50+
4451
# if `output_size` is passed we force the interpolation output
4552
# size and do not make use of `scale_factor=2`
4653
if output_size is None:
4754
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
4855
else:
4956
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
5057

58+
# If the input is bfloat16, we cast back to bfloat16
59+
if dtype == torch.bfloat16:
60+
hidden_states = hidden_states.to(dtype)
61+
5162
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
5263
if self.use_conv:
5364
if self.name == "conv":

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,9 @@ def __call__(
327327
image = self.vae.decode(latents).sample
328328

329329
image = (image / 2 + 0.5).clamp(0, 1)
330-
image = image.cpu().permute(0, 2, 3, 1).numpy()
330+
331+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
332+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
331333

332334
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
333335
image, has_nsfw_concept = self.safety_checker(

src/diffusers/pipelines/stable_diffusion/safety_checker.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@ def forward(self, clip_input, images):
3838
pooled_output = self.vision_model(clip_input)[1] # pooled_output
3939
image_embeds = self.visual_projection(pooled_output)
4040

41-
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().numpy()
42-
cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().numpy()
41+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
42+
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy()
43+
cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy()
4344

4445
result = []
4546
batch_size = image_embeds.shape[0]

0 commit comments

Comments
 (0)