Skip to content

Commit 6ca5a58

Browse files
authored
[Community Pipeline] Batched implementation of Flux with CFG (#9513)
* batched implementation of flux cfg. * style. * readme * remove comments.
1 parent b52684c commit 6ca5a58

File tree

2 files changed

+110
-69
lines changed

2 files changed

+110
-69
lines changed

examples/community/README.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
1010

1111
| Example | Description | Code Example | Colab | Author |
1212
|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|
13+
|Flux with CFG|[Flux with CFG](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md) provides an implementation of using CFG in [Flux](https://blackforestlabs.ai/announcing-black-forest-labs/).|[Flux with CFG](#flux-with-cfg)|NA|[Linoy Tsaban](https://github.com/linoytsaban), [Apolinário](https://github.com/apolinario), and [Sayak Paul](https://github.com/sayakpaul)|
1314
|Differential Diffusion|[Differential Diffusion](https://github.com/exx8/differential-diffusion) modifies an image according to a text prompt, and according to a map that specifies the amount of change in each region.|[Differential Diffusion](#differential-diffusion)|[![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/exx8/differential-diffusion) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/exx8/differential-diffusion/blob/main/examples/SD2.ipynb)|[Eran Levin](https://github.com/exx8) and [Ohad Fried](https://www.ohadf.com/)|
1415
| HD-Painter | [HD-Painter](https://github.com/Picsart-AI-Research/HD-Painter) enables prompt-faithfull and high resolution (up to 2k) image inpainting upon any diffusion-based image inpainting method. | [HD-Painter](#hd-painter) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/PAIR/HD-Painter) | [Manukyan Hayk](https://github.com/haikmanukyan) and [Sargsyan Andranik](https://github.com/AndranikSargsyan) |
1516
| Marigold Monocular Depth Estimation | A universal monocular depth estimator, utilizing Stable Diffusion, delivering sharp predictions in the wild. (See the [project page](https://marigoldmonodepth.github.io) and [full codebase](https://github.com/prs-eth/marigold) for more details.) | [Marigold Depth Estimation](#marigold-depth-estimation) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/toshas/marigold) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/12G8reD13DdpMie5ZQlaFNo2WCGeNUH-u?usp=sharing) | [Bingxin Ke](https://github.com/markkua) and [Anton Obukhov](https://github.com/toshas) |
@@ -82,6 +83,36 @@ pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion
8283

8384
## Example usages
8485

86+
### Flux with CFG
87+
88+
Know more about Flux [here](https://blackforestlabs.ai/announcing-black-forest-labs/). Since Flux doesn't use CFG, this implementation provides one, inspired by the [PuLID Flux adaptation](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md).
89+
90+
Example usage:
91+
92+
```py
93+
from diffusers import DiffusionPipeline
94+
import torch
95+
96+
pipeline = DiffusionPipeline.from_pretrained(
97+
"black-forest-labs/FLUX.1-dev",
98+
torch_dtype=torch.bfloat16,
99+
custom_pipeline="pipeline_flux_with_cfg"
100+
)
101+
pipeline.enable_model_cpu_offload()
102+
prompt = "a watercolor painting of a unicorn"
103+
negative_prompt = "pink"
104+
105+
img = pipeline(
106+
prompt=prompt,
107+
negative_prompt=negative_prompt,
108+
true_cfg=1.5,
109+
guidance_scale=3.5,
110+
num_images_per_prompt=1,
111+
generator=torch.manual_seed(0)
112+
).images[0]
113+
img.save("cfg_flux.png")
114+
```
115+
85116
### Differential Diffusion
86117

87118
**Eran Levin, Ohad Fried**

examples/community/pipeline_flux_with_cfg.py

Lines changed: 79 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -289,80 +289,104 @@ def encode_prompt(
289289
self,
290290
prompt: Union[str, List[str]],
291291
prompt_2: Union[str, List[str]],
292+
negative_prompt: Union[str, List[str]] = None,
293+
negative_prompt_2: Union[str, List[str]] = None,
292294
device: Optional[torch.device] = None,
293295
num_images_per_prompt: int = 1,
294296
prompt_embeds: Optional[torch.FloatTensor] = None,
295297
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
298+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
299+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
296300
max_sequence_length: int = 512,
297301
lora_scale: Optional[float] = None,
302+
do_true_cfg: bool = False,
298303
):
299-
r"""
300-
301-
Args:
302-
prompt (`str` or `List[str]`, *optional*):
303-
prompt to be encoded
304-
prompt_2 (`str` or `List[str]`, *optional*):
305-
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
306-
used in all text-encoders
307-
device: (`torch.device`):
308-
torch device
309-
num_images_per_prompt (`int`):
310-
number of images that should be generated per prompt
311-
prompt_embeds (`torch.FloatTensor`, *optional*):
312-
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
313-
provided, text embeddings will be generated from `prompt` input argument.
314-
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
315-
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
316-
If not provided, pooled text embeddings will be generated from `prompt` input argument.
317-
lora_scale (`float`, *optional*):
318-
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
319-
"""
320304
device = device or self._execution_device
321305

322-
# set lora scale so that monkey patched LoRA
323-
# function of text encoder can correctly access it
306+
# Set LoRA scale if applicable
324307
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
325308
self._lora_scale = lora_scale
326309

327-
# dynamically adjust the LoRA scale
328310
if self.text_encoder is not None and USE_PEFT_BACKEND:
329311
scale_lora_layers(self.text_encoder, lora_scale)
330312
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
331313
scale_lora_layers(self.text_encoder_2, lora_scale)
332314

333315
prompt = [prompt] if isinstance(prompt, str) else prompt
316+
batch_size = len(prompt)
317+
318+
if do_true_cfg and negative_prompt is not None:
319+
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
320+
negative_batch_size = len(negative_prompt)
321+
322+
if negative_batch_size != batch_size:
323+
raise ValueError(
324+
f"Negative prompt batch size ({negative_batch_size}) does not match prompt batch size ({batch_size})"
325+
)
326+
327+
# Concatenate prompts
328+
prompts = prompt + negative_prompt
329+
prompts_2 = (
330+
prompt_2 + negative_prompt_2 if prompt_2 is not None and negative_prompt_2 is not None else None
331+
)
332+
else:
333+
prompts = prompt
334+
prompts_2 = prompt_2
334335

335336
if prompt_embeds is None:
336-
prompt_2 = prompt_2 or prompt
337-
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
337+
if prompts_2 is None:
338+
prompts_2 = prompts
338339

339-
# We only use the pooled prompt output from the CLIPTextModel
340+
# Get pooled prompt embeddings from CLIPTextModel
340341
pooled_prompt_embeds = self._get_clip_prompt_embeds(
341-
prompt=prompt,
342+
prompt=prompts,
342343
device=device,
343344
num_images_per_prompt=num_images_per_prompt,
344345
)
345346
prompt_embeds = self._get_t5_prompt_embeds(
346-
prompt=prompt_2,
347+
prompt=prompts_2,
347348
num_images_per_prompt=num_images_per_prompt,
348349
max_sequence_length=max_sequence_length,
349350
device=device,
350351
)
351352

353+
if do_true_cfg and negative_prompt is not None:
354+
# Split embeddings back into positive and negative parts
355+
total_batch_size = batch_size * num_images_per_prompt
356+
positive_indices = slice(0, total_batch_size)
357+
negative_indices = slice(total_batch_size, 2 * total_batch_size)
358+
359+
positive_pooled_prompt_embeds = pooled_prompt_embeds[positive_indices]
360+
negative_pooled_prompt_embeds = pooled_prompt_embeds[negative_indices]
361+
362+
positive_prompt_embeds = prompt_embeds[positive_indices]
363+
negative_prompt_embeds = prompt_embeds[negative_indices]
364+
365+
pooled_prompt_embeds = positive_pooled_prompt_embeds
366+
prompt_embeds = positive_prompt_embeds
367+
368+
# Unscale LoRA layers
352369
if self.text_encoder is not None:
353370
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
354-
# Retrieve the original scale by scaling back the LoRA layers
355371
unscale_lora_layers(self.text_encoder, lora_scale)
356372

357373
if self.text_encoder_2 is not None:
358374
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
359-
# Retrieve the original scale by scaling back the LoRA layers
360375
unscale_lora_layers(self.text_encoder_2, lora_scale)
361376

362377
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
363378
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
364379

365-
return prompt_embeds, pooled_prompt_embeds, text_ids
380+
if do_true_cfg and negative_prompt is not None:
381+
return (
382+
prompt_embeds,
383+
pooled_prompt_embeds,
384+
text_ids,
385+
negative_prompt_embeds,
386+
negative_pooled_prompt_embeds,
387+
)
388+
else:
389+
return prompt_embeds, pooled_prompt_embeds, text_ids, None, None
366390

367391
def check_inputs(
368392
self,
@@ -687,38 +711,33 @@ def __call__(
687711
lora_scale = (
688712
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
689713
)
714+
do_true_cfg = true_cfg > 1 and negative_prompt is not None
690715
(
691716
prompt_embeds,
692717
pooled_prompt_embeds,
693718
text_ids,
719+
negative_prompt_embeds,
720+
negative_pooled_prompt_embeds,
694721
) = self.encode_prompt(
695722
prompt=prompt,
696723
prompt_2=prompt_2,
724+
negative_prompt=negative_prompt,
725+
negative_prompt_2=negative_prompt_2,
697726
prompt_embeds=prompt_embeds,
698727
pooled_prompt_embeds=pooled_prompt_embeds,
728+
negative_prompt_embeds=negative_prompt_embeds,
729+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
699730
device=device,
700731
num_images_per_prompt=num_images_per_prompt,
701732
max_sequence_length=max_sequence_length,
702733
lora_scale=lora_scale,
734+
do_true_cfg=do_true_cfg,
703735
)
704736

705-
# perform "real" CFG as suggested for distilled Flux models in https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md
706-
do_true_cfg = true_cfg > 1 and negative_prompt is not None
707737
if do_true_cfg:
708-
(
709-
negative_prompt_embeds,
710-
negative_pooled_prompt_embeds,
711-
negative_text_ids,
712-
) = self.encode_prompt(
713-
prompt=negative_prompt,
714-
prompt_2=negative_prompt_2,
715-
prompt_embeds=negative_prompt_embeds,
716-
pooled_prompt_embeds=negative_pooled_prompt_embeds,
717-
device=device,
718-
num_images_per_prompt=num_images_per_prompt,
719-
max_sequence_length=max_sequence_length,
720-
lora_scale=lora_scale,
721-
)
738+
# Concatenate embeddings
739+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
740+
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
722741

723742
# 4. Prepare latent variables
724743
num_channels_latents = self.transformer.config.in_channels // 4
@@ -754,24 +773,26 @@ def __call__(
754773
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
755774
self._num_timesteps = len(timesteps)
756775

757-
# handle guidance
758-
if self.transformer.config.guidance_embeds:
759-
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
760-
guidance = guidance.expand(latents.shape[0])
761-
else:
762-
guidance = None
763-
764776
# 6. Denoising loop
765777
with self.progress_bar(total=num_inference_steps) as progress_bar:
766778
for i, t in enumerate(timesteps):
767779
if self.interrupt:
768780
continue
769781

782+
latent_model_input = torch.cat([latents] * 2) if do_true_cfg else latents
783+
784+
# handle guidance
785+
if self.transformer.config.guidance_embeds:
786+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
787+
guidance = guidance.expand(latent_model_input.shape[0])
788+
else:
789+
guidance = None
790+
770791
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
771-
timestep = t.expand(latents.shape[0]).to(latents.dtype)
792+
timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
772793

773794
noise_pred = self.transformer(
774-
hidden_states=latents,
795+
hidden_states=latent_model_input,
775796
timestep=timestep / 1000,
776797
guidance=guidance,
777798
pooled_projections=pooled_prompt_embeds,
@@ -783,18 +804,7 @@ def __call__(
783804
)[0]
784805

785806
if do_true_cfg:
786-
neg_noise_pred = self.transformer(
787-
hidden_states=latents,
788-
timestep=timestep / 1000,
789-
guidance=guidance,
790-
pooled_projections=negative_pooled_prompt_embeds,
791-
encoder_hidden_states=negative_prompt_embeds,
792-
txt_ids=negative_text_ids,
793-
img_ids=latent_image_ids,
794-
joint_attention_kwargs=self.joint_attention_kwargs,
795-
return_dict=False,
796-
)[0]
797-
807+
neg_noise_pred, noise_pred = noise_pred.chunk(2)
798808
noise_pred = neg_noise_pred + true_cfg * (noise_pred - neg_noise_pred)
799809

800810
# compute the previous noisy sample x_t -> x_t-1

0 commit comments

Comments
 (0)