Skip to content

[Flux] fix: encode_prompt when called separately. #9049

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

sayakpaul
Copy link
Member

What does this PR do?

This PR allows the Flux pipeline to run under 24GBs of VRAM.

Code:

from diffusers import FluxPipeline, AutoencoderKL
from diffusers.image_processor import VaeImageProcessor
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch 
import gc

def flush():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()

def bytes_to_giga_bytes(bytes):
    return bytes / 1024 / 1024 / 1024

flush()

ckpt_id = "black-forest-labs/FLUX.1-schnell"
prompt = "a photo of a dog with cat-like look"

text_encoder = CLIPTextModel.from_pretrained(ckpt_id, revision="refs/pr/1", subfolder="text_encoder", torch_dtype=torch.bfloat16)
text_encoder_2 = T5EncoderModel.from_pretrained(ckpt_id, revision="refs/pr/1", subfolder="text_encoder_2", torch_dtype=torch.bfloat16)
tokenizer = CLIPTokenizer.from_pretrained(ckpt_id, subfolder="tokenizer", revision="refs/pr/1")
tokenizer_2 = T5TokenizerFast.from_pretrained(ckpt_id, subfolder="tokenizer_2", revision="refs/pr/1")

pipeline = FluxPipeline.from_pretrained(
    ckpt_id, text_encoder=text_encoder, text_encoder_2=text_encoder_2,
    tokenizer=tokenizer, tokenizer_2=tokenizer_2, transformer=None, vae=None,
    revision="refs/pr/1"
).to("cuda")

with torch.no_grad():
    print("Encoding prompts.")
    prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(prompt=prompt, prompt_2=None, max_sequence_length=256)

del text_encoder
del text_encoder_2
del tokenizer
del tokenizer_2
del pipeline

flush()

pipeline = FluxPipeline.from_pretrained(
    ckpt_id, text_encoder=None, text_encoder_2=None,
    tokenizer=None, tokenizer_2=None, vae=None,
    revision="refs/pr/1",
    torch_dtype=torch.bfloat16
).to("cuda")

print("Running denoising.")
height, width = 768, 1360
# No need to wrap it up under `torch.no_grad()` as pipeline call method
# is already wrapped under that.
latents = pipeline(
    prompt_embeds=prompt_embeds, 
    pooled_prompt_embeds=pooled_prompt_embeds,
    num_inference_steps=4, guidance_scale=0.0, 
    height=height, width=width,
    output_type="latent"
).images
print(f"{latents.shape=}")

del pipeline.transformer
del pipeline

flush()

vae = AutoencoderKL.from_pretrained(
    ckpt_id, 
    revision="refs/pr/1",
    subfolder="vae",
    torch_dtype=torch.bfloat16
).to("cuda")
vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)

with torch.no_grad():
    print("Running decoding.")
    latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor)
    latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor

    image = vae.decode(latents, return_dict=False)[0]
    image = image_processor.postprocess(image, output_type="pil")
    image[0].save("image.png")

@sayakpaul sayakpaul requested a review from DN6 August 2, 2024 03:48
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul
Copy link
Member Author

Closing in favor of #9050.

@sayakpaul sayakpaul closed this Aug 2, 2024
@sayakpaul sayakpaul deleted the encode-prompt-flux-fix branch August 2, 2024 06:21
@tin2tin
Copy link

tin2tin commented Aug 2, 2024

@sayakpaul
In spite of the #9048 commit, I'm still getting a dtype error running the above script:

Downloading shards: 100%|████████████████████████████████████████████████████████████████████████| 2/2 [00:00<?, ?it/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.97it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████| 5/5 [00:00<00:00, 4995.60it/s]
Encoding prompts.
Loading pipeline components...: 100%|████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.83it/s]
Running denoising.
Error: Python: Traceback (most recent call last):
  File "python\Lib\site-packages\torch\utils\_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "...\AppData\Roaming\Python\Python311\site-packages\diffusers\pipelines\flux\pipeline_flux.py", line 645, in __call__
    pooled_prompt_embeds,
  File "...\AppData\Roaming\Python\Python311\site-packages\diffusers\pipelines\flux\pipeline_flux.py", line 378, in encode_prompt
    dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
                                                                                          ^^^^^^^
AttributeError: 'NoneType' object has no attribute 'dtype'```

@sayakpaul
Copy link
Member Author

Can you try installing from main and use this snippet?

@tin2tin
Copy link

tin2tin commented Aug 2, 2024

My bad. After a reboot, the previous version of Diffusers was released from mem, and it is working with the main diffusers running. Thank you.
Your code brings down the inference time from 5+ minutes on a 4090 to 11 sec!

@sayakpaul
Copy link
Member Author

Share it out with your network so that more people are aware :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants