Skip to content

Commit 983991b

Browse files
sayakpaulDisty0
authored andcommitted
[Docs] add AuraFlow docs (huggingface#8851)
* add pipeline documentation. * add api spec for pipeline * model documentation * model spec
1 parent 35c475f commit 983991b

File tree

5 files changed

+176
-2
lines changed

5 files changed

+176
-2
lines changed

docs/source/en/_toctree.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,8 @@
249249
title: DiTTransformer2DModel
250250
- local: api/models/hunyuan_transformer2d
251251
title: HunyuanDiT2DModel
252+
- local: api/models/aura_flow_transformer2d
253+
title: AuraFlowTransformer2DModel
252254
- local: api/models/latte_transformer3d
253255
title: LatteTransformer3DModel
254256
- local: api/models/lumina_nextdit2d
@@ -280,6 +282,8 @@
280282
title: AudioLDM
281283
- local: api/pipelines/audioldm2
282284
title: AudioLDM 2
285+
- local: api/pipelines/aura_flow
286+
title: AuraFlow
283287
- local: api/pipelines/auto_pipeline
284288
title: AutoPipeline
285289
- local: api/pipelines/blip_diffusion
@@ -323,7 +327,7 @@
323327
- local: api/pipelines/kandinsky3
324328
title: Kandinsky 3
325329
- local: api/pipelines/kolors
326-
title: Kolors
330+
title: Kolors
327331
- local: api/pipelines/latent_consistency_models
328332
title: Latent Consistency Models
329333
- local: api/pipelines/latent_diffusion
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# AuraFlowTransformer2DModel
14+
15+
A Transformer model for image-like data from [AuraFlow](https://blog.fal.ai/auraflow/).
16+
17+
## AuraFlowTransformer2DModel
18+
19+
[[autodoc]] AuraFlowTransformer2DModel
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# AuraFlow
14+
15+
AuraFlow is inspired by [Stable Diffusion 3](../pipelines/stable_diffusion/stable_diffusion_3.md) and is by far the largest text-to-image generation model that comes with an Apache 2.0 license. This model achieves state-of-the-art results on the [GenEval](https://github.com/djghosh13/geneval) benchmark.
16+
17+
It was developed by the Fal team and more details about it can be found in [this blog post](https://blog.fal.ai/auraflow/).
18+
19+
<Tip>
20+
21+
AuraFlow can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details.
22+
23+
</Tip>
24+
25+
## AuraFlowPipeline
26+
27+
[[autodoc]] AuraFlowPipeline
28+
- all
29+
- __call__

src/diffusers/models/transformers/auraflow_transformer_2d.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,26 @@ def forward(
233233

234234

235235
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
236+
r"""
237+
A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).
238+
239+
Parameters:
240+
sample_size (`int`): The width of the latent images. This is fixed during training since
241+
it is used to learn a number of position embeddings.
242+
patch_size (`int`): Patch size to turn the input data into small patches.
243+
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
244+
num_mmdit_layers (`int`, *optional*, defaults to 4): The number of layers of MMDiT Transformer blocks to use.
245+
num_single_dit_layers (`int`, *optional*, defaults to 4):
246+
The number of layers of Transformer blocks to use. These blocks use concatenated image and text
247+
representations.
248+
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
249+
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
250+
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
251+
caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
252+
out_channels (`int`, defaults to 16): Number of output channels.
253+
pos_embed_max_size (`int`, defaults to 4096): Maximum positions to embed from the image latents.
254+
"""
255+
236256
_supports_gradient_checkpointing = True
237257

238258
@register_to_config

src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,29 @@
2121
from ...models import AuraFlowTransformer2DModel, AutoencoderKL
2222
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
2323
from ...schedulers import FlowMatchEulerDiscreteScheduler
24-
from ...utils import logging
24+
from ...utils import logging, replace_example_docstring
2525
from ...utils.torch_utils import randn_tensor
2626
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
2727

2828

2929
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3030

3131

32+
EXAMPLE_DOC_STRING = """
33+
Examples:
34+
```py
35+
>>> import torch
36+
>>> from diffusers import AuraFlowPipeline
37+
38+
>>> pipe = AuraFlowPipeline.from_pretrained("fal/AuraFlow", torch_dtype=torch.float16)
39+
>>> pipe = pipe.to("cuda")
40+
>>> prompt = "A cat holding a sign that says hello world"
41+
>>> image = pipe(prompt).images[0]
42+
>>> image.save("aura_flow.png")
43+
```
44+
"""
45+
46+
3247
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
3348
def retrieve_timesteps(
3449
scheduler,
@@ -90,6 +105,23 @@ def retrieve_timesteps(
90105

91106

92107
class AuraFlowPipeline(DiffusionPipeline):
108+
r"""
109+
Args:
110+
tokenizer (`T5TokenizerFast`):
111+
Tokenizer of class
112+
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
113+
text_encoder ([`T5EncoderModel`]):
114+
Frozen text-encoder. AuraFlow uses
115+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
116+
[EleutherAI/pile-t5-xl](https://huggingface.co/EleutherAI/pile-t5-xl) variant.
117+
vae ([`AutoencoderKL`]):
118+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
119+
transformer ([`AuraFlowTransformer2DModel`]):
120+
Conditional Transformer (MMDiT and DiT) architecture to denoise the encoded image latents.
121+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
122+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
123+
"""
124+
93125
_optional_components = []
94126
model_cpu_offload_seq = "text_encoder->transformer->vae"
95127

@@ -201,8 +233,12 @@ def encode_prompt(
201233
prompt_embeds (`torch.Tensor`, *optional*):
202234
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
203235
provided, text embeddings will be generated from `prompt` input argument.
236+
prompt_attention_mask (`torch.Tensor`, *optional*):
237+
Pre-generated attention mask for text embeddings.
204238
negative_prompt_embeds (`torch.Tensor`, *optional*):
205239
Pre-generated negative text embeddings.
240+
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
241+
Pre-generated attention mask for negative text embeddings.
206242
max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt.
207243
"""
208244
if device is None:
@@ -345,6 +381,7 @@ def upcast_vae(self):
345381
self.vae.decoder.mid_block.to(dtype)
346382

347383
@torch.no_grad()
384+
@replace_example_docstring(EXAMPLE_DOC_STRING)
348385
def __call__(
349386
self,
350387
prompt: Union[str, List[str]] = None,
@@ -366,6 +403,71 @@ def __call__(
366403
output_type: Optional[str] = "pil",
367404
return_dict: bool = True,
368405
) -> Union[ImagePipelineOutput, Tuple]:
406+
r"""
407+
Function invoked when calling the pipeline for generation.
408+
409+
Args:
410+
prompt (`str` or `List[str]`, *optional*):
411+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
412+
instead.
413+
negative_prompt (`str` or `List[str]`, *optional*):
414+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
415+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
416+
less than `1`).
417+
height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
418+
The height in pixels of the generated image. This is set to 512 by default.
419+
width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
420+
The width in pixels of the generated image. This is set to 512 by default.
421+
num_inference_steps (`int`, *optional*, defaults to 50):
422+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
423+
expense of slower inference.
424+
sigmas (`List[float]`, *optional*):
425+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
426+
`num_inference_steps` and `timesteps` must be `None`.
427+
timesteps (`List[int]`, *optional*):
428+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
429+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
430+
passed will be used. Must be in descending order.
431+
guidance_scale (`float`, *optional*, defaults to 5.0):
432+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
433+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
434+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
435+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
436+
usually at the expense of lower image quality.
437+
num_images_per_prompt (`int`, *optional*, defaults to 1):
438+
The number of images to generate per prompt.
439+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
440+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
441+
to make generation deterministic.
442+
latents (`torch.FloatTensor`, *optional*):
443+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
444+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
445+
tensor will ge generated by sampling using the supplied random `generator`.
446+
prompt_embeds (`torch.FloatTensor`, *optional*):
447+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
448+
provided, text embeddings will be generated from `prompt` input argument.
449+
prompt_attention_mask (`torch.Tensor`, *optional*):
450+
Pre-generated attention mask for text embeddings.
451+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
452+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
453+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
454+
argument.
455+
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
456+
Pre-generated attention mask for negative text embeddings.
457+
output_type (`str`, *optional*, defaults to `"pil"`):
458+
The output format of the generate image. Choose between
459+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
460+
return_dict (`bool`, *optional*, defaults to `True`):
461+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
462+
of a plain tuple.
463+
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
464+
465+
Examples:
466+
467+
Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`:
468+
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned
469+
where the first element is a list with the generated images.
470+
"""
369471
# 1. Check inputs. Raise error if not correct
370472
height = height or self.transformer.config.sample_size * self.vae_scale_factor
371473
width = width or self.transformer.config.sample_size * self.vae_scale_factor

0 commit comments

Comments
 (0)