Skip to content

[DC-AE] Add the official Deep Compression Autoencoder code(32x,64x,128x compression ratio); #9708

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 101 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from 94 commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
6e616a9
first add a script for DC-AE;
lawrence-cj Oct 18, 2024
d2e187a
Merge remote-tracking branch 'upstream/main' into DC-AE
chenjy2003 Oct 23, 2024
90e8939
DC-AE init
chenjy2003 Oct 23, 2024
825c975
replace triton with custom implementation
chenjy2003 Oct 23, 2024
3a44fa4
1. rename file and remove un-used codes;
lawrence-cj Oct 23, 2024
55b2615
no longer rely on omegaconf and dataclass
chenjy2003 Oct 25, 2024
6fb7fdb
merge
chenjy2003 Oct 25, 2024
c323e76
Merge remote-tracking branch 'upstream/main' into DC-AE
chenjy2003 Oct 25, 2024
da7caa5
replace custom activation with diffuers activation
chenjy2003 Oct 25, 2024
fb6d92a
remove dc_ae attention in attention_processor.py
chenjy2003 Oct 25, 2024
5e63a1a
iinherit from ModelMixin
chenjy2003 Oct 25, 2024
72cce2b
inherit from ConfigMixin
chenjy2003 Oct 25, 2024
8f9b4e4
dc-ae reduce to one file
chenjy2003 Oct 31, 2024
b7f68f9
Merge remote-tracking branch 'upstream/main' into DC-AE
chenjy2003 Oct 31, 2024
6d96b95
Merge branch 'huggingface:main' into DC-AE
lawrence-cj Nov 4, 2024
3c3cc51
Merge remote-tracking branch 'refs/remotes/origin/main' into DC-AE
lawrence-cj Nov 6, 2024
1448681
update downsample and upsample
chenjy2003 Nov 9, 2024
bf40fe8
merge
chenjy2003 Nov 9, 2024
dd7718a
clean code
chenjy2003 Nov 9, 2024
19986a5
support DecoderOutput
chenjy2003 Nov 9, 2024
3481e23
Merge branch 'main' into DC-AE
lawrence-cj Nov 9, 2024
0e818df
Merge branch 'main' into DC-AE
lawrence-cj Nov 13, 2024
c6eb233
remove get_same_padding and val2tuple
chenjy2003 Nov 14, 2024
59de0a3
remove autocast and some assert
chenjy2003 Nov 14, 2024
ea604a4
update ResBlock
chenjy2003 Nov 14, 2024
80dce02
remove contents within super().__init__
chenjy2003 Nov 14, 2024
1752afd
Update src/diffusers/models/autoencoders/dc_ae.py
lawrence-cj Nov 16, 2024
883bcf4
remove opsequential
chenjy2003 Nov 16, 2024
25ae389
Merge branch 'DC-AE' of github.com:lawrence-cj/diffusers into DC-AE
chenjy2003 Nov 16, 2024
96e844b
update other blocks to support the removal of build_norm
chenjy2003 Nov 16, 2024
59b6e25
Merge branch 'main' into DC-AE
sayakpaul Nov 16, 2024
7ce9ff2
remove build encoder/decoder project in/out
chenjy2003 Nov 16, 2024
30d6308
Merge branch 'DC-AE' of github.com:lawrence-cj/diffusers into DC-AE
chenjy2003 Nov 16, 2024
cab56b1
remove inheritance of RMSNorm2d from LayerNorm
chenjy2003 Nov 16, 2024
b42bb54
remove reset_parameters for RMSNorm2d
chenjy2003 Nov 20, 2024
2e04a99
remove device and dtype in RMSNorm2d __init__
chenjy2003 Nov 20, 2024
b4f75f2
Update src/diffusers/models/autoencoders/dc_ae.py
lawrence-cj Nov 21, 2024
c82f828
Update src/diffusers/models/autoencoders/dc_ae.py
lawrence-cj Nov 21, 2024
22ea5fd
Update src/diffusers/models/autoencoders/dc_ae.py
lawrence-cj Nov 21, 2024
4f5cbb4
remove op_list & build_block
chenjy2003 Nov 26, 2024
2f6bbad
remove build_stage_main
chenjy2003 Nov 26, 2024
4495783
Merge branch 'main' into DC-AE
lawrence-cj Nov 26, 2024
4d3c026
change file name to autoencoder_dc
chenjy2003 Nov 28, 2024
e007057
Merge branch 'DC-AE' of github.com:lawrence-cj/diffusers into DC-AE
chenjy2003 Nov 28, 2024
d3d9c84
move LiteMLA to attention.py
chenjy2003 Nov 28, 2024
be9826c
align with other vae decode output;
lawrence-cj Nov 28, 2024
20da201
add DC-AE into init files;
lawrence-cj Nov 28, 2024
5ed50e9
update
a-r-r-o-w Nov 28, 2024
2d59056
make quality && make style;
lawrence-cj Nov 28, 2024
c1c02a2
quick push before dgx disappears again
a-r-r-o-w Nov 28, 2024
1f8a3b3
update
a-r-r-o-w Nov 28, 2024
7b9d7e5
make style
a-r-r-o-w Nov 28, 2024
bf6c211
update
a-r-r-o-w Nov 28, 2024
a2ec5f8
update
a-r-r-o-w Nov 28, 2024
f5876c5
fix
a-r-r-o-w Nov 28, 2024
44034a6
refactor
a-r-r-o-w Nov 29, 2024
6379241
refactor
a-r-r-o-w Nov 29, 2024
77571a8
refactor
a-r-r-o-w Nov 29, 2024
c4d0867
update
a-r-r-o-w Nov 30, 2024
0bdb7ef
possibly change to nn.Linear
a-r-r-o-w Nov 30, 2024
54e933b
refactor
a-r-r-o-w Nov 30, 2024
babc9f5
Merge branch 'main' into aryan-dcae
a-r-r-o-w Nov 30, 2024
3d5faaf
make fix-copies
a-r-r-o-w Nov 30, 2024
65edfa5
resolve conflicts & merge
a-r-r-o-w Dec 1, 2024
ca3ac4d
replace vae with ae
chenjy2003 Dec 3, 2024
9ef7b59
replace get_block_from_block_type to get_block
chenjy2003 Dec 3, 2024
074817c
replace downsample_block_type from Conv to conv for consistency
chenjy2003 Dec 3, 2024
64de66a
add scaling factors
chenjy2003 Dec 3, 2024
0bda5c5
incorporate changes for all checkpoints
a-r-r-o-w Dec 4, 2024
eb64d52
make style
a-r-r-o-w Dec 4, 2024
4a224ce
Merge branch 'main' into DC-AE
a-r-r-o-w Dec 4, 2024
30c3238
move mla to attention processor file; split qkv conv to linears
a-r-r-o-w Dec 4, 2024
39a947c
refactor
a-r-r-o-w Dec 4, 2024
68f817a
Merge branch 'main' into DC-AE
a-r-r-o-w Dec 4, 2024
da834d5
add tests
a-r-r-o-w Dec 4, 2024
632ad3b
Merge branch 'main' into DC-AE
lawrence-cj Dec 4, 2024
d6c748c
from original file loader
a-r-r-o-w Dec 4, 2024
46eb504
Merge branch 'main' into DC-AE
a-r-r-o-w Dec 4, 2024
31f9fc6
add docs
a-r-r-o-w Dec 4, 2024
6f29e2a
add standard autoencoder methods
a-r-r-o-w Dec 4, 2024
b6e8fba
combine attention processor
yiyixuxu Dec 4, 2024
f862bae
fix tests
a-r-r-o-w Dec 5, 2024
f9fce24
update
a-r-r-o-w Dec 5, 2024
e594745
Merge branch 'main' into DC-AE
a-r-r-o-w Dec 5, 2024
3c0b1ca
minor fix
chenjy2003 Dec 5, 2024
91057d4
minor fix
chenjy2003 Dec 5, 2024
67aa715
Merge branch 'main' into DC-AE
lawrence-cj Dec 5, 2024
eda66e1
minor fix & in/out shortcut rename
chenjy2003 Dec 5, 2024
e3d33e6
minor fix
chenjy2003 Dec 5, 2024
cc97502
Merge branch 'main' into DC-AE
a-r-r-o-w Dec 5, 2024
2b370df
make style
a-r-r-o-w Dec 5, 2024
94355ab
fix paper link
chenjy2003 Dec 6, 2024
a191f07
Merge branch 'main' into DC-AE
a-r-r-o-w Dec 6, 2024
116c049
update docs
a-r-r-o-w Dec 6, 2024
b6e0aba
update single file loading
a-r-r-o-w Dec 6, 2024
ec4e84f
Merge branch 'main' into DC-AE
a-r-r-o-w Dec 6, 2024
dbae8f1
make style
a-r-r-o-w Dec 6, 2024
042c2a0
remove single file loading support; todo for DN6
a-r-r-o-w Dec 6, 2024
f2525b9
Apply suggestions from code review
a-r-r-o-w Dec 6, 2024
d3d224c
Merge branch 'main' into DC-AE
a-r-r-o-w Dec 6, 2024
6122b84
add abstract
a-r-r-o-w Dec 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@
title: AutoencoderKLMochi
- local: api/models/asymmetricautoencoderkl
title: AsymmetricAutoencoderKL
- local: api/models/autoencoder_dc
title: AutoencoderDC
- local: api/models/consistency_decoder_vae
title: ConsistencyDecoderVAE
- local: api/models/autoencoder_oobleck
Expand Down
59 changes: 59 additions & 0 deletions docs/source/en/api/models/autoencoder_dc.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->

