21
21
from ...models import AuraFlowTransformer2DModel , AutoencoderKL
22
22
from ...models .attention_processor import AttnProcessor2_0 , FusedAttnProcessor2_0 , XFormersAttnProcessor
23
23
from ...schedulers import FlowMatchEulerDiscreteScheduler
24
- from ...utils import logging
24
+ from ...utils import logging , replace_example_docstring
25
25
from ...utils .torch_utils import randn_tensor
26
26
from ..pipeline_utils import DiffusionPipeline , ImagePipelineOutput
27
27
28
28
29
29
logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
30
30
31
31
32
+ EXAMPLE_DOC_STRING = """
33
+ Examples:
34
+ ```py
35
+ >>> import torch
36
+ >>> from diffusers import AuraFlowPipeline
37
+
38
+ >>> pipe = AuraFlowPipeline.from_pretrained("fal/AuraFlow", torch_dtype=torch.float16)
39
+ >>> pipe = pipe.to("cuda")
40
+ >>> prompt = "A cat holding a sign that says hello world"
41
+ >>> image = pipe(prompt).images[0]
42
+ >>> image.save("aura_flow.png")
43
+ ```
44
+ """
45
+
46
+
32
47
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
33
48
def retrieve_timesteps (
34
49
scheduler ,
@@ -90,6 +105,23 @@ def retrieve_timesteps(
90
105
91
106
92
107
class AuraFlowPipeline (DiffusionPipeline ):
108
+ r"""
109
+ Args:
110
+ tokenizer (`T5TokenizerFast`):
111
+ Tokenizer of class
112
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
113
+ text_encoder ([`T5EncoderModel`]):
114
+ Frozen text-encoder. AuraFlow uses
115
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
116
+ [EleutherAI/pile-t5-xl](https://huggingface.co/EleutherAI/pile-t5-xl) variant.
117
+ vae ([`AutoencoderKL`]):
118
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
119
+ transformer ([`AuraFlowTransformer2DModel`]):
120
+ Conditional Transformer (MMDiT and DiT) architecture to denoise the encoded image latents.
121
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
122
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
123
+ """
124
+
93
125
_optional_components = []
94
126
model_cpu_offload_seq = "text_encoder->transformer->vae"
95
127
@@ -201,8 +233,12 @@ def encode_prompt(
201
233
prompt_embeds (`torch.Tensor`, *optional*):
202
234
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
203
235
provided, text embeddings will be generated from `prompt` input argument.
236
+ prompt_attention_mask (`torch.Tensor`, *optional*):
237
+ Pre-generated attention mask for text embeddings.
204
238
negative_prompt_embeds (`torch.Tensor`, *optional*):
205
239
Pre-generated negative text embeddings.
240
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
241
+ Pre-generated attention mask for negative text embeddings.
206
242
max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt.
207
243
"""
208
244
if device is None :
@@ -345,6 +381,7 @@ def upcast_vae(self):
345
381
self .vae .decoder .mid_block .to (dtype )
346
382
347
383
@torch .no_grad ()
384
+ @replace_example_docstring (EXAMPLE_DOC_STRING )
348
385
def __call__ (
349
386
self ,
350
387
prompt : Union [str , List [str ]] = None ,
@@ -366,6 +403,71 @@ def __call__(
366
403
output_type : Optional [str ] = "pil" ,
367
404
return_dict : bool = True ,
368
405
) -> Union [ImagePipelineOutput , Tuple ]:
406
+ r"""
407
+ Function invoked when calling the pipeline for generation.
408
+
409
+ Args:
410
+ prompt (`str` or `List[str]`, *optional*):
411
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
412
+ instead.
413
+ negative_prompt (`str` or `List[str]`, *optional*):
414
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
415
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
416
+ less than `1`).
417
+ height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
418
+ The height in pixels of the generated image. This is set to 512 by default.
419
+ width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
420
+ The width in pixels of the generated image. This is set to 512 by default.
421
+ num_inference_steps (`int`, *optional*, defaults to 50):
422
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
423
+ expense of slower inference.
424
+ sigmas (`List[float]`, *optional*):
425
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
426
+ `num_inference_steps` and `timesteps` must be `None`.
427
+ timesteps (`List[int]`, *optional*):
428
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
429
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
430
+ passed will be used. Must be in descending order.
431
+ guidance_scale (`float`, *optional*, defaults to 5.0):
432
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
433
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
434
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
435
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
436
+ usually at the expense of lower image quality.
437
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
438
+ The number of images to generate per prompt.
439
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
440
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
441
+ to make generation deterministic.
442
+ latents (`torch.FloatTensor`, *optional*):
443
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
444
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
445
+ tensor will ge generated by sampling using the supplied random `generator`.
446
+ prompt_embeds (`torch.FloatTensor`, *optional*):
447
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
448
+ provided, text embeddings will be generated from `prompt` input argument.
449
+ prompt_attention_mask (`torch.Tensor`, *optional*):
450
+ Pre-generated attention mask for text embeddings.
451
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
452
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
453
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
454
+ argument.
455
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
456
+ Pre-generated attention mask for negative text embeddings.
457
+ output_type (`str`, *optional*, defaults to `"pil"`):
458
+ The output format of the generate image. Choose between
459
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
460
+ return_dict (`bool`, *optional*, defaults to `True`):
461
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
462
+ of a plain tuple.
463
+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
464
+
465
+ Examples:
466
+
467
+ Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`:
468
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned
469
+ where the first element is a list with the generated images.
470
+ """
369
471
# 1. Check inputs. Raise error if not correct
370
472
height = height or self .transformer .config .sample_size * self .vae_scale_factor
371
473
width = width or self .transformer .config .sample_size * self .vae_scale_factor
0 commit comments