diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 3bd7f1987a00..fc3022cf7b35 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -79,6 +79,8 @@ - sections: - local: using-diffusers/cogvideox title: CogVideoX + - local: using-diffusers/consisid + title: ConsisID - local: using-diffusers/sdxl title: Stable Diffusion XL - local: using-diffusers/sdxl_turbo @@ -270,6 +272,8 @@ title: AuraFlowTransformer2DModel - local: api/models/cogvideox_transformer3d title: CogVideoXTransformer3DModel + - local: api/models/consisid_transformer3d + title: ConsisIDTransformer3DModel - local: api/models/cogview3plus_transformer2d title: CogView3PlusTransformer2DModel - local: api/models/dit_transformer2d @@ -372,6 +376,8 @@ title: CogVideoX - local: api/pipelines/cogview3 title: CogView3 + - local: api/pipelines/consisid + title: ConsisID - local: api/pipelines/consistency_models title: Consistency Models - local: api/pipelines/controlnet diff --git a/docs/source/en/api/models/consisid_transformer3d.md b/docs/source/en/api/models/consisid_transformer3d.md new file mode 100644 index 000000000000..bca03c099b1d --- /dev/null +++ b/docs/source/en/api/models/consisid_transformer3d.md @@ -0,0 +1,30 @@ + + +# ConsisIDTransformer3DModel + +A Diffusion Transformer model for 3D data from [ConsisID](https://github.com/PKU-YuanGroup/ConsisID) was introduced in [Identity-Preserving Text-to-Video Generation by Frequency Decomposition](https://arxiv.org/pdf/2411.17440) by Peking University & University of Rochester & etc. + +The model can be loaded with the following code snippet. + +```python +from diffusers import ConsisIDTransformer3DModel + +transformer = ConsisIDTransformer3DModel.from_pretrained("BestWishYsh/ConsisID-preview", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda") +``` + +## ConsisIDTransformer3DModel + +[[autodoc]] ConsisIDTransformer3DModel + +## Transformer2DModelOutput + +[[autodoc]] models.modeling_outputs.Transformer2DModelOutput diff --git a/docs/source/en/api/pipelines/consisid.md b/docs/source/en/api/pipelines/consisid.md new file mode 100644 index 000000000000..29ef3150f42d --- /dev/null +++ b/docs/source/en/api/pipelines/consisid.md @@ -0,0 +1,60 @@ + + +# ConsisID + +[Identity-Preserving Text-to-Video Generation by Frequency Decomposition](https://arxiv.org/abs/2411.17440) from Peking University & University of Rochester & etc, by Shenghai Yuan, Jinfa Huang, Xianyi He, Yunyang Ge, Yujun Shi, Liuhan Chen, Jiebo Luo, Li Yuan. + +The abstract from the paper is: + +*Identity-preserving text-to-video (IPT2V) generation aims to create high-fidelity videos with consistent human identity. It is an important task in video generation but remains an open problem for generative models. This paper pushes the technical frontier of IPT2V in two directions that have not been resolved in the literature: (1) A tuning-free pipeline without tedious case-by-case finetuning, and (2) A frequency-aware heuristic identity-preserving Diffusion Transformer (DiT)-based control scheme. To achieve these goals, we propose **ConsisID**, a tuning-free DiT-based controllable IPT2V model to keep human-**id**entity **consis**tent in the generated video. Inspired by prior findings in frequency analysis of vision/diffusion transformers, it employs identity-control signals in the frequency domain, where facial features can be decomposed into low-frequency global features (e.g., profile, proportions) and high-frequency intrinsic features (e.g., identity markers that remain unaffected by pose changes). First, from a low-frequency perspective, we introduce a global facial extractor, which encodes the reference image and facial key points into a latent space, generating features enriched with low-frequency information. These features are then integrated into the shallow layers of the network to alleviate training challenges associated with DiT. Second, from a high-frequency perspective, we design a local facial extractor to capture high-frequency details and inject them into the transformer blocks, enhancing the model's ability to preserve fine-grained features. To leverage the frequency information for identity preservation, we propose a hierarchical training strategy, transforming a vanilla pre-trained video generation model into an IPT2V model. Extensive experiments demonstrate that our frequency-aware heuristic scheme provides an optimal control solution for DiT-based models. Thanks to this scheme, our **ConsisID** achieves excellent results in generating high-quality, identity-preserving videos, making strides towards more effective IPT2V. The model weight of ConsID is publicly available at https://github.com/PKU-YuanGroup/ConsisID.* + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. + + + +This pipeline was contributed by [SHYuanBest](https://github.com/SHYuanBest). The original codebase can be found [here](https://github.com/PKU-YuanGroup/ConsisID). The original weights can be found under [hf.co/BestWishYsh](https://huggingface.co/BestWishYsh). + +There are two official ConsisID checkpoints for identity-preserving text-to-video. + +| checkpoints | recommended inference dtype | +|:---:|:---:| +| [`BestWishYsh/ConsisID-preview`](https://huggingface.co/BestWishYsh/ConsisID-preview) | torch.bfloat16 | +| [`BestWishYsh/ConsisID-1.5`](https://huggingface.co/BestWishYsh/ConsisID-preview) | torch.bfloat16 | + +### Memory optimization + +ConsisID requires about 44 GB of GPU memory to decode 49 frames (6 seconds of video at 8 FPS) with output resolution 720x480 (W x H), which makes it not possible to run on consumer GPUs or free-tier T4 Colab. The following memory optimizations could be used to reduce the memory footprint. For replication, you can refer to [this](https://gist.github.com/SHYuanBest/bc4207c36f454f9e969adbb50eaf8258) script. + +| Feature (overlay the previous) | Max Memory Allocated | Max Memory Reserved | +| :----------------------------- | :------------------- | :------------------ | +| - | 37 GB | 44 GB | +| enable_model_cpu_offload | 22 GB | 25 GB | +| enable_sequential_cpu_offload | 16 GB | 22 GB | +| vae.enable_slicing | 16 GB | 22 GB | +| vae.enable_tiling | 5 GB | 7 GB | + +## ConsisIDPipeline + +[[autodoc]] ConsisIDPipeline + + - all + - __call__ + +## ConsisIDPipelineOutput + +[[autodoc]] pipelines.consisid.pipeline_output.ConsisIDPipelineOutput diff --git a/docs/source/en/using-diffusers/consisid.md b/docs/source/en/using-diffusers/consisid.md new file mode 100644 index 000000000000..07c13c4c66b3 --- /dev/null +++ b/docs/source/en/using-diffusers/consisid.md @@ -0,0 +1,96 @@ + +# ConsisID + +[ConsisID](https://github.com/PKU-YuanGroup/ConsisID) is an identity-preserving text-to-video generation model that keeps the face consistent in the generated video by frequency decomposition. The main features of ConsisID are: + +- Frequency decomposition: The characteristics of the DiT architecture are analyzed from the frequency domain perspective, and based on these characteristics, a reasonable control information injection method is designed. +- Consistency training strategy: A coarse-to-fine training strategy, dynamic masking loss, and dynamic cross-face loss further enhance the model's generalization ability and identity preservation performance. +- Inference without finetuning: Previous methods required case-by-case finetuning of the input ID before inference, leading to significant time and computational costs. In contrast, ConsisID is tuning-free. + +This guide will walk you through using ConsisID for use cases. + +## Load Model Checkpoints + +Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method. + +```python +# !pip install consisid_eva_clip insightface facexlib +import torch +from diffusers import ConsisIDPipeline +from diffusers.pipelines.consisid.consisid_utils import prepare_face_models, process_face_embeddings_infer +from huggingface_hub import snapshot_download + +# Download ckpts +snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir="BestWishYsh/ConsisID-preview") + +# Load face helper model to preprocess input face image +face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std = prepare_face_models("BestWishYsh/ConsisID-preview", device="cuda", dtype=torch.bfloat16) + +# Load consisid base model +pipe = ConsisIDPipeline.from_pretrained("BestWishYsh/ConsisID-preview", torch_dtype=torch.bfloat16) +pipe.to("cuda") +``` + +## Identity-Preserving Text-to-Video + +For identity-preserving text-to-video, pass a text prompt and an image contain clear face (e.g., preferably half-body or full-body). By default, ConsisID generates a 720x480 video for the best results. + +```python +from diffusers.utils import export_to_video + +prompt = "The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy's path, adding depth to the scene. The lighting highlights the boy's subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel." +image = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_input.png?download=true" + +id_cond, id_vit_hidden, image, face_kps = process_face_embeddings_infer(face_helper_1, face_clip_model, face_helper_2, eva_transform_mean, eva_transform_std, face_main_model, "cuda", torch.bfloat16, image, is_align_face=True) + +video = pipe(image=image, prompt=prompt, num_inference_steps=50, guidance_scale=6.0, use_dynamic_cfg=False, id_vit_hidden=id_vit_hidden, id_cond=id_cond, kps_cond=face_kps, generator=torch.Generator("cuda").manual_seed(42)) +export_to_video(video.frames[0], "output.mp4", fps=8) +``` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Face ImageVideoDescription
The video, in a beautifully crafted animated style, features a confident woman riding a horse through a lush forest clearing. Her expression is focused yet serene as she adjusts her wide-brimmed hat with a practiced hand. She wears a flowy bohemian dress, which moves gracefully with the rhythm of the horse, the fabric flowing fluidly in the animated motion. The dappled sunlight filters through the trees, casting soft, painterly patterns on the forest floor. Her posture is poised, showing both control and elegance as she guides the horse with ease. The animation's gentle, fluid style adds a dreamlike quality to the scene, with the woman’s calm demeanor and the peaceful surroundings evoking a sense of freedom and harmony.
The video, in a captivating animated style, shows a woman standing in the center of a snowy forest, her eyes narrowed in concentration as she extends her hand forward. She is dressed in a deep blue cloak, her breath visible in the cold air, which is rendered with soft, ethereal strokes. A faint smile plays on her lips as she summons a wisp of ice magic, watching with focus as the surrounding trees and ground begin to shimmer and freeze, covered in delicate ice crystals. The animation’s fluid motion brings the magic to life, with the frost spreading outward in intricate, sparkling patterns. The environment is painted with soft, watercolor-like hues, enhancing the magical, dreamlike atmosphere. The overall mood is serene yet powerful, with the quiet winter air amplifying the delicate beauty of the frozen scene.
The animation features a whimsical portrait of a balloon seller standing in a gentle breeze, captured with soft, hazy brushstrokes that evoke the feel of a serene spring day. His face is framed by a gentle smile, his eyes squinting slightly against the sun, while a few wisps of hair flutter in the wind. He is dressed in a light, pastel-colored shirt, and the balloons around him sway with the wind, adding a sense of playfulness to the scene. The background blurs softly, with hints of a vibrant market or park, enhancing the light-hearted, yet tender mood of the moment.
The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy's path, adding depth to the scene. The lighting highlights the boy's subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel.
The video features a baby wearing a bright superhero cape, standing confidently with arms raised in a powerful pose. The baby has a determined look on their face, with eyes wide and lips pursed in concentration, as if ready to take on a challenge. The setting appears playful, with colorful toys scattered around and a soft rug underfoot, while sunlight streams through a nearby window, highlighting the fluttering cape and adding to the impression of heroism. The overall atmosphere is lighthearted and fun, with the baby's expressions capturing a mix of innocence and an adorable attempt at bravery, as if truly ready to save the day.
+ +## Resources + +Learn more about ConsisID with the following resources. +- A [video](https://www.youtube.com/watch?v=PhlgC-bI5SQ) demonstrating ConsisID's main features. +- The research paper, [Identity-Preserving Text-to-Video Generation by Frequency Decomposition](https://hf.co/papers/2411.17440) for more details. diff --git a/docs/source/zh/_toctree.yml b/docs/source/zh/_toctree.yml index 41d5e95a4230..6416c468a8e9 100644 --- a/docs/source/zh/_toctree.yml +++ b/docs/source/zh/_toctree.yml @@ -5,6 +5,8 @@ title: 快速入门 - local: stable_diffusion title: 有效和高效的扩散 + - local: consisid + title: 身份保持的文本到视频生成 - local: installation title: 安装 title: 开始 diff --git a/docs/source/zh/consisid.md b/docs/source/zh/consisid.md new file mode 100644 index 000000000000..2f404499fc69 --- /dev/null +++ b/docs/source/zh/consisid.md @@ -0,0 +1,100 @@ + +# ConsisID + +[ConsisID](https://github.com/PKU-YuanGroup/ConsisID)是一种身份保持的文本到视频生成模型,其通过频率分解在生成的视频中保持面部一致性。它具有以下特点: + +- 基于频率分解:将人物ID特征解耦为高频和低频部分,从频域的角度分析DIT架构的特性,并且基于此特性设计合理的控制信息注入方式。 + +- 一致性训练策略:我们提出粗到细训练策略、动态掩码损失、动态跨脸损失,进一步提高了模型的泛化能力和身份保持效果。 + + +- 推理无需微调:之前的方法在推理前,需要对输入id进行case-by-case微调,时间和算力开销较大,而我们的方法是tuning-free的。 + + +本指南将指导您使用 ConsisID 生成身份保持的视频。 + +## Load Model Checkpoints +模型权重可以存储在Hub上或本地的单独子文件夹中,在这种情况下,您应该使用 [`~DiffusionPipeline.from_pretrained`] 方法。 + + +```python +# !pip install consisid_eva_clip insightface facexlib +import torch +from diffusers import ConsisIDPipeline +from diffusers.pipelines.consisid.consisid_utils import prepare_face_models, process_face_embeddings_infer +from huggingface_hub import snapshot_download + +# Download ckpts +snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir="BestWishYsh/ConsisID-preview") + +# Load face helper model to preprocess input face image +face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std = prepare_face_models("BestWishYsh/ConsisID-preview", device="cuda", dtype=torch.bfloat16) + +# Load consisid base model +pipe = ConsisIDPipeline.from_pretrained("BestWishYsh/ConsisID-preview", torch_dtype=torch.bfloat16) +pipe.to("cuda") +``` + +## Identity-Preserving Text-to-Video +对于身份保持的文本到视频生成,需要输入文本提示和包含清晰面部(例如,最好是半身或全身)的图像。默认情况下,ConsisID 会生成 720x480 的视频以获得最佳效果。 + +```python +from diffusers.utils import export_to_video + +prompt = "The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy's path, adding depth to the scene. The lighting highlights the boy's subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel." +image = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_input.png?download=true" + +id_cond, id_vit_hidden, image, face_kps = process_face_embeddings_infer(face_helper_1, face_clip_model, face_helper_2, eva_transform_mean, eva_transform_std, face_main_model, "cuda", torch.bfloat16, image, is_align_face=True) + +video = pipe(image=image, prompt=prompt, num_inference_steps=50, guidance_scale=6.0, use_dynamic_cfg=False, id_vit_hidden=id_vit_hidden, id_cond=id_cond, kps_cond=face_kps, generator=torch.Generator("cuda").manual_seed(42)) +export_to_video(video.frames[0], "output.mp4", fps=8) +``` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Face ImageVideoDescription
The video, in a beautifully crafted animated style, features a confident woman riding a horse through a lush forest clearing. Her expression is focused yet serene as she adjusts her wide-brimmed hat with a practiced hand. She wears a flowy bohemian dress, which moves gracefully with the rhythm of the horse, the fabric flowing fluidly in the animated motion. The dappled sunlight filters through the trees, casting soft, painterly patterns on the forest floor. Her posture is poised, showing both control and elegance as she guides the horse with ease. The animation's gentle, fluid style adds a dreamlike quality to the scene, with the woman’s calm demeanor and the peaceful surroundings evoking a sense of freedom and harmony.
The video, in a captivating animated style, shows a woman standing in the center of a snowy forest, her eyes narrowed in concentration as she extends her hand forward. She is dressed in a deep blue cloak, her breath visible in the cold air, which is rendered with soft, ethereal strokes. A faint smile plays on her lips as she summons a wisp of ice magic, watching with focus as the surrounding trees and ground begin to shimmer and freeze, covered in delicate ice crystals. The animation’s fluid motion brings the magic to life, with the frost spreading outward in intricate, sparkling patterns. The environment is painted with soft, watercolor-like hues, enhancing the magical, dreamlike atmosphere. The overall mood is serene yet powerful, with the quiet winter air amplifying the delicate beauty of the frozen scene.
The animation features a whimsical portrait of a balloon seller standing in a gentle breeze, captured with soft, hazy brushstrokes that evoke the feel of a serene spring day. His face is framed by a gentle smile, his eyes squinting slightly against the sun, while a few wisps of hair flutter in the wind. He is dressed in a light, pastel-colored shirt, and the balloons around him sway with the wind, adding a sense of playfulness to the scene. The background blurs softly, with hints of a vibrant market or park, enhancing the light-hearted, yet tender mood of the moment.
The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy's path, adding depth to the scene. The lighting highlights the boy's subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel.
The video features a baby wearing a bright superhero cape, standing confidently with arms raised in a powerful pose. The baby has a determined look on their face, with eyes wide and lips pursed in concentration, as if ready to take on a challenge. The setting appears playful, with colorful toys scattered around and a soft rug underfoot, while sunlight streams through a nearby window, highlighting the fluttering cape and adding to the impression of heroism. The overall atmosphere is lighthearted and fun, with the baby's expressions capturing a mix of innocence and an adorable attempt at bravery, as if truly ready to save the day.
+ +## Resources + +通过以下资源了解有关 ConsisID 的更多信息: + +- 一段 [视频](https://www.youtube.com/watch?v=PhlgC-bI5SQ) 演示了 ConsisID 的主要功能; +- 有关更多详细信息,请参阅研究论文 [Identity-Preserving Text-to-Video Generation by Frequency Decomposition](https://hf.co/papers/2411.17440)。 diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 5e9ab2a117d1..b1801fbb2b4b 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -92,6 +92,7 @@ "AutoencoderTiny", "CogVideoXTransformer3DModel", "CogView3PlusTransformer2DModel", + "ConsisIDTransformer3DModel", "ConsistencyDecoderVAE", "ControlNetModel", "ControlNetUnionModel", @@ -275,6 +276,7 @@ "CogVideoXPipeline", "CogVideoXVideoToVideoPipeline", "CogView3PlusPipeline", + "ConsisIDPipeline", "CycleDiffusionPipeline", "FluxControlImg2ImgPipeline", "FluxControlInpaintPipeline", @@ -602,6 +604,7 @@ AutoencoderTiny, CogVideoXTransformer3DModel, CogView3PlusTransformer2DModel, + ConsisIDTransformer3DModel, ConsistencyDecoderVAE, ControlNetModel, ControlNetUnionModel, @@ -764,6 +767,7 @@ CogVideoXPipeline, CogVideoXVideoToVideoPipeline, CogView3PlusPipeline, + ConsisIDPipeline, CycleDiffusionPipeline, FluxControlImg2ImgPipeline, FluxControlInpaintPipeline, diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 454496ff04d4..b35839b29ed2 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -47,6 +47,7 @@ "SD3Transformer2DModel": lambda model_cls, weights: weights, "FluxTransformer2DModel": lambda model_cls, weights: weights, "CogVideoXTransformer3DModel": lambda model_cls, weights: weights, + "ConsisIDTransformer3DModel": lambda model_cls, weights: weights, "MochiTransformer3DModel": lambda model_cls, weights: weights, "HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights, "LTXVideoTransformer3DModel": lambda model_cls, weights: weights, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 01e67b01d91a..e3f291ce2dc7 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -54,6 +54,7 @@ _import_structure["modeling_utils"] = ["ModelMixin"] _import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"] _import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"] + _import_structure["transformers.consisid_transformer_3d"] = ["ConsisIDTransformer3DModel"] _import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"] _import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"] _import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"] @@ -129,6 +130,7 @@ AuraFlowTransformer2DModel, CogVideoXTransformer3DModel, CogView3PlusTransformer2DModel, + ConsisIDTransformer3DModel, DiTTransformer2DModel, DualTransformer2DModel, FluxTransformer2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 3a33c8070c08..77e1698b8fc2 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -4,6 +4,7 @@ if is_torch_available(): from .auraflow_transformer_2d import AuraFlowTransformer2DModel from .cogvideox_transformer_3d import CogVideoXTransformer3DModel + from .consisid_transformer_3d import ConsisIDTransformer3DModel from .dit_transformer_2d import DiTTransformer2DModel from .dual_transformer_2d import DualTransformer2DModel from .hunyuan_transformer_2d import HunyuanDiT2DModel diff --git a/src/diffusers/models/transformers/consisid_transformer_3d.py b/src/diffusers/models/transformers/consisid_transformer_3d.py new file mode 100644 index 000000000000..86a6628b5161 --- /dev/null +++ b/src/diffusers/models/transformers/consisid_transformer_3d.py @@ -0,0 +1,801 @@ +# Copyright 2024 ConsisID Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import Attention, FeedForward +from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0 +from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class PerceiverAttention(nn.Module): + def __init__(self, dim: int, dim_head: int = 64, heads: int = 8, kv_dim: Optional[int] = None): + super().__init__() + + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, image_embeds: torch.Tensor, latents: torch.Tensor) -> torch.Tensor: + # Apply normalization + image_embeds = self.norm1(image_embeds) + latents = self.norm2(latents) + + batch_size, seq_len, _ = latents.shape # Get batch size and sequence length + + # Compute query, key, and value matrices + query = self.to_q(latents) + kv_input = torch.cat((image_embeds, latents), dim=-2) + key, value = self.to_kv(kv_input).chunk(2, dim=-1) + + # Reshape the tensors for multi-head attention + query = query.reshape(query.size(0), -1, self.heads, self.dim_head).transpose(1, 2) + key = key.reshape(key.size(0), -1, self.heads, self.dim_head).transpose(1, 2) + value = value.reshape(value.size(0), -1, self.heads, self.dim_head).transpose(1, 2) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (query * scale) @ (key * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + output = weight @ value + + # Reshape and return the final output + output = output.permute(0, 2, 1, 3).reshape(batch_size, seq_len, -1) + + return self.to_out(output) + + +class LocalFacialExtractor(nn.Module): + def __init__( + self, + id_dim: int = 1280, + vit_dim: int = 1024, + depth: int = 10, + dim_head: int = 64, + heads: int = 16, + num_id_token: int = 5, + num_queries: int = 32, + output_dim: int = 2048, + ff_mult: int = 4, + num_scale: int = 5, + ): + super().__init__() + + # Storing identity token and query information + self.num_id_token = num_id_token + self.vit_dim = vit_dim + self.num_queries = num_queries + assert depth % num_scale == 0 + self.depth = depth // num_scale + self.num_scale = num_scale + scale = vit_dim**-0.5 + + # Learnable latent query embeddings + self.latents = nn.Parameter(torch.randn(1, num_queries, vit_dim) * scale) + # Projection layer to map the latent output to the desired dimension + self.proj_out = nn.Parameter(scale * torch.randn(vit_dim, output_dim)) + + # Attention and ConsisIDFeedForward layer stack + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=vit_dim, dim_head=dim_head, heads=heads), # Perceiver Attention layer + nn.Sequential( + nn.LayerNorm(vit_dim), + nn.Linear(vit_dim, vit_dim * ff_mult, bias=False), + nn.GELU(), + nn.Linear(vit_dim * ff_mult, vit_dim, bias=False), + ), # ConsisIDFeedForward layer + ] + ) + ) + + # Mappings for each of the 5 different ViT features + for i in range(num_scale): + setattr( + self, + f"mapping_{i}", + nn.Sequential( + nn.Linear(vit_dim, vit_dim), + nn.LayerNorm(vit_dim), + nn.LeakyReLU(), + nn.Linear(vit_dim, vit_dim), + nn.LayerNorm(vit_dim), + nn.LeakyReLU(), + nn.Linear(vit_dim, vit_dim), + ), + ) + + # Mapping for identity embedding vectors + self.id_embedding_mapping = nn.Sequential( + nn.Linear(id_dim, vit_dim), + nn.LayerNorm(vit_dim), + nn.LeakyReLU(), + nn.Linear(vit_dim, vit_dim), + nn.LayerNorm(vit_dim), + nn.LeakyReLU(), + nn.Linear(vit_dim, vit_dim * num_id_token), + ) + + def forward(self, id_embeds: torch.Tensor, vit_hidden_states: List[torch.Tensor]) -> torch.Tensor: + # Repeat latent queries for the batch size + latents = self.latents.repeat(id_embeds.size(0), 1, 1) + + # Map the identity embedding to tokens + id_embeds = self.id_embedding_mapping(id_embeds) + id_embeds = id_embeds.reshape(-1, self.num_id_token, self.vit_dim) + + # Concatenate identity tokens with the latent queries + latents = torch.cat((latents, id_embeds), dim=1) + + # Process each of the num_scale visual feature inputs + for i in range(self.num_scale): + vit_feature = getattr(self, f"mapping_{i}")(vit_hidden_states[i]) + ctx_feature = torch.cat((id_embeds, vit_feature), dim=1) + + # Pass through the PerceiverAttention and ConsisIDFeedForward layers + for attn, ff in self.layers[i * self.depth : (i + 1) * self.depth]: + latents = attn(ctx_feature, latents) + latents + latents = ff(latents) + latents + + # Retain only the query latents + latents = latents[:, : self.num_queries] + # Project the latents to the output dimension + latents = latents @ self.proj_out + return latents + + +class PerceiverCrossAttention(nn.Module): + def __init__(self, dim: int = 3072, dim_head: int = 128, heads: int = 16, kv_dim: int = 2048): + super().__init__() + + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + # Layer normalization to stabilize training + self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim) + self.norm2 = nn.LayerNorm(dim) + + # Linear transformations to produce queries, keys, and values + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, image_embeds: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor: + # Apply layer normalization to the input image and latent features + image_embeds = self.norm1(image_embeds) + hidden_states = self.norm2(hidden_states) + + batch_size, seq_len, _ = hidden_states.shape + + # Compute queries, keys, and values + query = self.to_q(hidden_states) + key, value = self.to_kv(image_embeds).chunk(2, dim=-1) + + # Reshape tensors to split into attention heads + query = query.reshape(query.size(0), -1, self.heads, self.dim_head).transpose(1, 2) + key = key.reshape(key.size(0), -1, self.heads, self.dim_head).transpose(1, 2) + value = value.reshape(value.size(0), -1, self.heads, self.dim_head).transpose(1, 2) + + # Compute attention weights + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (query * scale) @ (key * scale).transpose(-2, -1) # More stable scaling than post-division + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + + # Compute the output via weighted combination of values + out = weight @ value + + # Reshape and permute to prepare for final linear transformation + out = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, -1) + + return self.to_out(out) + + +@maybe_allow_in_graph +class ConsisIDBlock(nn.Module): + r""" + Transformer block used in [ConsisID](https://github.com/PKU-YuanGroup/ConsisID) model. + + Parameters: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + time_embed_dim (`int`): + The number of channels in timestep embedding. + dropout (`float`, defaults to `0.0`): + The dropout probability to use. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to be used in feed-forward. + attention_bias (`bool`, defaults to `False`): + Whether or not to use bias in attention projection layers. + qk_norm (`bool`, defaults to `True`): + Whether or not to use normalization after query and key projections in Attention. + norm_elementwise_affine (`bool`, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_eps (`float`, defaults to `1e-5`): + Epsilon value for normalization layers. + final_dropout (`bool` defaults to `False`): + Whether to apply a final dropout after the last feed-forward layer. + ff_inner_dim (`int`, *optional*, defaults to `None`): + Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used. + ff_bias (`bool`, defaults to `True`): + Whether or not to use bias in Feed-forward layer. + attention_out_bias (`bool`, defaults to `True`): + Whether or not to use bias in Attention output projection layer. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + time_embed_dim: int, + dropout: float = 0.0, + activation_fn: str = "gelu-approximate", + attention_bias: bool = False, + qk_norm: bool = True, + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + final_dropout: bool = True, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + + # 1. Self Attention + self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) + + self.attn1 = Attention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + qk_norm="layer_norm" if qk_norm else None, + eps=1e-6, + bias=attention_bias, + out_bias=attention_out_bias, + processor=CogVideoXAttnProcessor2_0(), + ) + + # 2. Feed Forward + self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + + # norm & modulate + norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( + hidden_states, encoder_hidden_states, temb + ) + + # attention + attn_hidden_states, attn_encoder_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = hidden_states + gate_msa * attn_hidden_states + encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states + + # norm & modulate + norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( + hidden_states, encoder_hidden_states, temb + ) + + # feed-forward + norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) + ff_output = self.ff(norm_hidden_states) + + hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:] + encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length] + + return hidden_states, encoder_hidden_states + + +class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): + """ + A Transformer model for video-like data in [ConsisID](https://github.com/PKU-YuanGroup/ConsisID). + + Parameters: + num_attention_heads (`int`, defaults to `30`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `64`): + The number of channels in each head. + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `16`): + The number of channels in the output. + flip_sin_to_cos (`bool`, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + time_embed_dim (`int`, defaults to `512`): + Output dimension of timestep embeddings. + text_embed_dim (`int`, defaults to `4096`): + Input dimension of text embeddings from the text encoder. + num_layers (`int`, defaults to `30`): + The number of layers of Transformer blocks to use. + dropout (`float`, defaults to `0.0`): + The dropout probability to use. + attention_bias (`bool`, defaults to `True`): + Whether to use bias in the attention projection layers. + sample_width (`int`, defaults to `90`): + The width of the input latents. + sample_height (`int`, defaults to `60`): + The height of the input latents. + sample_frames (`int`, defaults to `49`): + The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49 + instead of 13 because ConsisID processed 13 latent frames at once in its default and recommended settings, + but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with + K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1). + patch_size (`int`, defaults to `2`): + The size of the patches to use in the patch embedding layer. + temporal_compression_ratio (`int`, defaults to `4`): + The compression ratio across the temporal dimension. See documentation for `sample_frames`. + max_text_seq_length (`int`, defaults to `226`): + The maximum sequence length of the input text embeddings. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to use in feed-forward. + timestep_activation_fn (`str`, defaults to `"silu"`): + Activation function to use when generating the timestep embeddings. + norm_elementwise_affine (`bool`, defaults to `True`): + Whether to use elementwise affine in normalization layers. + norm_eps (`float`, defaults to `1e-5`): + The epsilon value to use in normalization layers. + spatial_interpolation_scale (`float`, defaults to `1.875`): + Scaling factor to apply in 3D positional embeddings across spatial dimensions. + temporal_interpolation_scale (`float`, defaults to `1.0`): + Scaling factor to apply in 3D positional embeddings across temporal dimensions. + is_train_face (`bool`, defaults to `False`): + Whether to use enable the identity-preserving module during the training process. When set to `True`, the + model will focus on identity-preserving tasks. + is_kps (`bool`, defaults to `False`): + Whether to enable keypoint for global facial extractor. If `True`, keypoints will be in the model. + cross_attn_interval (`int`, defaults to `2`): + The interval between cross-attention layers in the Transformer architecture. A larger value may reduce the + frequency of cross-attention computations, which can help reduce computational overhead. + cross_attn_dim_head (`int`, optional, defaults to `128`): + The dimensionality of each attention head in the cross-attention layers of the Transformer architecture. A + larger value increases the capacity to attend to more complex patterns, but also increases memory and + computation costs. + cross_attn_num_heads (`int`, optional, defaults to `16`): + The number of attention heads in the cross-attention layers. More heads allow for more parallel attention + mechanisms, capturing diverse relationships between different components of the input, but can also + increase computational requirements. + LFE_id_dim (`int`, optional, defaults to `1280`): + The dimensionality of the identity vector used in the Local Facial Extractor (LFE). This vector represents + the identity features of a face, which are important for tasks like face recognition and identity + preservation across different frames. + LFE_vit_dim (`int`, optional, defaults to `1024`): + The dimension of the vision transformer (ViT) output used in the Local Facial Extractor (LFE). This value + dictates the size of the transformer-generated feature vectors that will be processed for facial feature + extraction. + LFE_depth (`int`, optional, defaults to `10`): + The number of layers in the Local Facial Extractor (LFE). Increasing the depth allows the model to capture + more complex representations of facial features, but also increases the computational load. + LFE_dim_head (`int`, optional, defaults to `64`): + The dimensionality of each attention head in the Local Facial Extractor (LFE). This parameter affects how + finely the model can process and focus on different parts of the facial features during the extraction + process. + LFE_num_heads (`int`, optional, defaults to `16`): + The number of attention heads in the Local Facial Extractor (LFE). More heads can improve the model's + ability to capture diverse facial features, but at the cost of increased computational complexity. + LFE_num_id_token (`int`, optional, defaults to `5`): + The number of identity tokens used in the Local Facial Extractor (LFE). This defines how many + identity-related tokens the model will process to ensure face identity preservation during feature + extraction. + LFE_num_querie (`int`, optional, defaults to `32`): + The number of query tokens used in the Local Facial Extractor (LFE). These tokens are used to capture + high-frequency face-related information that aids in accurate facial feature extraction. + LFE_output_dim (`int`, optional, defaults to `2048`): + The output dimension of the Local Facial Extractor (LFE). This dimension determines the size of the feature + vectors produced by the LFE module, which will be used for subsequent tasks such as face recognition or + tracking. + LFE_ff_mult (`int`, optional, defaults to `4`): + The multiplication factor applied to the feed-forward network's hidden layer size in the Local Facial + Extractor (LFE). A higher value increases the model's capacity to learn more complex facial feature + transformations, but also increases the computation and memory requirements. + LFE_num_scale (`int`, optional, defaults to `5`): + The number of different scales visual feature. A higher value increases the model's capacity to learn more + complex facial feature transformations, but also increases the computation and memory requirements. + local_face_scale (`float`, defaults to `1.0`): + A scaling factor used to adjust the importance of local facial features in the model. This can influence + how strongly the model focuses on high frequency face-related content. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 30, + attention_head_dim: int = 64, + in_channels: int = 16, + out_channels: Optional[int] = 16, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + time_embed_dim: int = 512, + text_embed_dim: int = 4096, + num_layers: int = 30, + dropout: float = 0.0, + attention_bias: bool = True, + sample_width: int = 90, + sample_height: int = 60, + sample_frames: int = 49, + patch_size: int = 2, + temporal_compression_ratio: int = 4, + max_text_seq_length: int = 226, + activation_fn: str = "gelu-approximate", + timestep_activation_fn: str = "silu", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + spatial_interpolation_scale: float = 1.875, + temporal_interpolation_scale: float = 1.0, + use_rotary_positional_embeddings: bool = False, + use_learned_positional_embeddings: bool = False, + is_train_face: bool = False, + is_kps: bool = False, + cross_attn_interval: int = 2, + cross_attn_dim_head: int = 128, + cross_attn_num_heads: int = 16, + LFE_id_dim: int = 1280, + LFE_vit_dim: int = 1024, + LFE_depth: int = 10, + LFE_dim_head: int = 64, + LFE_num_heads: int = 16, + LFE_num_id_token: int = 5, + LFE_num_querie: int = 32, + LFE_output_dim: int = 2048, + LFE_ff_mult: int = 4, + LFE_num_scale: int = 5, + local_face_scale: float = 1.0, + ): + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + + if not use_rotary_positional_embeddings and use_learned_positional_embeddings: + raise ValueError( + "There are no ConsisID checkpoints available with disable rotary embeddings and learned positional " + "embeddings. If you're using a custom model and/or believe this should be supported, please open an " + "issue at https://github.com/huggingface/diffusers/issues." + ) + + # 1. Patch embedding + self.patch_embed = CogVideoXPatchEmbed( + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + text_embed_dim=text_embed_dim, + bias=True, + sample_width=sample_width, + sample_height=sample_height, + sample_frames=sample_frames, + temporal_compression_ratio=temporal_compression_ratio, + max_text_seq_length=max_text_seq_length, + spatial_interpolation_scale=spatial_interpolation_scale, + temporal_interpolation_scale=temporal_interpolation_scale, + use_positional_embeddings=not use_rotary_positional_embeddings, + use_learned_positional_embeddings=use_learned_positional_embeddings, + ) + self.embedding_dropout = nn.Dropout(dropout) + + # 2. Time embeddings + self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) + self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) + + # 3. Define spatio-temporal transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + ConsisIDBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + time_embed_dim=time_embed_dim, + dropout=dropout, + activation_fn=activation_fn, + attention_bias=attention_bias, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + ) + for _ in range(num_layers) + ] + ) + self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) + + # 4. Output blocks + self.norm_out = AdaLayerNorm( + embedding_dim=time_embed_dim, + output_dim=2 * inner_dim, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + chunk_dim=1, + ) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) + + self.is_train_face = is_train_face + self.is_kps = is_kps + + # 5. Define identity-preserving config + if is_train_face: + # LFE configs + self.LFE_id_dim = LFE_id_dim + self.LFE_vit_dim = LFE_vit_dim + self.LFE_depth = LFE_depth + self.LFE_dim_head = LFE_dim_head + self.LFE_num_heads = LFE_num_heads + self.LFE_num_id_token = LFE_num_id_token + self.LFE_num_querie = LFE_num_querie + self.LFE_output_dim = LFE_output_dim + self.LFE_ff_mult = LFE_ff_mult + self.LFE_num_scale = LFE_num_scale + # cross configs + self.inner_dim = inner_dim + self.cross_attn_interval = cross_attn_interval + self.num_cross_attn = num_layers // cross_attn_interval + self.cross_attn_dim_head = cross_attn_dim_head + self.cross_attn_num_heads = cross_attn_num_heads + self.cross_attn_kv_dim = int(self.inner_dim / 3 * 2) + self.local_face_scale = local_face_scale + # face modules + self._init_face_inputs() + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value + + def _init_face_inputs(self): + self.local_facial_extractor = LocalFacialExtractor( + id_dim=self.LFE_id_dim, + vit_dim=self.LFE_vit_dim, + depth=self.LFE_depth, + dim_head=self.LFE_dim_head, + heads=self.LFE_num_heads, + num_id_token=self.LFE_num_id_token, + num_queries=self.LFE_num_querie, + output_dim=self.LFE_output_dim, + ff_mult=self.LFE_ff_mult, + num_scale=self.LFE_num_scale, + ) + self.perceiver_cross_attention = nn.ModuleList( + [ + PerceiverCrossAttention( + dim=self.inner_dim, + dim_head=self.cross_attn_dim_head, + heads=self.cross_attn_num_heads, + kv_dim=self.cross_attn_kv_dim, + ) + for _ in range(self.num_cross_attn) + ] + ) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: Union[int, float, torch.LongTensor], + timestep_cond: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + id_cond: Optional[torch.Tensor] = None, + id_vit_hidden: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + # fuse clip and insightface + valid_face_emb = None + if self.is_train_face: + id_cond = id_cond.to(device=hidden_states.device, dtype=hidden_states.dtype) + id_vit_hidden = [ + tensor.to(device=hidden_states.device, dtype=hidden_states.dtype) for tensor in id_vit_hidden + ] + valid_face_emb = self.local_facial_extractor( + id_cond, id_vit_hidden + ) # torch.Size([1, 1280]), list[5](torch.Size([1, 577, 1024])) -> torch.Size([1, 32, 2048]) + + batch_size, num_frames, channels, height, width = hidden_states.shape + + # 1. Time embedding + timesteps = timestep + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + # 2. Patch embedding + # torch.Size([1, 226, 4096]) torch.Size([1, 13, 32, 60, 90]) + hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) # torch.Size([1, 17776, 3072]) + hidden_states = self.embedding_dropout(hidden_states) # torch.Size([1, 17776, 3072]) + + text_seq_length = encoder_hidden_states.shape[1] + encoder_hidden_states = hidden_states[:, :text_seq_length] # torch.Size([1, 226, 3072]) + hidden_states = hidden_states[:, text_seq_length:] # torch.Size([1, 17550, 3072]) + + # 3. Transformer blocks + ca_idx = 0 + for i, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + emb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + ) + + if self.is_train_face: + if i % self.cross_attn_interval == 0 and valid_face_emb is not None: + hidden_states = hidden_states + self.local_face_scale * self.perceiver_cross_attention[ca_idx]( + valid_face_emb, hidden_states + ) # torch.Size([2, 32, 2048]) torch.Size([2, 17550, 3072]) + ca_idx += 1 + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + hidden_states = self.norm_final(hidden_states) + hidden_states = hidden_states[:, text_seq_length:] + + # 4. Final block + hidden_states = self.norm_out(hidden_states, temb=emb) + hidden_states = self.proj_out(hidden_states) + + # 5. Unpatchify + # Note: we use `-1` instead of `channels`: + # - It is okay to `channels` use for ConsisID (number of input channels is equal to output channels) + p = self.config.patch_size + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index ce291e5ceb45..5829cf495dcc 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -154,6 +154,7 @@ "CogVideoXFunControlPipeline", ] _import_structure["cogview3"] = ["CogView3PlusPipeline"] + _import_structure["consisid"] = ["ConsisIDPipeline"] _import_structure["controlnet"].extend( [ "BlipDiffusionControlNetPipeline", @@ -496,6 +497,7 @@ CogVideoXVideoToVideoPipeline, ) from .cogview3 import CogView3PlusPipeline + from .consisid import ConsisIDPipeline from .controlnet import ( BlipDiffusionControlNetPipeline, StableDiffusionControlNetImg2ImgPipeline, diff --git a/src/diffusers/pipelines/consisid/__init__.py b/src/diffusers/pipelines/consisid/__init__.py new file mode 100644 index 000000000000..5052e146f1df --- /dev/null +++ b/src/diffusers/pipelines/consisid/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_consisid"] = ["ConsisIDPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_consisid import ConsisIDPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/consisid/consisid_utils.py b/src/diffusers/pipelines/consisid/consisid_utils.py new file mode 100644 index 000000000000..ec9e9aa49c0f --- /dev/null +++ b/src/diffusers/pipelines/consisid/consisid_utils.py @@ -0,0 +1,355 @@ +import importlib.util +import os + +import cv2 +import numpy as np +import torch +from PIL import Image, ImageOps +from torchvision.transforms import InterpolationMode +from torchvision.transforms.functional import normalize, resize + +from ...utils import load_image + + +_insightface_available = importlib.util.find_spec("insightface") is not None +_consisid_eva_clip_available = importlib.util.find_spec("consisid_eva_clip") is not None +_facexlib_available = importlib.util.find_spec("facexlib") is not None + +if _insightface_available: + import insightface + from insightface.app import FaceAnalysis +else: + raise ImportError("insightface is not available. Please install it using 'pip install insightface'.") + +if _consisid_eva_clip_available: + from consisid_eva_clip import create_model_and_transforms + from consisid_eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +else: + raise ImportError("consisid_eva_clip is not available. Please install it using 'pip install consisid_eva_clip'.") + +if _facexlib_available: + from facexlib.parsing import init_parsing_model + from facexlib.utils.face_restoration_helper import FaceRestoreHelper +else: + raise ImportError("facexlib is not available. Please install it using 'pip install facexlib'.") + + +def resize_numpy_image_long(image, resize_long_edge=768): + """ + Resize the input image to a specified long edge while maintaining aspect ratio. + + Args: + image (numpy.ndarray): Input image (H x W x C or H x W). + resize_long_edge (int): The target size for the long edge of the image. Default is 768. + + Returns: + numpy.ndarray: Resized image with the long edge matching `resize_long_edge`, while maintaining the aspect + ratio. + """ + + h, w = image.shape[:2] + if max(h, w) <= resize_long_edge: + return image + k = resize_long_edge / max(h, w) + h = int(h * k) + w = int(w * k) + image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4) + return image + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == "float64": + img = img.astype("float32") + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + return _totensor(imgs, bgr2rgb, float32) + + +def to_gray(img): + """ + Converts an RGB image to grayscale by applying the standard luminosity formula. + + Args: + img (torch.Tensor): The input image tensor with shape (batch_size, channels, height, width). + The image is expected to be in RGB format (3 channels). + + Returns: + torch.Tensor: The grayscale image tensor with shape (batch_size, 3, height, width). + The grayscale values are replicated across all three channels. + """ + x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3] + x = x.repeat(1, 3, 1, 1) + return x + + +def process_face_embeddings( + face_helper_1, + clip_vision_model, + face_helper_2, + eva_transform_mean, + eva_transform_std, + app, + device, + weight_dtype, + image, + original_id_image=None, + is_align_face=True, +): + """ + Process face embeddings from an image, extracting relevant features such as face embeddings, landmarks, and parsed + face features using a series of face detection and alignment tools. + + Args: + face_helper_1: Face helper object (first helper) for alignment and landmark detection. + clip_vision_model: Pre-trained CLIP vision model used for feature extraction. + face_helper_2: Face helper object (second helper) for embedding extraction. + eva_transform_mean: Mean values for image normalization before passing to EVA model. + eva_transform_std: Standard deviation values for image normalization before passing to EVA model. + app: Application instance used for face detection. + device: Device (CPU or GPU) where the computations will be performed. + weight_dtype: Data type of the weights for precision (e.g., `torch.float32`). + image: Input image in RGB format with pixel values in the range [0, 255]. + original_id_image: (Optional) Original image for feature extraction if `is_align_face` is False. + is_align_face: Boolean flag indicating whether face alignment should be performed. + + Returns: + Tuple: + - id_cond: Concatenated tensor of Ante face embedding and CLIP vision embedding + - id_vit_hidden: Hidden state of the CLIP vision model, a list of tensors. + - return_face_features_image_2: Processed face features image after normalization and parsing. + - face_kps: Keypoints of the face detected in the image. + """ + + face_helper_1.clean_all() + image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + # get antelopev2 embedding + face_info = app.get(image_bgr) + if len(face_info) > 0: + face_info = sorted(face_info, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))[ + -1 + ] # only use the maximum face + id_ante_embedding = face_info["embedding"] # (512,) + face_kps = face_info["kps"] + else: + id_ante_embedding = None + face_kps = None + + # using facexlib to detect and align face + face_helper_1.read_image(image_bgr) + face_helper_1.get_face_landmarks_5(only_center_face=True) + if face_kps is None: + face_kps = face_helper_1.all_landmarks_5[0] + face_helper_1.align_warp_face() + if len(face_helper_1.cropped_faces) == 0: + raise RuntimeError("facexlib align face fail") + align_face = face_helper_1.cropped_faces[0] # (512, 512, 3) # RGB + + # incase insightface didn't detect face + if id_ante_embedding is None: + print("fail to detect face using insightface, extract embedding on align face") + id_ante_embedding = face_helper_2.get_feat(align_face) + + id_ante_embedding = torch.from_numpy(id_ante_embedding).to(device, weight_dtype) # torch.Size([512]) + if id_ante_embedding.ndim == 1: + id_ante_embedding = id_ante_embedding.unsqueeze(0) # torch.Size([1, 512]) + + # parsing + if is_align_face: + input = img2tensor(align_face, bgr2rgb=True).unsqueeze(0) / 255.0 # torch.Size([1, 3, 512, 512]) + input = input.to(device) + parsing_out = face_helper_1.face_parse(normalize(input, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))[0] + parsing_out = parsing_out.argmax(dim=1, keepdim=True) # torch.Size([1, 1, 512, 512]) + bg_label = [0, 16, 18, 7, 8, 9, 14, 15] + bg = sum(parsing_out == i for i in bg_label).bool() + white_image = torch.ones_like(input) # torch.Size([1, 3, 512, 512]) + # only keep the face features + return_face_features_image = torch.where(bg, white_image, to_gray(input)) # torch.Size([1, 3, 512, 512]) + return_face_features_image_2 = torch.where(bg, white_image, input) # torch.Size([1, 3, 512, 512]) + else: + original_image_bgr = cv2.cvtColor(original_id_image, cv2.COLOR_RGB2BGR) + input = img2tensor(original_image_bgr, bgr2rgb=True).unsqueeze(0) / 255.0 # torch.Size([1, 3, 512, 512]) + input = input.to(device) + return_face_features_image = return_face_features_image_2 = input + + # transform img before sending to eva-clip-vit + face_features_image = resize( + return_face_features_image, clip_vision_model.image_size, InterpolationMode.BICUBIC + ) # torch.Size([1, 3, 336, 336]) + face_features_image = normalize(face_features_image, eva_transform_mean, eva_transform_std) + id_cond_vit, id_vit_hidden = clip_vision_model( + face_features_image.to(weight_dtype), return_all_features=False, return_hidden=True, shuffle=False + ) # torch.Size([1, 768]), list(torch.Size([1, 577, 1024])) + id_cond_vit_norm = torch.norm(id_cond_vit, 2, 1, True) + id_cond_vit = torch.div(id_cond_vit, id_cond_vit_norm) + + id_cond = torch.cat( + [id_ante_embedding, id_cond_vit], dim=-1 + ) # torch.Size([1, 512]), torch.Size([1, 768]) -> torch.Size([1, 1280]) + + return ( + id_cond, + id_vit_hidden, + return_face_features_image_2, + face_kps, + ) # torch.Size([1, 1280]), list(torch.Size([1, 577, 1024])) + + +def process_face_embeddings_infer( + face_helper_1, + clip_vision_model, + face_helper_2, + eva_transform_mean, + eva_transform_std, + app, + device, + weight_dtype, + img_file_path, + is_align_face=True, +): + """ + Process face embeddings from an input image for inference, including alignment, feature extraction, and embedding + concatenation. + + Args: + face_helper_1: Face helper object (first helper) for alignment and landmark detection. + clip_vision_model: Pre-trained CLIP vision model used for feature extraction. + face_helper_2: Face helper object (second helper) for embedding extraction. + eva_transform_mean: Mean values for image normalization before passing to EVA model. + eva_transform_std: Standard deviation values for image normalization before passing to EVA model. + app: Application instance used for face detection. + device: Device (CPU or GPU) where the computations will be performed. + weight_dtype: Data type of the weights for precision (e.g., `torch.float32`). + img_file_path: Path to the input image file (string) or a numpy array representing an image. + is_align_face: Boolean flag indicating whether face alignment should be performed (default: True). + + Returns: + Tuple: + - id_cond: Concatenated tensor of Ante face embedding and CLIP vision embedding. + - id_vit_hidden: Hidden state of the CLIP vision model, a list of tensors. + - image: Processed face image after feature extraction and alignment. + - face_kps: Keypoints of the face detected in the image. + """ + + # Load and preprocess the input image + if isinstance(img_file_path, str): + image = np.array(load_image(image=img_file_path).convert("RGB")) + else: + image = np.array(ImageOps.exif_transpose(Image.fromarray(img_file_path)).convert("RGB")) + + # Resize image to ensure the longer side is 1024 pixels + image = resize_numpy_image_long(image, 1024) + original_id_image = image + + # Process the image to extract face embeddings and related features + id_cond, id_vit_hidden, align_crop_face_image, face_kps = process_face_embeddings( + face_helper_1, + clip_vision_model, + face_helper_2, + eva_transform_mean, + eva_transform_std, + app, + device, + weight_dtype, + image, + original_id_image, + is_align_face, + ) + + # Convert the aligned cropped face image (torch tensor) to a numpy array + tensor = align_crop_face_image.cpu().detach() + tensor = tensor.squeeze() + tensor = tensor.permute(1, 2, 0) + tensor = tensor.numpy() * 255 + tensor = tensor.astype(np.uint8) + image = ImageOps.exif_transpose(Image.fromarray(tensor)) + + return id_cond, id_vit_hidden, image, face_kps + + +def prepare_face_models(model_path, device, dtype): + """ + Prepare all face models for the facial recognition task. + + Parameters: + - model_path: Path to the directory containing model files. + - device: The device (e.g., 'cuda', 'cpu') where models will be loaded. + - dtype: Data type (e.g., torch.float32) for model inference. + + Returns: + - face_helper_1: First face restoration helper. + - face_helper_2: Second face restoration helper. + - face_clip_model: CLIP model for face extraction. + - eva_transform_mean: Mean value for image normalization. + - eva_transform_std: Standard deviation value for image normalization. + - face_main_model: Main face analysis model. + """ + # get helper model + face_helper_1 = FaceRestoreHelper( + upscale_factor=1, + face_size=512, + crop_ratio=(1, 1), + det_model="retinaface_resnet50", + save_ext="png", + device=device, + model_rootpath=os.path.join(model_path, "face_encoder"), + ) + face_helper_1.face_parse = None + face_helper_1.face_parse = init_parsing_model( + model_name="bisenet", device=device, model_rootpath=os.path.join(model_path, "face_encoder") + ) + face_helper_2 = insightface.model_zoo.get_model( + f"{model_path}/face_encoder/models/antelopev2/glintr100.onnx", providers=["CUDAExecutionProvider"] + ) + face_helper_2.prepare(ctx_id=0) + + # get local facial extractor part 1 + model, _, _ = create_model_and_transforms( + "EVA02-CLIP-L-14-336", + os.path.join(model_path, "face_encoder", "EVA02_CLIP_L_336_psz14_s6B.pt"), + force_custom_clip=True, + ) + face_clip_model = model.visual + eva_transform_mean = getattr(face_clip_model, "image_mean", OPENAI_DATASET_MEAN) + eva_transform_std = getattr(face_clip_model, "image_std", OPENAI_DATASET_STD) + if not isinstance(eva_transform_mean, (list, tuple)): + eva_transform_mean = (eva_transform_mean,) * 3 + if not isinstance(eva_transform_std, (list, tuple)): + eva_transform_std = (eva_transform_std,) * 3 + eva_transform_mean = eva_transform_mean + eva_transform_std = eva_transform_std + + # get local facial extractor part 2 + face_main_model = FaceAnalysis( + name="antelopev2", root=os.path.join(model_path, "face_encoder"), providers=["CUDAExecutionProvider"] + ) + face_main_model.prepare(ctx_id=0, det_size=(640, 640)) + + # move face models to device + face_helper_1.face_det.eval() + face_helper_1.face_parse.eval() + face_clip_model.eval() + face_helper_1.face_det.to(device) + face_helper_1.face_parse.to(device) + face_clip_model.to(device, dtype=dtype) + + return face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std diff --git a/src/diffusers/pipelines/consisid/pipeline_consisid.py b/src/diffusers/pipelines/consisid/pipeline_consisid.py new file mode 100644 index 000000000000..0d4891cf17d7 --- /dev/null +++ b/src/diffusers/pipelines/consisid/pipeline_consisid.py @@ -0,0 +1,966 @@ +# Copyright 2024 ConsisID Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import cv2 +import numpy as np +import PIL +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import CogVideoXLoraLoaderMixin +from ...models import AutoencoderKLCogVideoX, ConsisIDTransformer3DModel +from ...models.embeddings import get_3d_rotary_pos_embed +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import CogVideoXDPMScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import ConsisIDPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import ConsisIDPipeline + >>> from diffusers.pipelines.consisid.consisid_utils import prepare_face_models, process_face_embeddings_infer + >>> from diffusers.utils import export_to_video + >>> from huggingface_hub import snapshot_download + + >>> snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir="BestWishYsh/ConsisID-preview") + >>> face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std = ( + ... prepare_face_models("BestWishYsh/ConsisID-preview", device="cuda", dtype=torch.bfloat16) + ... ) + >>> pipe = ConsisIDPipeline.from_pretrained("BestWishYsh/ConsisID-preview", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> # ConsisID works well with long and well-described prompts. Make sure the face in the image is clearly visible (e.g., preferably half-body or full-body). + >>> prompt = "The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy's path, adding depth to the scene. The lighting highlights the boy's subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel." + >>> image = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_input.png?download=true" + + >>> id_cond, id_vit_hidden, image, face_kps = process_face_embeddings_infer( + ... face_helper_1, + ... face_clip_model, + ... face_helper_2, + ... eva_transform_mean, + ... eva_transform_std, + ... face_main_model, + ... "cuda", + ... torch.bfloat16, + ... image, + ... is_align_face=True, + ... ) + + >>> video = pipe( + ... image=image, + ... prompt=prompt, + ... num_inference_steps=50, + ... guidance_scale=6.0, + ... use_dynamic_cfg=False, + ... id_vit_hidden=id_vit_hidden, + ... id_cond=id_cond, + ... kps_cond=face_kps, + ... generator=torch.Generator("cuda").manual_seed(42), + ... ) + >>> export_to_video(video.frames[0], "output.mp4", fps=8) + ``` +""" + + +def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]): + """ + This function draws keypoints and the limbs connecting them on an image. + + Parameters: + - image_pil (PIL.Image): Input image as a PIL object. + - kps (list of tuples): A list of keypoints where each keypoint is a tuple of (x, y) coordinates. + - color_list (list of tuples, optional): List of colors (in RGB format) for each keypoint. Default is a set of five + colors. + + Returns: + - PIL.Image: Image with the keypoints and limbs drawn. + """ + + stickwidth = 4 + limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]]) + kps = np.array(kps) + + w, h = image_pil.size + out_img = np.zeros([h, w, 3]) + + for i in range(len(limbSeq)): + index = limbSeq[i] + color = color_list[index[0]] + + x = kps[index][:, 0] + y = kps[index][:, 1] + length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1])) + polygon = cv2.ellipse2Poly( + (int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1 + ) + out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color) + out_img = (out_img * 0.6).astype(np.uint8) + + for idx_kp, kp in enumerate(kps): + color = color_list[idx_kp] + x, y = kp + out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1) + + out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8)) + return out_img_pil + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + """ + This function calculates the resize and crop region for an image to fit a target width and height while preserving + the aspect ratio. + + Parameters: + - src (tuple): A tuple containing the source image's height (h) and width (w). + - tgt_width (int): The target width to resize the image. + - tgt_height (int): The target height to resize the image. + + Returns: + - tuple: Two tuples representing the crop region: + 1. The top-left coordinates of the crop region. + 2. The bottom-right coordinates of the crop region. + """ + + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class ConsisIDPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): + r""" + Pipeline for image-to-video generation using ConsisID. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. ConsisID uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`ConsisIDTransformer3DModel`]): + A text conditioned `ConsisIDTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: ConsisIDTransformer3DModel, + scheduler: CogVideoXDPMScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + self.vae_scaling_factor_image = ( + self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, + image: torch.Tensor, + batch_size: int = 1, + num_channels_latents: int = 16, + num_frames: int = 13, + height: int = 60, + width: int = 90, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + kps_cond: Optional[torch.Tensor] = None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_frames, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + image = image.unsqueeze(2) # [B, C, F, H, W] + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size) + ] + if kps_cond is not None: + kps_cond = kps_cond.unsqueeze(2) + kps_cond_latents = [ + retrieve_latents(self.vae.encode(kps_cond[i].unsqueeze(0)), generator[i]) + for i in range(batch_size) + ] + else: + image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image] + if kps_cond is not None: + kps_cond = kps_cond.unsqueeze(2) + kps_cond_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in kps_cond] + + image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + image_latents = self.vae_scaling_factor_image * image_latents + + if kps_cond is not None: + kps_cond_latents = torch.cat(kps_cond_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + kps_cond_latents = self.vae_scaling_factor_image * kps_cond_latents + + padding_shape = ( + batch_size, + num_frames - 2, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + else: + padding_shape = ( + batch_size, + num_frames - 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype) + if kps_cond is not None: + image_latents = torch.cat([image_latents, kps_cond_latents, latent_padding], dim=1) + else: + image_latents = torch.cat([image_latents, latent_padding], dim=1) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents, image_latents + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae_scaling_factor_image * latents + + frames = self.vae.decode(latents).sample + return frames + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, timesteps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + image, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + latents=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_width = self.transformer.config.sample_width // self.transformer.config.patch_size + base_size_height = self.transformer.config.sample_height // self.transformer.config.patch_size + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + device=device, + ) + + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_frames: int = 49, + num_inference_steps: int = 50, + guidance_scale: float = 6.0, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + id_vit_hidden: Optional[torch.Tensor] = None, + id_cond: Optional[torch.Tensor] = None, + kps_cond: Optional[torch.Tensor] = None, + ) -> Union[ConsisIDPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The width in pixels of the generated image. This is set to 720 by default for the best results. + num_frames (`int`, defaults to `49`): + Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will + contain 1 extra frame because ConsisID is conditioned with (num_seconds * fps + 1) frames where + num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that + needs to be satisfied is that of divisibility mentioned above. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 6): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + use_dynamic_cfg (`bool`, *optional*, defaults to `False`): + If True, dynamically adjusts the guidance scale during inference. This allows the model to use a + progressive guidance scale, improving the balance between text-guided generation and image quality over + the course of the inference steps. Typically, early inference steps use a higher guidance scale for + more faithful image generation, while later steps reduce it for more diverse and natural results. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + id_vit_hidden (`Optional[torch.Tensor]`, *optional*): + The tensor representing the hidden features extracted from the face model, which are used to condition + the local facial extractor. This is crucial for the model to obtain high-frequency information of the + face. If not provided, the local facial extractor will not run normally. + id_cond (`Optional[torch.Tensor]`, *optional*): + The tensor representing the hidden features extracted from the clip model, which are used to condition + the local facial extractor. This is crucial for the model to edit facial features If not provided, the + local facial extractor will not run normally. + kps_cond (`Optional[torch.Tensor]`, *optional*): + A tensor that determines whether the global facial extractor use keypoint information for conditioning. + If provided, this tensor controls whether facial keypoints such as eyes, nose, and mouth landmarks are + used during the generation process. This helps ensure the model retains more facial low-frequency + information. + + Examples: + + Returns: + [`~pipelines.consisid.pipeline_output.ConsisIDPipelineOutput`] or `tuple`: + [`~pipelines.consisid.pipeline_output.ConsisIDPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial + num_frames = num_frames or self.transformer.config.sample_frames + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + image=image, + prompt=prompt, + height=height, + width=width, + negative_prompt=negative_prompt, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents + is_kps = getattr(self.transformer.config, "is_kps", False) + kps_cond = kps_cond if is_kps else None + if kps_cond is not None: + kps_cond = draw_kps(image, kps_cond) + kps_cond = self.video_processor.preprocess(kps_cond, height=height, width=width).to( + device, dtype=prompt_embeds.dtype + ) + + image = self.video_processor.preprocess(image, height=height, width=width).to( + device, dtype=prompt_embeds.dtype + ) + + latent_channels = self.transformer.config.in_channels // 2 + latents, image_latents = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + kps_cond, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + timesteps_cpu = timesteps.cpu() + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents + latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + id_vit_hidden=id_vit_hidden, + id_cond=id_cond, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + ( + 1 + - math.cos( + math.pi + * ((num_inference_steps - timesteps_cpu[i].item()) / num_inference_steps) ** 5.0 + ) + ) + / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return ConsisIDPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/consisid/pipeline_output.py b/src/diffusers/pipelines/consisid/pipeline_output.py new file mode 100644 index 000000000000..dd4a63aa50b9 --- /dev/null +++ b/src/diffusers/pipelines/consisid/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class ConsisIDPipelineOutput(BaseOutput): + r""" + Output class for ConsisID pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 4b6ac10385cf..183d6beb35c3 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -227,6 +227,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class ConsisIDTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class ConsistencyDecoderVAE(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 9b36be9e0604..b899915c3046 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -362,6 +362,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class ConsisIDPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class CycleDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/transformers/test_models_transformer_consisid.py b/tests/models/transformers/test_models_transformer_consisid.py new file mode 100644 index 000000000000..b848ed014074 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_consisid.py @@ -0,0 +1,105 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import ConsisIDTransformer3DModel +from diffusers.utils.testing_utils import ( + enable_full_determinism, + torch_device, +) + +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class ConsisIDTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = ConsisIDTransformer3DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + batch_size = 2 + num_channels = 4 + num_frames = 1 + height = 8 + width = 8 + embedding_dim = 8 + sequence_length = 8 + + hidden_states = torch.randn((batch_size, num_frames, num_channels, height, width)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + id_vit_hidden = [torch.ones([batch_size, 2, 2]).to(torch_device)] * 1 + id_cond = torch.ones(batch_size, 2).to(torch_device) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timestep, + "id_vit_hidden": id_vit_hidden, + "id_cond": id_cond, + } + + @property + def input_shape(self): + return (1, 4, 8, 8) + + @property + def output_shape(self): + return (1, 4, 8, 8) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "num_attention_heads": 2, + "attention_head_dim": 8, + "in_channels": 4, + "out_channels": 4, + "time_embed_dim": 2, + "text_embed_dim": 8, + "num_layers": 1, + "sample_width": 8, + "sample_height": 8, + "sample_frames": 8, + "patch_size": 2, + "temporal_compression_ratio": 4, + "max_text_seq_length": 8, + "cross_attn_interval": 1, + "is_kps": False, + "is_train_face": True, + "cross_attn_dim_head": 1, + "cross_attn_num_heads": 1, + "LFE_id_dim": 2, + "LFE_vit_dim": 2, + "LFE_depth": 5, + "LFE_dim_head": 8, + "LFE_num_heads": 2, + "LFE_num_id_token": 1, + "LFE_num_querie": 1, + "LFE_output_dim": 10, + "LFE_ff_mult": 1, + "LFE_num_scale": 1, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"ConsisIDTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/pipelines/consisid/__init__.py b/tests/pipelines/consisid/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/consisid/test_consisid.py b/tests/pipelines/consisid/test_consisid.py new file mode 100644 index 000000000000..31f2bc024af6 --- /dev/null +++ b/tests/pipelines/consisid/test_consisid.py @@ -0,0 +1,359 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import inspect +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import AutoencoderKLCogVideoX, ConsisIDPipeline, ConsisIDTransformer3DModel, DDIMScheduler +from diffusers.utils import load_image +from diffusers.utils.testing_utils import ( + enable_full_determinism, + numpy_cosine_similarity_distance, + require_torch_gpu, + slow, + torch_device, +) + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import ( + PipelineTesterMixin, + to_np, +) + + +enable_full_determinism() + + +class ConsisIDPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = ConsisIDPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"}) + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = ConsisIDTransformer3DModel( + num_attention_heads=2, + attention_head_dim=16, + in_channels=8, + out_channels=4, + time_embed_dim=2, + text_embed_dim=32, + num_layers=1, + sample_width=2, + sample_height=2, + sample_frames=9, + patch_size=2, + temporal_compression_ratio=4, + max_text_seq_length=16, + use_rotary_positional_embeddings=True, + use_learned_positional_embeddings=True, + cross_attn_interval=1, + is_kps=False, + is_train_face=True, + cross_attn_dim_head=1, + cross_attn_num_heads=1, + LFE_id_dim=2, + LFE_vit_dim=2, + LFE_depth=5, + LFE_dim_head=8, + LFE_num_heads=2, + LFE_num_id_token=1, + LFE_num_querie=1, + LFE_output_dim=21, + LFE_ff_mult=1, + LFE_num_scale=1, + ) + + torch.manual_seed(0) + vae = AutoencoderKLCogVideoX( + in_channels=3, + out_channels=3, + down_block_types=( + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + ), + up_block_types=( + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + ), + block_out_channels=(8, 8, 8, 8), + latent_channels=4, + layers_per_block=1, + norm_num_groups=2, + temporal_compression_ratio=4, + ) + + torch.manual_seed(0) + scheduler = DDIMScheduler() + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + image_height = 16 + image_width = 16 + image = Image.new("RGB", (image_width, image_height)) + id_vit_hidden = [torch.ones([1, 2, 2])] * 1 + id_cond = torch.ones(1, 2) + inputs = { + "image": image, + "prompt": "dance monkey", + "negative_prompt": "", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "height": image_height, + "width": image_width, + "num_frames": 8, + "max_sequence_length": 16, + "id_vit_hidden": id_vit_hidden, + "id_cond": id_cond, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (8, 3, 16, 16)) + expected_video = torch.randn(8, 3, 16, 16) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + # Test passing in a subset + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + output = pipe(**inputs)[0] + + # Test passing in a everything + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3) + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_vae_tiling(self, expected_diff_max: float = 0.4): + generator_device = "cpu" + components = self.get_dummy_components() + + # The reason to modify it this way is because ConsisID Transformer limits the generation to resolutions used during initalization. + # This limitation comes from using learned positional embeddings which cannot be generated on-the-fly like sincos or RoPE embeddings. + # See the if-statement on "self.use_learned_positional_embeddings" in diffusers/models/embeddings.py + components["transformer"] = ConsisIDTransformer3DModel.from_config( + components["transformer"].config, + sample_height=16, + sample_width=16, + ) + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_overlap_factor_height=1 / 12, + tile_overlap_factor_width=1 / 12, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) + + +@slow +@require_torch_gpu +class ConsisIDPipelineIntegrationTests(unittest.TestCase): + prompt = "A painting of a squirrel eating a burger." + + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_consisid(self): + generator = torch.Generator("cpu").manual_seed(0) + + pipe = ConsisIDPipeline.from_pretrained("BestWishYsh/ConsisID-preview", torch_dtype=torch.bfloat16) + pipe.enable_model_cpu_offload() + + prompt = self.prompt + image = load_image("https://github.com/PKU-YuanGroup/ConsisID/blob/main/asserts/example_images/2.png?raw=true") + id_vit_hidden = [torch.ones([1, 2, 2])] * 1 + id_cond = torch.ones(1, 2) + + videos = pipe( + image=image, + prompt=prompt, + height=480, + width=720, + num_frames=16, + id_vit_hidden=id_vit_hidden, + id_cond=id_cond, + generator=generator, + num_inference_steps=1, + output_type="pt", + ).frames + + video = videos[0] + expected_video = torch.randn(1, 16, 480, 720, 3).numpy() + + max_diff = numpy_cosine_similarity_distance(video, expected_video) + assert max_diff < 1e-3, f"Max diff is too high. got {video}"