# AutoencoderDC

*The 2D Autoencoder model used in [SANA](https://huggingface.co/papers/2410.10629) and introduced in [DCAE](https://huggingface.co/papers/2410.10733) by authors Junyu Chen\*, Han Cai\*, Junsong Chen, Enze Xie, Shang Yang, Haotian Tang, Muyang Li, Yao Lu, Song Han from MIT HAN Lab.*

The following DCAE models are released and supported in Diffusers:

| diffusers format | original format |
|:----------------:|:---------------:|
| [`mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers) | [`mit-han-lab/dc-ae-f32c32-sana-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0)
| [`mit-han-lab/dc-ae-f32c32-in-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-in-1.0-diffusers) | [`mit-han-lab/dc-ae-f32c32-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-in-1.0)
| [`mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers) | [`mit-han-lab/dc-ae-f32c32-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-mix-1.0)
| [`mit-han-lab/dc-ae-f64c128-in-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-in-1.0-diffusers) | [`mit-han-lab/dc-ae-f64c128-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-in-1.0)
| [`mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers) | [`mit-han-lab/dc-ae-f64c128-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-mix-1.0)
| [`mit-han-lab/dc-ae-f128c512-in-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0-diffusers) | [`mit-han-lab/dc-ae-f128c512-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0)
| [`mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers) | [`mit-han-lab/dc-ae-f128c512-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-mix-1.0)

The models can be loaded with the following code snippet.

```python
from diffusers import AutoencoderDC

ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float32).to("cuda")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lawrence-cj @chenjy2003 Will you be hosting the diffusers converted weights for the VAE as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes. We will host the weight in the same place as the original ones.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cc: @chenjy2003 , please add our official diffusers version checkpoint into our huggingface collection.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. We can hold the converted weight. By the way, I noticed @a-r-r-o-w 's previous comments.

cc @DN6 I think we could benefit from having .from_single_file loading for the autoencoders from this PR, because the original checkpoints have a good amount of downloads.

Could you please elaborate more? Do you mean that it is not necessary to hold the converted weights? Thanks in advance for your answer.

Copy link
Member

@a-r-r-o-w a-r-r-o-w Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually, whenever there is a conversion required to go from original implementation to diffusers version, we make sure to host the diffusers formats weights as well, so that it guarantees to work out of the box.

With single file loading (already added support for it in this commit, and you can find the example in docs as well as here), one can load the original checkpoints directly as well. This, however, relies on the config.json file in the original repositories: https://huggingface.co/mit-han-lab/dc-ae-f128c512-mix-1.0/blob/main/config.json

{
  "model_name": "dc-ae-f128c512-mix-1.0"
}

Using this value, we choose the appropriate config for initalizing the AE and rename the state-dict on the fly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose what he mean is we can load the original weight using DCAE.from_single_file(xxx), but DCAE.from_pretrained(xxx) command still need a converted weight, right?

Copy link
Member

@a-r-r-o-w a-r-r-o-w Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, from_single_file for original weights. from_pretrained for diffusers converted weights

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Thanks!

```

## Single file loading

The `AutoencoderDC` implementation supports loading checkpoints shipped in the original format by MIT HAN Lab. The following example demonstrates how to load the `f128c512` checkpoint:

```python
from diffusers import AutoencoderDC

model_name = "dc-ae-f128c512-mix-1.0"
ae = AutoencoderDC.from_single_file(
f"https://huggingface.co/mit-han-lab/{model_name}/model.safetensors",
original_config=f"https://huggingface.co/mit-han-lab/{model_name}/resolve/main/config.json"
)
```

## AutoencoderDC

[[autodoc]] AutoencoderDC
- decode
- all

## DecoderOutput

[[autodoc]] models.autoencoders.vae.DecoderOutput

323 changes: 323 additions & 0 deletions scripts/convert_dcae_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,323 @@
import argparse
from typing import Any, Dict

import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

from diffusers import AutoencoderDC


def remap_qkv_(key: str, state_dict: Dict[str, Any]):
qkv = state_dict.pop(key)
q, k, v = torch.chunk(qkv, 3, dim=0)
parent_module, _, _ = key.rpartition(".qkv.conv.weight")
state_dict[f"{parent_module}.to_q.weight"] = q.squeeze()
state_dict[f"{parent_module}.to_k.weight"] = k.squeeze()
state_dict[f"{parent_module}.to_v.weight"] = v.squeeze()


def remap_proj_conv_(key: str, state_dict: Dict[str, Any]):
parent_module, _, _ = key.rpartition(".proj.conv.weight")
state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze()


AE_KEYS_RENAME_DICT = {
# common
"main.": "",
"op_list.": "",
"context_module": "attn",
"local_module": "conv_out",
# NOTE: The below two lines work because scales in the available configs only have a tuple length of 1
# If there were more scales, there would be more layers, so a loop would be better to handle this
"aggreg.0.0": "to_qkv_multiscale.0.proj_in",
"aggreg.0.1": "to_qkv_multiscale.0.proj_out",
"depth_conv.conv": "conv_depth",
"inverted_conv.conv": "conv_inverted",
"point_conv.conv": "conv_point",
"point_conv.norm": "norm",
"conv.conv.": "conv.",
"conv1.conv": "conv1",
"conv2.conv": "conv2",
"conv2.norm": "norm",
"proj.norm": "norm_out",
# encoder
"encoder.project_in.conv": "encoder.conv_in",
"encoder.project_out.0.conv": "encoder.conv_out",
"encoder.stages": "encoder.down_blocks",
# decoder
"decoder.project_in.conv": "decoder.conv_in",
"decoder.project_out.0": "decoder.norm_out",
"decoder.project_out.2.conv": "decoder.conv_out",
"decoder.stages": "decoder.up_blocks",
}

AE_F32C32_KEYS = {
# encoder
"encoder.project_in.conv": "encoder.conv_in.conv",
# decoder
"decoder.project_out.2.conv": "decoder.conv_out.conv",
}

AE_F64C128_KEYS = {
# encoder
"encoder.project_in.conv": "encoder.conv_in.conv",
# decoder
"decoder.project_out.2.conv": "decoder.conv_out.conv",
}

AE_F128C512_KEYS = {
# encoder
"encoder.project_in.conv": "encoder.conv_in.conv",
# decoder
"decoder.project_out.2.conv": "decoder.conv_out.conv",
}

AE_SPECIAL_KEYS_REMAP = {
"qkv.conv.weight": remap_qkv_,
"proj.conv.weight": remap_proj_conv_,
}


def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
state_dict = saved_dict
if "model" in saved_dict.keys():
state_dict = state_dict["model"]
if "module" in saved_dict.keys():
state_dict = state_dict["module"]
if "state_dict" in saved_dict.keys():
state_dict = state_dict["state_dict"]
return state_dict


def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
state_dict[new_key] = state_dict.pop(old_key)


def convert_ae(config_name: str, dtype: torch.dtype):
config = get_ae_config(config_name)
hub_id = f"mit-han-lab/{config_name}"
ckpt_path = hf_hub_download(hub_id, "model.safetensors")
original_state_dict = get_state_dict(load_file(ckpt_path))

ae = AutoencoderDC(**config).to(dtype=dtype)

for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in AE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_(original_state_dict, key, new_key)

for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)

