From 54977e240768bc8d0e36e62a8cd843cb868441ea Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 12 Jul 2024 09:21:13 +0200 Subject: [PATCH 1/4] add pipeline documentation. --- .../pipelines/aura_flow/pipeline_aura_flow.py | 105 +++++++++++++++++- 1 file changed, 104 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index 73b149e853cf..645b52130678 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -21,7 +21,7 @@ from ...models import AuraFlowTransformer2DModel, AutoencoderKL from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import logging +from ...utils import logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -29,6 +29,22 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AuraFlowPipeline + + >>> pipe = AuraFlowPipeline.from_pretrained( + ... "fal/AuraFlow", torch_dtype=torch.float16 + ... ).to("cuda) + >>> prompt = "A cat holding a sign that says hello world" + >>> image = pipe(prompt).images[0] + >>> image.save("aura_flow.png") + ``` +""" + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -90,6 +106,22 @@ def retrieve_timesteps( class AuraFlowPipeline(DiffusionPipeline): + r""" + Args: + tokenizer (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. AuraFlow uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [EleutherAI/pile-t5-xl](https://huggingface.co/EleutherAI/pile-t5-xl) variant. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + transformer ([`AuraFlowTransformer2DModel`]): + Conditional Transformer (MMDiT and DiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ _optional_components = [] model_cpu_offload_seq = "text_encoder->transformer->vae" @@ -201,8 +233,12 @@ def encode_prompt( prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt. """ if device is None: @@ -345,6 +381,7 @@ def upcast_vae(self): self.vae.decoder.mid_block.to(dtype) @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, @@ -366,6 +403,72 @@ def __call__( output_type: Optional[str] = "pil", return_dict: bool = True, ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 512 by default. + width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 512 by default. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ # 1. Check inputs. Raise error if not correct height = height or self.transformer.config.sample_size * self.vae_scale_factor width = width or self.transformer.config.sample_size * self.vae_scale_factor From 2bb9934a72fe7f1c89a13b89f97d976a941c94e7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 12 Jul 2024 09:30:44 +0200 Subject: [PATCH 2/4] add api spec for pipeline --- docs/source/en/_toctree.yml | 4 ++- docs/source/en/api/pipelines/aura_flow.md | 29 +++++++++++++++++++ .../pipelines/aura_flow/pipeline_aura_flow.py | 17 +++++------ 3 files changed, 40 insertions(+), 10 deletions(-) create mode 100644 docs/source/en/api/pipelines/aura_flow.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 1a1a23e2938a..f375b2394872 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -280,6 +280,8 @@ title: AudioLDM - local: api/pipelines/audioldm2 title: AudioLDM 2 + - local: api/pipelines/aura_flow + title: AuraFlow - local: api/pipelines/auto_pipeline title: AutoPipeline - local: api/pipelines/blip_diffusion @@ -323,7 +325,7 @@ - local: api/pipelines/kandinsky3 title: Kandinsky 3 - local: api/pipelines/kolors - title: Kolors + title: Kolors - local: api/pipelines/latent_consistency_models title: Latent Consistency Models - local: api/pipelines/latent_diffusion diff --git a/docs/source/en/api/pipelines/aura_flow.md b/docs/source/en/api/pipelines/aura_flow.md new file mode 100644 index 000000000000..90b882051a12 --- /dev/null +++ b/docs/source/en/api/pipelines/aura_flow.md @@ -0,0 +1,29 @@ + + +# AuraFlow + +AuraFlow is inspired by [Stable Diffusion 3](../pipelines/stable_diffusion/stable_diffusion_3.md) and is by far the largest text-to-image generation model that comes with an Apache 2.0 license. This model achieves state-of-the-art results on the [GenEval](https://github.com/djghosh13/geneval) benchmark. + +It was developed by the Fal team and more details about it can be found in [this blog post](https://blog.fal.ai/auraflow/). + + + +AuraFlow can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. + + + +## AuraFlowPipeline + +[[autodoc]] AuraFlowPipeline + - all + - __call__ \ No newline at end of file diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index 645b52130678..47c765d5cbb5 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -35,9 +35,8 @@ >>> import torch >>> from diffusers import AuraFlowPipeline - >>> pipe = AuraFlowPipeline.from_pretrained( - ... "fal/AuraFlow", torch_dtype=torch.float16 - ... ).to("cuda) + >>> pipe = AuraFlowPipeline.from_pretrained("fal/AuraFlow", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") >>> prompt = "A cat holding a sign that says hello world" >>> image = pipe(prompt).images[0] >>> image.save("aura_flow.png") @@ -122,6 +121,7 @@ class AuraFlowPipeline(DiffusionPipeline): scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. """ + _optional_components = [] model_cpu_offload_seq = "text_encoder->transformer->vae" @@ -233,7 +233,7 @@ def encode_prompt( prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. - prompt_attention_mask (`torch.Tensor`, *optional*): + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. @@ -446,7 +446,7 @@ def __call__( prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. - prompt_attention_mask (`torch.Tensor`, *optional*): + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt @@ -464,10 +464,9 @@ def __call__( Examples: - Returns: - [`~pipelines.ImagePipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is - returned where the first element is a list with the generated images. + Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images. """ # 1. Check inputs. Raise error if not correct height = height or self.transformer.config.sample_size * self.vae_scale_factor From a0cd0054497277a7ff1415dd7139785e07b46510 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 12 Jul 2024 09:36:33 +0200 Subject: [PATCH 3/4] model documentation --- .../transformers/auraflow_transformer_2d.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index eb3b749c88c5..342373b4c11d 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -233,6 +233,26 @@ def forward( class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin): + r""" + A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/). + + Parameters: + sample_size (`int`): The width of the latent images. This is fixed during training since + it is used to learn a number of position embeddings. + patch_size (`int`): Patch size to turn the input data into small patches. + in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. + num_mmdit_layers (`int`, *optional*, defaults to 4): The number of layers of MMDiT Transformer blocks to use. + num_single_dit_layers (`int`, *optional*, defaults to 4): + The number of layers of Transformer blocks to use. These blocks use concatenated image and text + representations. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. + joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`. + out_channels (`int`, defaults to 16): Number of output channels. + pos_embed_max_size (`int`, defaults to 4096): Maximum positions to embed from the image latents. + """ + _supports_gradient_checkpointing = True @register_to_config From 5c0a74bee1ef3de918346ae4265e7862ae101f22 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 12 Jul 2024 09:39:23 +0200 Subject: [PATCH 4/4] model spec --- docs/source/en/_toctree.yml | 2 ++ .../en/api/models/aura_flow_transformer2d.md | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+) create mode 100644 docs/source/en/api/models/aura_flow_transformer2d.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index f375b2394872..4ef5740da7d2 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -249,6 +249,8 @@ title: DiTTransformer2DModel - local: api/models/hunyuan_transformer2d title: HunyuanDiT2DModel + - local: api/models/aura_flow_transformer2d + title: AuraFlowTransformer2DModel - local: api/models/latte_transformer3d title: LatteTransformer3DModel - local: api/models/lumina_nextdit2d diff --git a/docs/source/en/api/models/aura_flow_transformer2d.md b/docs/source/en/api/models/aura_flow_transformer2d.md new file mode 100644 index 000000000000..d07806bcc215 --- /dev/null +++ b/docs/source/en/api/models/aura_flow_transformer2d.md @@ -0,0 +1,19 @@ + + +# AuraFlowTransformer2DModel + +A Transformer model for image-like data from [AuraFlow](https://blog.fal.ai/auraflow/). + +## AuraFlowTransformer2DModel + +[[autodoc]] AuraFlowTransformer2DModel