Open
Description
Describe the bug
I have tested PixArt-Sigma with following code, where I load text_encoder separately since I will fine-tune it in later. I found T5EncoderModel.from_pretrained(torch_dtype=torch.float16)
is very different from T5EncoderModel.from_pretrained().to(dtype=torch.float16)
, the later one produces corrupted images.
What's happening when we pass torch_dtype
argument to from_pretrained?
Reproduction
from diffusers import PixArtSigmaPipeline
import torch
from transformers import T5EncoderModel
# text_encoder = T5EncoderModel.from_pretrained("PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", subfolder="text_encoder", torch_dtype=torch.float16) # good result
text_encoder = T5EncoderModel.from_pretrained("PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", subfolder="text_encoder").to(dtype=torch.float16) # noise
pipe = PixArtSigmaPipeline.from_pretrained(
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
text_encoder=text_encoder,
torch_dtype=torch.float16
)
pipe = pipe.to("cuda")
prompts = "a space elevator, cinematic scifi art"
for idx, prompt in enumerate(prompts):
image = pipe(prompt=prompt, num_inference_steps=50, generator=torch.manual_seed(1)).images[0]
image.save("x.png")
Logs
No response
System Info
- 🤗 Diffusers version: 0.29.0
- Platform: Linux-5.15.0-60-generic-x86_64-with-glibc2.35
- Running on a notebook?: No
- Running on Google Colab?: No
- Python version: 3.10.11
- PyTorch version (GPU?): 2.1.2+cu118 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.23.3
- Transformers version: 4.41.2
- Accelerate version: 0.23.0
- PEFT version: 0.7.0
- Bitsandbytes version: not installed
- Safetensors version: 0.4.2
- xFormers version: 0.0.23.post1+cu118
- Accelerator: NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB VRAM - Using GPU in script?:
- Using distributed or parallel set-up in script?: