-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[DC-AE] Add the official Deep Compression Autoencoder code(32x,64x,128x compression ratio); #9708
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 94 commits
6e616a9
d2e187a
90e8939
825c975
3a44fa4
55b2615
6fb7fdb
c323e76
da7caa5
fb6d92a
5e63a1a
72cce2b
8f9b4e4
b7f68f9
6d96b95
3c3cc51
1448681
bf40fe8
dd7718a
19986a5
3481e23
0e818df
c6eb233
59de0a3
ea604a4
80dce02
1752afd
883bcf4
25ae389
96e844b
59b6e25
7ce9ff2
30d6308
cab56b1
b42bb54
2e04a99
b4f75f2
c82f828
22ea5fd
4f5cbb4
2f6bbad
4495783
4d3c026
e007057
d3d9c84
be9826c
20da201
5ed50e9
2d59056
c1c02a2
1f8a3b3
7b9d7e5
bf6c211
a2ec5f8
f5876c5
44034a6
6379241
77571a8
c4d0867
0bdb7ef
54e933b
babc9f5
3d5faaf
65edfa5
ca3ac4d
9ef7b59
074817c
64de66a
0bda5c5
eb64d52
4a224ce
30c3238
39a947c
68f817a
da834d5
632ad3b
d6c748c
46eb504
31f9fc6
6f29e2a
b6e8fba
f862bae
f9fce24
e594745
3c0b1ca
91057d4
67aa715
eda66e1
e3d33e6
cc97502
2b370df
94355ab
a191f07
116c049
b6e0aba
ec4e84f
dbae8f1
042c2a0
f2525b9
d3d224c
6122b84
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. --> | ||
|
||
# AutoencoderDC | ||
|
||
*The 2D Autoencoder model used in [SANA](https://huggingface.co/papers/2410.10629) and introduced in [DCAE](https://huggingface.co/papers/2410.10733) by authors Junyu Chen\*, Han Cai\*, Junsong Chen, Enze Xie, Shang Yang, Haotian Tang, Muyang Li, Yao Lu, Song Han from MIT HAN Lab.* | ||
|
||
The following DCAE models are released and supported in Diffusers: | ||
|
||
| diffusers format | original format | | ||
a-r-r-o-w marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|:----------------:|:---------------:| | ||
| [`mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers) | [`mit-han-lab/dc-ae-f32c32-sana-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0) | ||
| [`mit-han-lab/dc-ae-f32c32-in-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-in-1.0-diffusers) | [`mit-han-lab/dc-ae-f32c32-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-in-1.0) | ||
| [`mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers) | [`mit-han-lab/dc-ae-f32c32-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-mix-1.0) | ||
| [`mit-han-lab/dc-ae-f64c128-in-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-in-1.0-diffusers) | [`mit-han-lab/dc-ae-f64c128-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-in-1.0) | ||
| [`mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers) | [`mit-han-lab/dc-ae-f64c128-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-mix-1.0) | ||
| [`mit-han-lab/dc-ae-f128c512-in-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0-diffusers) | [`mit-han-lab/dc-ae-f128c512-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0) | ||
| [`mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers) | [`mit-han-lab/dc-ae-f128c512-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-mix-1.0) | ||
|
||
The models can be loaded with the following code snippet. | ||
a-r-r-o-w marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
```python | ||
from diffusers import AutoencoderDC | ||
|
||
ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float32).to("cuda") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @lawrence-cj @chenjy2003 Will you be hosting the diffusers converted weights for the VAE as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yes. We will host the weight in the same place as the original ones. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cc: @chenjy2003 , please add our official diffusers version checkpoint into our huggingface collection. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure. We can hold the converted weight. By the way, I noticed @a-r-r-o-w 's previous comments.
Could you please elaborate more? Do you mean that it is not necessary to hold the converted weights? Thanks in advance for your answer. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Usually, whenever there is a conversion required to go from original implementation to diffusers version, we make sure to host the diffusers formats weights as well, so that it guarantees to work out of the box. With single file loading (already added support for it in this commit, and you can find the example in docs as well as here), one can load the original checkpoints directly as well. This, however, relies on the
Using this value, we choose the appropriate config for initalizing the AE and rename the state-dict on the fly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suppose what he mean is we can load the original weight using DCAE.from_single_file(xxx), but DCAE.from_pretrained(xxx) command still need a converted weight, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it. Thanks! |
||
``` | ||
|
||
## Single file loading | ||
|
||
The `AutoencoderDC` implementation supports loading checkpoints shipped in the original format by MIT HAN Lab. The following example demonstrates how to load the `f128c512` checkpoint: | ||
a-r-r-o-w marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
```python | ||
from diffusers import AutoencoderDC | ||
|
||
model_name = "dc-ae-f128c512-mix-1.0" | ||
ae = AutoencoderDC.from_single_file( | ||
f"https://huggingface.co/mit-han-lab/{model_name}/model.safetensors", | ||
original_config=f"https://huggingface.co/mit-han-lab/{model_name}/resolve/main/config.json" | ||
) | ||
``` | ||
|
||
## AutoencoderDC | ||
|
||
[[autodoc]] AutoencoderDC | ||
- decode | ||
- all | ||
|
||
## DecoderOutput | ||
|
||
[[autodoc]] models.autoencoders.vae.DecoderOutput | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,323 @@ | ||
import argparse | ||
from typing import Any, Dict | ||
|
||
import torch | ||
from huggingface_hub import hf_hub_download | ||
from safetensors.torch import load_file | ||
|
||
from diffusers import AutoencoderDC | ||
|
||
|
||
def remap_qkv_(key: str, state_dict: Dict[str, Any]): | ||
qkv = state_dict.pop(key) | ||
q, k, v = torch.chunk(qkv, 3, dim=0) | ||
parent_module, _, _ = key.rpartition(".qkv.conv.weight") | ||
state_dict[f"{parent_module}.to_q.weight"] = q.squeeze() | ||
state_dict[f"{parent_module}.to_k.weight"] = k.squeeze() | ||
state_dict[f"{parent_module}.to_v.weight"] = v.squeeze() | ||
|
||
|
||
def remap_proj_conv_(key: str, state_dict: Dict[str, Any]): | ||
parent_module, _, _ = key.rpartition(".proj.conv.weight") | ||
state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze() | ||
|
||
|
||
AE_KEYS_RENAME_DICT = { | ||
# common | ||
"main.": "", | ||
"op_list.": "", | ||
"context_module": "attn", | ||
"local_module": "conv_out", | ||
# NOTE: The below two lines work because scales in the available configs only have a tuple length of 1 | ||
# If there were more scales, there would be more layers, so a loop would be better to handle this | ||
"aggreg.0.0": "to_qkv_multiscale.0.proj_in", | ||
"aggreg.0.1": "to_qkv_multiscale.0.proj_out", | ||
"depth_conv.conv": "conv_depth", | ||
"inverted_conv.conv": "conv_inverted", | ||
"point_conv.conv": "conv_point", | ||
"point_conv.norm": "norm", | ||
"conv.conv.": "conv.", | ||
"conv1.conv": "conv1", | ||
"conv2.conv": "conv2", | ||
"conv2.norm": "norm", | ||
"proj.norm": "norm_out", | ||
# encoder | ||
"encoder.project_in.conv": "encoder.conv_in", | ||
"encoder.project_out.0.conv": "encoder.conv_out", | ||
"encoder.stages": "encoder.down_blocks", | ||
# decoder | ||
"decoder.project_in.conv": "decoder.conv_in", | ||
"decoder.project_out.0": "decoder.norm_out", | ||
"decoder.project_out.2.conv": "decoder.conv_out", | ||
"decoder.stages": "decoder.up_blocks", | ||
} | ||
|
||
AE_F32C32_KEYS = { | ||
# encoder | ||
"encoder.project_in.conv": "encoder.conv_in.conv", | ||
# decoder | ||
"decoder.project_out.2.conv": "decoder.conv_out.conv", | ||
} | ||
|
||
AE_F64C128_KEYS = { | ||
# encoder | ||
"encoder.project_in.conv": "encoder.conv_in.conv", | ||
# decoder | ||
"decoder.project_out.2.conv": "decoder.conv_out.conv", | ||
} | ||
|
||
AE_F128C512_KEYS = { | ||
# encoder | ||
"encoder.project_in.conv": "encoder.conv_in.conv", | ||
# decoder | ||
"decoder.project_out.2.conv": "decoder.conv_out.conv", | ||
} | ||
|
||
AE_SPECIAL_KEYS_REMAP = { | ||
"qkv.conv.weight": remap_qkv_, | ||
"proj.conv.weight": remap_proj_conv_, | ||
} | ||
|
||
|
||
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: | ||
state_dict = saved_dict | ||
if "model" in saved_dict.keys(): | ||
state_dict = state_dict["model"] | ||
if "module" in saved_dict.keys(): | ||
state_dict = state_dict["module"] | ||
if "state_dict" in saved_dict.keys(): | ||
state_dict = state_dict["state_dict"] | ||
return state_dict | ||
|
||
|
||
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: | ||
state_dict[new_key] = state_dict.pop(old_key) | ||
|
||
|
||
def convert_ae(config_name: str, dtype: torch.dtype): | ||
config = get_ae_config(config_name) | ||
hub_id = f"mit-han-lab/{config_name}" | ||
ckpt_path = hf_hub_download(hub_id, "model.safetensors") | ||
original_state_dict = get_state_dict(load_file(ckpt_path)) | ||
|
||
ae = AutoencoderDC(**config).to(dtype=dtype) | ||
|
||
for key in list(original_state_dict.keys()): | ||
new_key = key[:] | ||
for replace_key, rename_key in AE_KEYS_RENAME_DICT.items(): | ||
new_key = new_key.replace(replace_key, rename_key) | ||
update_state_dict_(original_state_dict, key, new_key) | ||
|
||
for key in list(original_state_dict.keys()): | ||
for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items(): | ||
if special_key not in key: | ||
continue | ||
handler_fn_inplace(key, original_state_dict) | ||
|
||
ae.load_state_dict(original_state_dict, strict=True) | ||
return ae | ||
|
||
|
||
def get_ae_config(name: str): | ||
if name in ["dc-ae-f32c32-sana-1.0"]: | ||
config = { | ||
"latent_channels": 32, | ||
"encoder_block_types": ( | ||
"ResBlock", | ||
"ResBlock", | ||
"ResBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
), | ||
"decoder_block_types": ( | ||
"ResBlock", | ||
"ResBlock", | ||
"ResBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
), | ||
"encoder_block_out_channels": (128, 256, 512, 512, 1024, 1024), | ||
"decoder_block_out_channels": (128, 256, 512, 512, 1024, 1024), | ||
"encoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)), | ||
"decoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)), | ||
"encoder_layers_per_block": (2, 2, 2, 3, 3, 3), | ||
"decoder_layers_per_block": [3, 3, 3, 3, 3, 3], | ||
"downsample_block_type": "conv", | ||
"upsample_block_type": "interpolate", | ||
"decoder_norm_types": "rms_norm", | ||
"decoder_act_fns": "silu", | ||
"scaling_factor": 0.41407, | ||
} | ||
elif name in ["dc-ae-f32c32-in-1.0", "dc-ae-f32c32-mix-1.0"]: | ||
AE_KEYS_RENAME_DICT.update(AE_F32C32_KEYS) | ||
config = { | ||
"latent_channels": 32, | ||
"encoder_block_types": [ | ||
"ResBlock", | ||
"ResBlock", | ||
"ResBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
], | ||
"decoder_block_types": [ | ||
"ResBlock", | ||
"ResBlock", | ||
"ResBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
], | ||
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024], | ||
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024], | ||
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2], | ||
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2], | ||
"encoder_qkv_multiscales": ((), (), (), (), (), ()), | ||
"decoder_qkv_multiscales": ((), (), (), (), (), ()), | ||
"decoder_norm_types": ["batch_norm", "batch_norm", "batch_norm", "rms_norm", "rms_norm", "rms_norm"], | ||
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu"], | ||
} | ||
if name == "dc-ae-f32c32-in-1.0": | ||
config["scaling_factor"] = 0.3189 | ||
elif name == "dc-ae-f32c32-mix-1.0": | ||
config["scaling_factor"] = 0.4552 | ||
elif name in ["dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0"]: | ||
AE_KEYS_RENAME_DICT.update(AE_F64C128_KEYS) | ||
config = { | ||
"latent_channels": 128, | ||
"encoder_block_types": [ | ||
"ResBlock", | ||
"ResBlock", | ||
"ResBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
], | ||
"decoder_block_types": [ | ||
"ResBlock", | ||
"ResBlock", | ||
"ResBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
], | ||
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048], | ||
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048], | ||
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2], | ||
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2], | ||
"encoder_qkv_multiscales": ((), (), (), (), (), (), ()), | ||
"decoder_qkv_multiscales": ((), (), (), (), (), (), ()), | ||
"decoder_norm_types": [ | ||
"batch_norm", | ||
"batch_norm", | ||
"batch_norm", | ||
"rms_norm", | ||
"rms_norm", | ||
"rms_norm", | ||
"rms_norm", | ||
], | ||
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu"], | ||
} | ||
if name == "dc-ae-f64c128-in-1.0": | ||
config["scaling_factor"] = 0.2889 | ||
elif name == "dc-ae-f64c128-mix-1.0": | ||
config["scaling_factor"] = 0.4538 | ||
elif name in ["dc-ae-f128c512-in-1.0", "dc-ae-f128c512-mix-1.0"]: | ||
AE_KEYS_RENAME_DICT.update(AE_F128C512_KEYS) | ||
config = { | ||
"latent_channels": 512, | ||
"encoder_block_types": [ | ||
"ResBlock", | ||
"ResBlock", | ||
"ResBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
], | ||
"decoder_block_types": [ | ||
"ResBlock", | ||
"ResBlock", | ||
"ResBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
], | ||
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048], | ||
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048], | ||
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2, 2], | ||
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2, 2], | ||
"encoder_qkv_multiscales": ((), (), (), (), (), (), (), ()), | ||
"decoder_qkv_multiscales": ((), (), (), (), (), (), (), ()), | ||
"decoder_norm_types": [ | ||
"batch_norm", | ||
"batch_norm", | ||
"batch_norm", | ||
"rms_norm", | ||
"rms_norm", | ||
"rms_norm", | ||
"rms_norm", | ||
"rms_norm", | ||
], | ||
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu", "silu"], | ||
} | ||
if name == "dc-ae-f128c512-in-1.0": | ||
config["scaling_factor"] = 0.4883 | ||
elif name == "dc-ae-f128c512-mix-1.0": | ||
config["scaling_factor"] = 0.3620 | ||
else: | ||
raise ValueError("Invalid config name provided.") | ||
|
||
return config | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--config_name", | ||
type=str, | ||
default="dc-ae-f32c32-sana-1.0", | ||
choices=[ | ||
"dc-ae-f32c32-sana-1.0", | ||
"dc-ae-f32c32-in-1.0", | ||
"dc-ae-f32c32-mix-1.0", | ||
"dc-ae-f64c128-in-1.0", | ||
"dc-ae-f64c128-mix-1.0", | ||
"dc-ae-f128c512-in-1.0", | ||
"dc-ae-f128c512-mix-1.0", | ||
], | ||
help="The DCAE checkpoint to convert", | ||
) | ||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") | ||
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.") | ||
return parser.parse_args() | ||
|
||
|
||
DTYPE_MAPPING = { | ||
"fp32": torch.float32, | ||
"fp16": torch.float16, | ||
"bf16": torch.bfloat16, | ||
} | ||
|
||
VARIANT_MAPPING = { | ||
"fp32": None, | ||
"fp16": "fp16", | ||
"bf16": "bf16", | ||
} | ||
|
||
|
||
if __name__ == "__main__": | ||
args = get_args() | ||
|
||
dtype = DTYPE_MAPPING[args.dtype] | ||
variant = VARIANT_MAPPING[args.dtype] | ||
|
||
ae = convert_ae(args.config_name, dtype) | ||
ae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant) |
Uh oh!
There was an error while loading. Please reload this page.