Skip to content

Commit 1d9eb2c

Browse files
lawrence-cjchenjy2003yiyixuxua-r-r-o-wstevhliu
authored andcommitted
[DC-AE] Add the official Deep Compression Autoencoder code(32x,64x,128x compression ratio); (#9708)
* first add a script for DC-AE; * DC-AE init * replace triton with custom implementation * 1. rename file and remove un-used codes; * no longer rely on omegaconf and dataclass * replace custom activation with diffuers activation * remove dc_ae attention in attention_processor.py * iinherit from ModelMixin * inherit from ConfigMixin * dc-ae reduce to one file * update downsample and upsample * clean code * support DecoderOutput * remove get_same_padding and val2tuple * remove autocast and some assert * update ResBlock * remove contents within super().__init__ * Update src/diffusers/models/autoencoders/dc_ae.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * remove opsequential * update other blocks to support the removal of build_norm * remove build encoder/decoder project in/out * remove inheritance of RMSNorm2d from LayerNorm * remove reset_parameters for RMSNorm2d Co-authored-by: YiYi Xu <yixu310@gmail.com> * remove device and dtype in RMSNorm2d __init__ Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/models/autoencoders/dc_ae.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/models/autoencoders/dc_ae.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/models/autoencoders/dc_ae.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * remove op_list & build_block * remove build_stage_main * change file name to autoencoder_dc * move LiteMLA to attention.py * align with other vae decode output; * add DC-AE into init files; * update * make quality && make style; * quick push before dgx disappears again * update * make style * update * update * fix * refactor * refactor * refactor * update * possibly change to nn.Linear * refactor * make fix-copies * replace vae with ae * replace get_block_from_block_type to get_block * replace downsample_block_type from Conv to conv for consistency * add scaling factors * incorporate changes for all checkpoints * make style * move mla to attention processor file; split qkv conv to linears * refactor * add tests * from original file loader * add docs * add standard autoencoder methods * combine attention processor * fix tests * update * minor fix * minor fix * minor fix & in/out shortcut rename * minor fix * make style * fix paper link * update docs * update single file loading * make style * remove single file loading support; todo for DN6 * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * add abstract --------- Co-authored-by: Junyu Chen <chenjydl2003@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com> Co-authored-by: chenjy2003 <70215701+chenjy2003@users.noreply.github.com> Co-authored-by: Aryan <aryan@huggingface.co> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
1 parent d54c70b commit 1d9eb2c

File tree

12 files changed

+1322
-3
lines changed

12 files changed

+1322
-3
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,8 @@
314314
title: AutoencoderKLMochi
315315
- local: api/models/asymmetricautoencoderkl
316316
title: AsymmetricAutoencoderKL
317+
- local: api/models/autoencoder_dc
318+
title: AutoencoderDC
317319
- local: api/models/consistency_decoder_vae
318320
title: ConsistencyDecoderVAE
319321
- local: api/models/autoencoder_oobleck
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# AutoencoderDC
13+
14+
The 2D Autoencoder model used in [SANA](https://huggingface.co/papers/2410.10629) and introduced in [DCAE](https://huggingface.co/papers/2410.10733) by authors Junyu Chen\*, Han Cai\*, Junsong Chen, Enze Xie, Shang Yang, Haotian Tang, Muyang Li, Yao Lu, Song Han from MIT HAN Lab.
15+
16+
The abstract from the paper is:
17+
18+
*We present Deep Compression Autoencoder (DC-AE), a new family of autoencoder models for accelerating high-resolution diffusion models. Existing autoencoder models have demonstrated impressive results at a moderate spatial compression ratio (e.g., 8x), but fail to maintain satisfactory reconstruction accuracy for high spatial compression ratios (e.g., 64x). We address this challenge by introducing two key techniques: (1) Residual Autoencoding, where we design our models to learn residuals based on the space-to-channel transformed features to alleviate the optimization difficulty of high spatial-compression autoencoders; (2) Decoupled High-Resolution Adaptation, an efficient decoupled three-phases training strategy for mitigating the generalization penalty of high spatial-compression autoencoders. With these designs, we improve the autoencoder's spatial compression ratio up to 128 while maintaining the reconstruction quality. Applying our DC-AE to latent diffusion models, we achieve significant speedup without accuracy drop. For example, on ImageNet 512x512, our DC-AE provides 19.1x inference speedup and 17.9x training speedup on H100 GPU for UViT-H while achieving a better FID, compared with the widely used SD-VAE-f8 autoencoder. Our code is available at [this https URL](https://github.com/mit-han-lab/efficientvit).*
19+
20+
The following DCAE models are released and supported in Diffusers.
21+
22+
| Diffusers format | Original format |
23+
|:----------------:|:---------------:|
24+
| [`mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers) | [`mit-han-lab/dc-ae-f32c32-sana-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0)
25+
| [`mit-han-lab/dc-ae-f32c32-in-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-in-1.0-diffusers) | [`mit-han-lab/dc-ae-f32c32-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-in-1.0)
26+
| [`mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers) | [`mit-han-lab/dc-ae-f32c32-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-mix-1.0)
27+
| [`mit-han-lab/dc-ae-f64c128-in-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-in-1.0-diffusers) | [`mit-han-lab/dc-ae-f64c128-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-in-1.0)
28+
| [`mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers) | [`mit-han-lab/dc-ae-f64c128-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-mix-1.0)
29+
| [`mit-han-lab/dc-ae-f128c512-in-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0-diffusers) | [`mit-han-lab/dc-ae-f128c512-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0)
30+
| [`mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers) | [`mit-han-lab/dc-ae-f128c512-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-mix-1.0)
31+
32+
Load a model in Diffusers format with [`~ModelMixin.from_pretrained`].
33+
34+
```python
35+
from diffusers import AutoencoderDC
36+
37+
ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float32).to("cuda")
38+
```
39+
40+
## AutoencoderDC
41+
42+
[[autodoc]] AutoencoderDC
43+
- encode
44+
- decode
45+
- all
46+
47+
## DecoderOutput
48+
49+
[[autodoc]] models.autoencoders.vae.DecoderOutput
50+

scripts/convert_dcae_to_diffusers.py

Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
import argparse
2+
from typing import Any, Dict
3+
4+
import torch
5+
from huggingface_hub import hf_hub_download
6+
from safetensors.torch import load_file
7+
8+
from diffusers import AutoencoderDC
9+
10+
11+
def remap_qkv_(key: str, state_dict: Dict[str, Any]):
12+
qkv = state_dict.pop(key)
13+
q, k, v = torch.chunk(qkv, 3, dim=0)
14+
parent_module, _, _ = key.rpartition(".qkv.conv.weight")
15+
state_dict[f"{parent_module}.to_q.weight"] = q.squeeze()
16+
state_dict[f"{parent_module}.to_k.weight"] = k.squeeze()
17+
state_dict[f"{parent_module}.to_v.weight"] = v.squeeze()
18+
19+
20+
def remap_proj_conv_(key: str, state_dict: Dict[str, Any]):
21+
parent_module, _, _ = key.rpartition(".proj.conv.weight")
22+
state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze()
23+
24+
25+
AE_KEYS_RENAME_DICT = {
26+
# common
27+
"main.": "",
28+
"op_list.": "",
29+
"context_module": "attn",
30+
"local_module": "conv_out",
31+
# NOTE: The below two lines work because scales in the available configs only have a tuple length of 1
32+
# If there were more scales, there would be more layers, so a loop would be better to handle this
33+
"aggreg.0.0": "to_qkv_multiscale.0.proj_in",
34+
"aggreg.0.1": "to_qkv_multiscale.0.proj_out",
35+
"depth_conv.conv": "conv_depth",
36+
"inverted_conv.conv": "conv_inverted",
37+
"point_conv.conv": "conv_point",
38+
"point_conv.norm": "norm",
39+
"conv.conv.": "conv.",
40+
"conv1.conv": "conv1",
41+
"conv2.conv": "conv2",
42+
"conv2.norm": "norm",
43+
"proj.norm": "norm_out",
44+
# encoder
45+
"encoder.project_in.conv": "encoder.conv_in",
46+
"encoder.project_out.0.conv": "encoder.conv_out",
47+
"encoder.stages": "encoder.down_blocks",
48+
# decoder
49+
"decoder.project_in.conv": "decoder.conv_in",
50+
"decoder.project_out.0": "decoder.norm_out",
51+
"decoder.project_out.2.conv": "decoder.conv_out",
52+
"decoder.stages": "decoder.up_blocks",
53+
}
54+
55+
AE_F32C32_KEYS = {
56+
# encoder
57+
"encoder.project_in.conv": "encoder.conv_in.conv",
58+
# decoder
59+
"decoder.project_out.2.conv": "decoder.conv_out.conv",
60+
}
61+
62+
AE_F64C128_KEYS = {
63+
# encoder
64+
"encoder.project_in.conv": "encoder.conv_in.conv",
65+
# decoder
66+
"decoder.project_out.2.conv": "decoder.conv_out.conv",
67+
}
68+
69+
AE_F128C512_KEYS = {
70+
# encoder
71+
"encoder.project_in.conv": "encoder.conv_in.conv",
72+
# decoder
73+
"decoder.project_out.2.conv": "decoder.conv_out.conv",
74+
}
75+
76+
AE_SPECIAL_KEYS_REMAP = {
77+
"qkv.conv.weight": remap_qkv_,
78+
"proj.conv.weight": remap_proj_conv_,
79+
}
80+
81+
82+
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
83+
state_dict = saved_dict
84+
if "model" in saved_dict.keys():
85+
state_dict = state_dict["model"]
86+
if "module" in saved_dict.keys():
87+
state_dict = state_dict["module"]
88+
if "state_dict" in saved_dict.keys():
89+
state_dict = state_dict["state_dict"]
90+
return state_dict
91+
92+
93+
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
94+
state_dict[new_key] = state_dict.pop(old_key)
95+
96+
97+
def convert_ae(config_name: str, dtype: torch.dtype):
98+
config = get_ae_config(config_name)
99+
hub_id = f"mit-han-lab/{config_name}"
100+
ckpt_path = hf_hub_download(hub_id, "model.safetensors")
101+
original_state_dict = get_state_dict(load_file(ckpt_path))
102+
103+
ae = AutoencoderDC(**config).to(dtype=dtype)
104+
105+
for key in list(original_state_dict.keys()):
106+
new_key = key[:]
107+
for replace_key, rename_key in AE_KEYS_RENAME_DICT.items():
108+
new_key = new_key.replace(replace_key, rename_key)
109+
update_state_dict_(original_state_dict, key, new_key)
110+
111+
for key in list(original_state_dict.keys()):
112+
for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items():
113+
if special_key not in key:
114+
continue
115+
handler_fn_inplace(key, original_state_dict)
116+
117+
ae.load_state_dict(original_state_dict, strict=True)
118+
return ae
119+
120+
121+
def get_ae_config(name: str):
122+
if name in ["dc-ae-f32c32-sana-1.0"]:
123+
config = {
124+
"latent_channels": 32,
125+
"encoder_block_types": (
126+
"ResBlock",
127+
"ResBlock",
128+
"ResBlock",
129+
"EfficientViTBlock",
130+
"EfficientViTBlock",
131+
"EfficientViTBlock",
132+
),
133+
"decoder_block_types": (
134+
"ResBlock",
135+
"ResBlock",
136+
"ResBlock",
137+
"EfficientViTBlock",
138+
"EfficientViTBlock",
139+
"EfficientViTBlock",
140+
),
141+
"encoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
142+
"decoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
143+
"encoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
144+
"decoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
145+
"encoder_layers_per_block": (2, 2, 2, 3, 3, 3),
146+
"decoder_layers_per_block": [3, 3, 3, 3, 3, 3],
147+
"downsample_block_type": "conv",
148+
"upsample_block_type": "interpolate",
149+
"decoder_norm_types": "rms_norm",
150+
"decoder_act_fns": "silu",
151+
"scaling_factor": 0.41407,
152+
}
153+
elif name in ["dc-ae-f32c32-in-1.0", "dc-ae-f32c32-mix-1.0"]:
154+
AE_KEYS_RENAME_DICT.update(AE_F32C32_KEYS)
155+
config = {
156+
"latent_channels": 32,
157+
"encoder_block_types": [
158+
"ResBlock",
159+
"ResBlock",
160+
"ResBlock",
161+
"EfficientViTBlock",
162+
"EfficientViTBlock",
163+
"EfficientViTBlock",
164+
],
165+
"decoder_block_types": [
166+
"ResBlock",
167+
"ResBlock",
168+
"ResBlock",
169+
"EfficientViTBlock",
170+
"EfficientViTBlock",
171+
"EfficientViTBlock",
172+
],
173+
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024],
174+
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024],
175+
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2],
176+
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2],
177+
"encoder_qkv_multiscales": ((), (), (), (), (), ()),
178+
"decoder_qkv_multiscales": ((), (), (), (), (), ()),
179+
"decoder_norm_types": ["batch_norm", "batch_norm", "batch_norm", "rms_norm", "rms_norm", "rms_norm"],
180+
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu"],
181+
}
182+
if name == "dc-ae-f32c32-in-1.0":
183+
config["scaling_factor"] = 0.3189
184+
elif name == "dc-ae-f32c32-mix-1.0":
185+
config["scaling_factor"] = 0.4552
186+
elif name in ["dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0"]:
187+
AE_KEYS_RENAME_DICT.update(AE_F64C128_KEYS)
188+
config = {
189+
"latent_channels": 128,
190+
"encoder_block_types": [
191+
"ResBlock",
192+
"ResBlock",
193+
"ResBlock",
194+
"EfficientViTBlock",
195+
"EfficientViTBlock",
196+
"EfficientViTBlock",
197+
"EfficientViTBlock",
198+
],
199+
"decoder_block_types": [
200+
"ResBlock",
201+
"ResBlock",
202+
"ResBlock",
203+
"EfficientViTBlock",
204+
"EfficientViTBlock",
205+
"EfficientViTBlock",
206+
"EfficientViTBlock",
207+
],
208+
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
209+
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
210+
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2],
211+
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2],
212+
"encoder_qkv_multiscales": ((), (), (), (), (), (), ()),
213+
"decoder_qkv_multiscales": ((), (), (), (), (), (), ()),
214+
"decoder_norm_types": [
215+
"batch_norm",
216+
"batch_norm",
217+
"batch_norm",
218+
"rms_norm",
219+
"rms_norm",
220+
"rms_norm",
221+
"rms_norm",
222+
],
223+
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu"],
224+
}
225+
if name == "dc-ae-f64c128-in-1.0":
226+
config["scaling_factor"] = 0.2889
227+
elif name == "dc-ae-f64c128-mix-1.0":
228+
config["scaling_factor"] = 0.4538
229+
elif name in ["dc-ae-f128c512-in-1.0", "dc-ae-f128c512-mix-1.0"]:
230+
AE_KEYS_RENAME_DICT.update(AE_F128C512_KEYS)
231+
config = {
232+
"latent_channels": 512,
233+
"encoder_block_types": [
234+
"ResBlock",
235+
"ResBlock",
236+
"ResBlock",
237+
"EfficientViTBlock",
238+
"EfficientViTBlock",
239+
"EfficientViTBlock",
240+
"EfficientViTBlock",
241+
"EfficientViTBlock",
242+
],
243+
"decoder_block_types": [
244+
"ResBlock",
245+
"ResBlock",
246+
"ResBlock",
247+
"EfficientViTBlock",
248+
"EfficientViTBlock",
249+
"EfficientViTBlock",
250+
"EfficientViTBlock",
251+
"EfficientViTBlock",
252+
],
253+
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
254+
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
255+
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2, 2],
256+
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2, 2],
257+
"encoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
258+
"decoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
259+
"decoder_norm_types": [
260+
"batch_norm",
261+
"batch_norm",
262+
"batch_norm",
263+
"rms_norm",
264+
"rms_norm",
265+
"rms_norm",
266+
"rms_norm",
267+
"rms_norm",
268+
],
269+
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu", "silu"],
270+
}
271+
if name == "dc-ae-f128c512-in-1.0":
272+
config["scaling_factor"] = 0.4883
273+
elif name == "dc-ae-f128c512-mix-1.0":
274+
config["scaling_factor"] = 0.3620
275+
else:
276+
raise ValueError("Invalid config name provided.")
277+
278+
return config
279+
280+
281+
def get_args():
282+
parser = argparse.ArgumentParser()
283+
parser.add_argument(
284+
"--config_name",
285+
type=str,
286+
default="dc-ae-f32c32-sana-1.0",
287+
choices=[
288+
"dc-ae-f32c32-sana-1.0",
289+
"dc-ae-f32c32-in-1.0",
290+
"dc-ae-f32c32-mix-1.0",
291+
"dc-ae-f64c128-in-1.0",
292+
"dc-ae-f64c128-mix-1.0",
293+
"dc-ae-f128c512-in-1.0",
294+
"dc-ae-f128c512-mix-1.0",
295+
],
296+
help="The DCAE checkpoint to convert",
297+
)
298+
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
299+
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
300+
return parser.parse_args()
301+
302+
303+
DTYPE_MAPPING = {
304+
"fp32": torch.float32,
305+
"fp16": torch.float16,
306+
"bf16": torch.bfloat16,
307+
}
308+
309+
VARIANT_MAPPING = {
310+
"fp32": None,
311+
"fp16": "fp16",
312+
"bf16": "bf16",
313+
}
314+
315+
316+
if __name__ == "__main__":
317+
args = get_args()
318+
319+
dtype = DTYPE_MAPPING[args.dtype]
320+
variant = VARIANT_MAPPING[args.dtype]
321+
322+
ae = convert_ae(args.config_name, dtype)
323+
ae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
"AllegroTransformer3DModel",
8181
"AsymmetricAutoencoderKL",
8282
"AuraFlowTransformer2DModel",
83+
"AutoencoderDC",
8384
"AutoencoderKL",
8485
"AutoencoderKLAllegro",
8586
"AutoencoderKLCogVideoX",
@@ -572,6 +573,7 @@
572573
AllegroTransformer3DModel,
573574
AsymmetricAutoencoderKL,
574575
AuraFlowTransformer2DModel,
576+
AutoencoderDC,
575577
AutoencoderKL,
576578
AutoencoderKLAllegro,
577579
AutoencoderKLCogVideoX,

0 commit comments

Comments
 (0)