ae.load_state_dict(original_state_dict, strict=True)
return ae


def get_ae_config(name: str):
if name in ["dc-ae-f32c32-sana-1.0"]:
config = {
"latent_channels": 32,
"encoder_block_types": (
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
),
"decoder_block_types": (
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
),
"encoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
"decoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
"encoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
"decoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
"encoder_layers_per_block": (2, 2, 2, 3, 3, 3),
"decoder_layers_per_block": [3, 3, 3, 3, 3, 3],
"downsample_block_type": "conv",
"upsample_block_type": "interpolate",
"decoder_norm_types": "rms_norm",
"decoder_act_fns": "silu",
"scaling_factor": 0.41407,
}
elif name in ["dc-ae-f32c32-in-1.0", "dc-ae-f32c32-mix-1.0"]:
AE_KEYS_RENAME_DICT.update(AE_F32C32_KEYS)
config = {
"latent_channels": 32,
"encoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"decoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024],
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024],
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2],
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2],
"encoder_qkv_multiscales": ((), (), (), (), (), ()),
"decoder_qkv_multiscales": ((), (), (), (), (), ()),
"decoder_norm_types": ["batch_norm", "batch_norm", "batch_norm", "rms_norm", "rms_norm", "rms_norm"],
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu"],
}
if name == "dc-ae-f32c32-in-1.0":
config["scaling_factor"] = 0.3189
elif name == "dc-ae-f32c32-mix-1.0":
config["scaling_factor"] = 0.4552
elif name in ["dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0"]:
AE_KEYS_RENAME_DICT.update(AE_F64C128_KEYS)
config = {
"latent_channels": 128,
"encoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"decoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2],
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2],
"encoder_qkv_multiscales": ((), (), (), (), (), (), ()),
"decoder_qkv_multiscales": ((), (), (), (), (), (), ()),
"decoder_norm_types": [
"batch_norm",
"batch_norm",
"batch_norm",
"rms_norm",
"rms_norm",
"rms_norm",
"rms_norm",
],
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu"],
}
if name == "dc-ae-f64c128-in-1.0":
config["scaling_factor"] = 0.2889
elif name == "dc-ae-f64c128-mix-1.0":
config["scaling_factor"] = 0.4538
elif name in ["dc-ae-f128c512-in-1.0", "dc-ae-f128c512-mix-1.0"]:
AE_KEYS_RENAME_DICT.update(AE_F128C512_KEYS)
config = {
"latent_channels": 512,
"encoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"decoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2, 2],
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2, 2],
"encoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
"decoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
"decoder_norm_types": [
"batch_norm",
"batch_norm",
"batch_norm",
"rms_norm",
"rms_norm",
"rms_norm",
"rms_norm",
"rms_norm",
],
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu", "silu"],
}
if name == "dc-ae-f128c512-in-1.0":
config["scaling_factor"] = 0.4883
elif name == "dc-ae-f128c512-mix-1.0":
config["scaling_factor"] = 0.3620
else:
raise ValueError("Invalid config name provided.")

return config


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_name",
type=str,
default="dc-ae-f32c32-sana-1.0",
choices=[
"dc-ae-f32c32-sana-1.0",
"dc-ae-f32c32-in-1.0",
"dc-ae-f32c32-mix-1.0",
"dc-ae-f64c128-in-1.0",
"dc-ae-f64c128-mix-1.0",
"dc-ae-f128c512-in-1.0",
"dc-ae-f128c512-mix-1.0",
],
help="The DCAE checkpoint to convert",
)
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
return parser.parse_args()


DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}

VARIANT_MAPPING = {
"fp32": None,
"fp16": "fp16",
"bf16": "bf16",
}


if __name__ == "__main__":
args = get_args()

dtype = DTYPE_MAPPING[args.dtype]
variant = VARIANT_MAPPING[args.dtype]

ae = convert_ae(args.config_name, dtype)
ae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
Loading
Loading