Skip to content

Commit 23b467c

Browse files
SHYuanBesthlkystevhliua-r-r-o-w
authored
[core] ConsisID (#10140)
* Update __init__.py * add consisid * update consisid * update consisid * make style * make_style * Update src/diffusers/pipelines/consisid/pipeline_consisid.py Co-authored-by: hlky <hlky@hlky.ac> * Update src/diffusers/pipelines/consisid/pipeline_consisid.py Co-authored-by: hlky <hlky@hlky.ac> * Update src/diffusers/pipelines/consisid/pipeline_consisid.py Co-authored-by: hlky <hlky@hlky.ac> * Update src/diffusers/pipelines/consisid/pipeline_consisid.py Co-authored-by: hlky <hlky@hlky.ac> * Update src/diffusers/pipelines/consisid/pipeline_consisid.py Co-authored-by: hlky <hlky@hlky.ac> * Update src/diffusers/pipelines/consisid/pipeline_consisid.py Co-authored-by: hlky <hlky@hlky.ac> * add doc * make style * Rename consisid .md to consisid.md * Update geodiff_molecule_conformation.ipynb * Update geodiff_molecule_conformation.ipynb * Update geodiff_molecule_conformation.ipynb * Update demo.ipynb * Update pipeline_consisid.py * make fix-copies * Update docs/source/en/using-diffusers/consisid.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update src/diffusers/pipelines/consisid/pipeline_consisid.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update src/diffusers/pipelines/consisid/pipeline_consisid.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/using-diffusers/consisid.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/using-diffusers/consisid.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * update doc & pipeline code * fix typo * make style * update example * Update docs/source/en/using-diffusers/consisid.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * update example * update example * Update src/diffusers/pipelines/consisid/pipeline_consisid.py Co-authored-by: hlky <hlky@hlky.ac> * Update src/diffusers/pipelines/consisid/pipeline_consisid.py Co-authored-by: hlky <hlky@hlky.ac> * update * add test and update * remove some changes from docs * refactor * fix * undo changes to examples * remove save/load and fuse methods * update * link hf-doc-img & make test extremely small * update * add lora * fix test * update * update * change expected_diff_max to 0.4 * fix typo * fix link * fix typo * update docs * update * remove consisid lora tests --------- Co-authored-by: hlky <hlky@hlky.ac> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Aryan <aryan@huggingface.co>
1 parent aeac0a0 commit 23b467c

File tree

21 files changed

+2988
-0
lines changed

21 files changed

+2988
-0
lines changed

docs/source/en/_toctree.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@
7979
- sections:
8080
- local: using-diffusers/cogvideox
8181
title: CogVideoX
82+
- local: using-diffusers/consisid
83+
title: ConsisID
8284
- local: using-diffusers/sdxl
8385
title: Stable Diffusion XL
8486
- local: using-diffusers/sdxl_turbo
@@ -270,6 +272,8 @@
270272
title: AuraFlowTransformer2DModel
271273
- local: api/models/cogvideox_transformer3d
272274
title: CogVideoXTransformer3DModel
275+
- local: api/models/consisid_transformer3d
276+
title: ConsisIDTransformer3DModel
273277
- local: api/models/cogview3plus_transformer2d
274278
title: CogView3PlusTransformer2DModel
275279
- local: api/models/dit_transformer2d
@@ -372,6 +376,8 @@
372376
title: CogVideoX
373377
- local: api/pipelines/cogview3
374378
title: CogView3
379+
- local: api/pipelines/consisid
380+
title: ConsisID
375381
- local: api/pipelines/consistency_models
376382
title: Consistency Models
377383
- local: api/pipelines/controlnet
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# ConsisIDTransformer3DModel
13+
14+
A Diffusion Transformer model for 3D data from [ConsisID](https://github.com/PKU-YuanGroup/ConsisID) was introduced in [Identity-Preserving Text-to-Video Generation by Frequency Decomposition](https://arxiv.org/pdf/2411.17440) by Peking University & University of Rochester & etc.
15+
16+
The model can be loaded with the following code snippet.
17+
18+
```python
19+
from diffusers import ConsisIDTransformer3DModel
20+
21+
transformer = ConsisIDTransformer3DModel.from_pretrained("BestWishYsh/ConsisID-preview", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
22+
```
23+
24+
## ConsisIDTransformer3DModel
25+
26+
[[autodoc]] ConsisIDTransformer3DModel
27+
28+
## Transformer2DModelOutput
29+
30+
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
-->
15+
16+
# ConsisID
17+
18+
[Identity-Preserving Text-to-Video Generation by Frequency Decomposition](https://arxiv.org/abs/2411.17440) from Peking University & University of Rochester & etc, by Shenghai Yuan, Jinfa Huang, Xianyi He, Yunyang Ge, Yujun Shi, Liuhan Chen, Jiebo Luo, Li Yuan.
19+
20+
The abstract from the paper is:
21+
22+
*Identity-preserving text-to-video (IPT2V) generation aims to create high-fidelity videos with consistent human identity. It is an important task in video generation but remains an open problem for generative models. This paper pushes the technical frontier of IPT2V in two directions that have not been resolved in the literature: (1) A tuning-free pipeline without tedious case-by-case finetuning, and (2) A frequency-aware heuristic identity-preserving Diffusion Transformer (DiT)-based control scheme. To achieve these goals, we propose **ConsisID**, a tuning-free DiT-based controllable IPT2V model to keep human-**id**entity **consis**tent in the generated video. Inspired by prior findings in frequency analysis of vision/diffusion transformers, it employs identity-control signals in the frequency domain, where facial features can be decomposed into low-frequency global features (e.g., profile, proportions) and high-frequency intrinsic features (e.g., identity markers that remain unaffected by pose changes). First, from a low-frequency perspective, we introduce a global facial extractor, which encodes the reference image and facial key points into a latent space, generating features enriched with low-frequency information. These features are then integrated into the shallow layers of the network to alleviate training challenges associated with DiT. Second, from a high-frequency perspective, we design a local facial extractor to capture high-frequency details and inject them into the transformer blocks, enhancing the model's ability to preserve fine-grained features. To leverage the frequency information for identity preservation, we propose a hierarchical training strategy, transforming a vanilla pre-trained video generation model into an IPT2V model. Extensive experiments demonstrate that our frequency-aware heuristic scheme provides an optimal control solution for DiT-based models. Thanks to this scheme, our **ConsisID** achieves excellent results in generating high-quality, identity-preserving videos, making strides towards more effective IPT2V. The model weight of ConsID is publicly available at https://github.com/PKU-YuanGroup/ConsisID.*
23+
24+
<Tip>
25+
26+
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
27+
28+
</Tip>
29+
30+
This pipeline was contributed by [SHYuanBest](https://github.com/SHYuanBest). The original codebase can be found [here](https://github.com/PKU-YuanGroup/ConsisID). The original weights can be found under [hf.co/BestWishYsh](https://huggingface.co/BestWishYsh).
31+
32+
There are two official ConsisID checkpoints for identity-preserving text-to-video.
33+
34+
| checkpoints | recommended inference dtype |
35+
|:---:|:---:|
36+
| [`BestWishYsh/ConsisID-preview`](https://huggingface.co/BestWishYsh/ConsisID-preview) | torch.bfloat16 |
37+
| [`BestWishYsh/ConsisID-1.5`](https://huggingface.co/BestWishYsh/ConsisID-preview) | torch.bfloat16 |
38+
39+
### Memory optimization
40+
41+
ConsisID requires about 44 GB of GPU memory to decode 49 frames (6 seconds of video at 8 FPS) with output resolution 720x480 (W x H), which makes it not possible to run on consumer GPUs or free-tier T4 Colab. The following memory optimizations could be used to reduce the memory footprint. For replication, you can refer to [this](https://gist.github.com/SHYuanBest/bc4207c36f454f9e969adbb50eaf8258) script.
42+
43+
| Feature (overlay the previous) | Max Memory Allocated | Max Memory Reserved |
44+
| :----------------------------- | :------------------- | :------------------ |
45+
| - | 37 GB | 44 GB |
46+
| enable_model_cpu_offload | 22 GB | 25 GB |
47+
| enable_sequential_cpu_offload | 16 GB | 22 GB |
48+
| vae.enable_slicing | 16 GB | 22 GB |
49+
| vae.enable_tiling | 5 GB | 7 GB |
50+
51+
## ConsisIDPipeline
52+
53+
[[autodoc]] ConsisIDPipeline
54+
55+
- all
56+
- __call__
57+
58+
## ConsisIDPipelineOutput
59+
60+
[[autodoc]] pipelines.consisid.pipeline_output.ConsisIDPipelineOutput
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
# ConsisID
13+
14+
[ConsisID](https://github.com/PKU-YuanGroup/ConsisID) is an identity-preserving text-to-video generation model that keeps the face consistent in the generated video by frequency decomposition. The main features of ConsisID are:
15+
16+
- Frequency decomposition: The characteristics of the DiT architecture are analyzed from the frequency domain perspective, and based on these characteristics, a reasonable control information injection method is designed.
17+
- Consistency training strategy: A coarse-to-fine training strategy, dynamic masking loss, and dynamic cross-face loss further enhance the model's generalization ability and identity preservation performance.
18+
- Inference without finetuning: Previous methods required case-by-case finetuning of the input ID before inference, leading to significant time and computational costs. In contrast, ConsisID is tuning-free.
19+
20+
This guide will walk you through using ConsisID for use cases.
21+
22+
## Load Model Checkpoints
23+
24+
Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method.
25+
26+
```python
27+
# !pip install consisid_eva_clip insightface facexlib
28+
import torch
29+
from diffusers import ConsisIDPipeline
30+
from diffusers.pipelines.consisid.consisid_utils import prepare_face_models, process_face_embeddings_infer
31+
from huggingface_hub import snapshot_download
32+
33+
# Download ckpts
34+
snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir="BestWishYsh/ConsisID-preview")
35+
36+
# Load face helper model to preprocess input face image
37+
face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std = prepare_face_models("BestWishYsh/ConsisID-preview", device="cuda", dtype=torch.bfloat16)
38+
39+
# Load consisid base model
40+
pipe = ConsisIDPipeline.from_pretrained("BestWishYsh/ConsisID-preview", torch_dtype=torch.bfloat16)
41+
pipe.to("cuda")
42+
```
43+
44+
## Identity-Preserving Text-to-Video
45+
46+
For identity-preserving text-to-video, pass a text prompt and an image contain clear face (e.g., preferably half-body or full-body). By default, ConsisID generates a 720x480 video for the best results.
47+
48+
```python
49+
from diffusers.utils import export_to_video
50+
51+
prompt = "The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy's path, adding depth to the scene. The lighting highlights the boy's subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel."
52+
image = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_input.png?download=true"
53+
54+
id_cond, id_vit_hidden, image, face_kps = process_face_embeddings_infer(face_helper_1, face_clip_model, face_helper_2, eva_transform_mean, eva_transform_std, face_main_model, "cuda", torch.bfloat16, image, is_align_face=True)
55+
56+
video = pipe(image=image, prompt=prompt, num_inference_steps=50, guidance_scale=6.0, use_dynamic_cfg=False, id_vit_hidden=id_vit_hidden, id_cond=id_cond, kps_cond=face_kps, generator=torch.Generator("cuda").manual_seed(42))
57+
export_to_video(video.frames[0], "output.mp4", fps=8)
58+
```
59+
<table>
60+
<tr>
61+
<th style="text-align: center;">Face Image</th>
62+
<th style="text-align: center;">Video</th>
63+
<th style="text-align: center;">Description</th
64+
</tr>
65+
<tr>
66+
<td><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_image_0.png?download=true" style="height: auto; width: 600px;"></td>
67+
<td><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_output_0.gif?download=true" style="height: auto; width: 2000px;"></td>
68+
<td>The video, in a beautifully crafted animated style, features a confident woman riding a horse through a lush forest clearing. Her expression is focused yet serene as she adjusts her wide-brimmed hat with a practiced hand. She wears a flowy bohemian dress, which moves gracefully with the rhythm of the horse, the fabric flowing fluidly in the animated motion. The dappled sunlight filters through the trees, casting soft, painterly patterns on the forest floor. Her posture is poised, showing both control and elegance as she guides the horse with ease. The animation's gentle, fluid style adds a dreamlike quality to the scene, with the woman’s calm demeanor and the peaceful surroundings evoking a sense of freedom and harmony.</td>
69+
</tr>
70+
<tr>
71+
<td><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_image_1.png?download=true" style="height: auto; width: 600px;"></td>
72+
<td><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_output_1.gif?download=true" style="height: auto; width: 2000px;"></td>
73+
<td>The video, in a captivating animated style, shows a woman standing in the center of a snowy forest, her eyes narrowed in concentration as she extends her hand forward. She is dressed in a deep blue cloak, her breath visible in the cold air, which is rendered with soft, ethereal strokes. A faint smile plays on her lips as she summons a wisp of ice magic, watching with focus as the surrounding trees and ground begin to shimmer and freeze, covered in delicate ice crystals. The animation’s fluid motion brings the magic to life, with the frost spreading outward in intricate, sparkling patterns. The environment is painted with soft, watercolor-like hues, enhancing the magical, dreamlike atmosphere. The overall mood is serene yet powerful, with the quiet winter air amplifying the delicate beauty of the frozen scene.</td>
74+
</tr>
75+
<tr>
76+
<td><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_image_2.png?download=true" style="height: auto; width: 600px;"></td>
77+
<td><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_output_2.gif?download=true" style="height: auto; width: 2000px;"></td>
78+
<td>The animation features a whimsical portrait of a balloon seller standing in a gentle breeze, captured with soft, hazy brushstrokes that evoke the feel of a serene spring day. His face is framed by a gentle smile, his eyes squinting slightly against the sun, while a few wisps of hair flutter in the wind. He is dressed in a light, pastel-colored shirt, and the balloons around him sway with the wind, adding a sense of playfulness to the scene. The background blurs softly, with hints of a vibrant market or park, enhancing the light-hearted, yet tender mood of the moment.</td>
79+
</tr>
80+
<tr>
81+
<td><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_image_3.png?download=true" style="height: auto; width: 600px;"></td>
82+
<td><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_output_3.gif?download=true" style="height: auto; width: 2000px;"></td>
83+
<td>The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy's path, adding depth to the scene. The lighting highlights the boy's subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel.</td>
84+
</tr>
85+
<tr>
86+
<td><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_image_4.png?download=true" style="height: auto; width: 600px;"></td>
87+
<td><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_output_4.gif?download=true" style="height: auto; width: 2000px;"></td>
88+
<td>The video features a baby wearing a bright superhero cape, standing confidently with arms raised in a powerful pose. The baby has a determined look on their face, with eyes wide and lips pursed in concentration, as if ready to take on a challenge. The setting appears playful, with colorful toys scattered around and a soft rug underfoot, while sunlight streams through a nearby window, highlighting the fluttering cape and adding to the impression of heroism. The overall atmosphere is lighthearted and fun, with the baby's expressions capturing a mix of innocence and an adorable attempt at bravery, as if truly ready to save the day.</td>
89+
</tr>
90+
</table>
91+
92+
## Resources
93+
94+
Learn more about ConsisID with the following resources.
95+
- A [video](https://www.youtube.com/watch?v=PhlgC-bI5SQ) demonstrating ConsisID's main features.
96+
- The research paper, [Identity-Preserving Text-to-Video Generation by Frequency Decomposition](https://hf.co/papers/2411.17440) for more details.

docs/source/zh/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
title: 快速入门
66
- local: stable_diffusion
77
title: 有效和高效的扩散
8+
- local: consisid
9+
title: 身份保持的文本到视频生成
810
- local: installation
911
title: 安装
1012
title: 开始

0 commit comments

Comments
 (0)