|
| 1 | +import argparse |
| 2 | +from typing import Any, Dict |
| 3 | + |
| 4 | +import torch |
| 5 | +from huggingface_hub import hf_hub_download |
| 6 | +from safetensors.torch import load_file |
| 7 | + |
| 8 | +from diffusers import AutoencoderDC |
| 9 | + |
| 10 | + |
| 11 | +def remap_qkv_(key: str, state_dict: Dict[str, Any]): |
| 12 | + qkv = state_dict.pop(key) |
| 13 | + q, k, v = torch.chunk(qkv, 3, dim=0) |
| 14 | + parent_module, _, _ = key.rpartition(".qkv.conv.weight") |
| 15 | + state_dict[f"{parent_module}.to_q.weight"] = q.squeeze() |
| 16 | + state_dict[f"{parent_module}.to_k.weight"] = k.squeeze() |
| 17 | + state_dict[f"{parent_module}.to_v.weight"] = v.squeeze() |
| 18 | + |
| 19 | + |
| 20 | +def remap_proj_conv_(key: str, state_dict: Dict[str, Any]): |
| 21 | + parent_module, _, _ = key.rpartition(".proj.conv.weight") |
| 22 | + state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze() |
| 23 | + |
| 24 | + |
| 25 | +AE_KEYS_RENAME_DICT = { |
| 26 | + # common |
| 27 | + "main.": "", |
| 28 | + "op_list.": "", |
| 29 | + "context_module": "attn", |
| 30 | + "local_module": "conv_out", |
| 31 | + # NOTE: The below two lines work because scales in the available configs only have a tuple length of 1 |
| 32 | + # If there were more scales, there would be more layers, so a loop would be better to handle this |
| 33 | + "aggreg.0.0": "to_qkv_multiscale.0.proj_in", |
| 34 | + "aggreg.0.1": "to_qkv_multiscale.0.proj_out", |
| 35 | + "depth_conv.conv": "conv_depth", |
| 36 | + "inverted_conv.conv": "conv_inverted", |
| 37 | + "point_conv.conv": "conv_point", |
| 38 | + "point_conv.norm": "norm", |
| 39 | + "conv.conv.": "conv.", |
| 40 | + "conv1.conv": "conv1", |
| 41 | + "conv2.conv": "conv2", |
| 42 | + "conv2.norm": "norm", |
| 43 | + "proj.norm": "norm_out", |
| 44 | + # encoder |
| 45 | + "encoder.project_in.conv": "encoder.conv_in", |
| 46 | + "encoder.project_out.0.conv": "encoder.conv_out", |
| 47 | + "encoder.stages": "encoder.down_blocks", |
| 48 | + # decoder |
| 49 | + "decoder.project_in.conv": "decoder.conv_in", |
| 50 | + "decoder.project_out.0": "decoder.norm_out", |
| 51 | + "decoder.project_out.2.conv": "decoder.conv_out", |
| 52 | + "decoder.stages": "decoder.up_blocks", |
| 53 | +} |
| 54 | + |
| 55 | +AE_F32C32_KEYS = { |
| 56 | + # encoder |
| 57 | + "encoder.project_in.conv": "encoder.conv_in.conv", |
| 58 | + # decoder |
| 59 | + "decoder.project_out.2.conv": "decoder.conv_out.conv", |
| 60 | +} |
| 61 | + |
| 62 | +AE_F64C128_KEYS = { |
| 63 | + # encoder |
| 64 | + "encoder.project_in.conv": "encoder.conv_in.conv", |
| 65 | + # decoder |
| 66 | + "decoder.project_out.2.conv": "decoder.conv_out.conv", |
| 67 | +} |
| 68 | + |
| 69 | +AE_F128C512_KEYS = { |
| 70 | + # encoder |
| 71 | + "encoder.project_in.conv": "encoder.conv_in.conv", |
| 72 | + # decoder |
| 73 | + "decoder.project_out.2.conv": "decoder.conv_out.conv", |
| 74 | +} |
| 75 | + |
| 76 | +AE_SPECIAL_KEYS_REMAP = { |
| 77 | + "qkv.conv.weight": remap_qkv_, |
| 78 | + "proj.conv.weight": remap_proj_conv_, |
| 79 | +} |
| 80 | + |
| 81 | + |
| 82 | +def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: |
| 83 | + state_dict = saved_dict |
| 84 | + if "model" in saved_dict.keys(): |
| 85 | + state_dict = state_dict["model"] |
| 86 | + if "module" in saved_dict.keys(): |
| 87 | + state_dict = state_dict["module"] |
| 88 | + if "state_dict" in saved_dict.keys(): |
| 89 | + state_dict = state_dict["state_dict"] |
| 90 | + return state_dict |
| 91 | + |
| 92 | + |
| 93 | +def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: |
| 94 | + state_dict[new_key] = state_dict.pop(old_key) |
| 95 | + |
| 96 | + |
| 97 | +def convert_ae(config_name: str, dtype: torch.dtype): |
| 98 | + config = get_ae_config(config_name) |
| 99 | + hub_id = f"mit-han-lab/{config_name}" |
| 100 | + ckpt_path = hf_hub_download(hub_id, "model.safetensors") |
| 101 | + original_state_dict = get_state_dict(load_file(ckpt_path)) |
| 102 | + |
| 103 | + ae = AutoencoderDC(**config).to(dtype=dtype) |
| 104 | + |
| 105 | + for key in list(original_state_dict.keys()): |
| 106 | + new_key = key[:] |
| 107 | + for replace_key, rename_key in AE_KEYS_RENAME_DICT.items(): |
| 108 | + new_key = new_key.replace(replace_key, rename_key) |
| 109 | + update_state_dict_(original_state_dict, key, new_key) |
| 110 | + |
| 111 | + for key in list(original_state_dict.keys()): |
| 112 | + for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items(): |
| 113 | + if special_key not in key: |
| 114 | + continue |
| 115 | + handler_fn_inplace(key, original_state_dict) |
| 116 | + |
| 117 | + ae.load_state_dict(original_state_dict, strict=True) |
| 118 | + return ae |
| 119 | + |
| 120 | + |
| 121 | +def get_ae_config(name: str): |
| 122 | + if name in ["dc-ae-f32c32-sana-1.0"]: |
| 123 | + config = { |
| 124 | + "latent_channels": 32, |
| 125 | + "encoder_block_types": ( |
| 126 | + "ResBlock", |
| 127 | + "ResBlock", |
| 128 | + "ResBlock", |
| 129 | + "EfficientViTBlock", |
| 130 | + "EfficientViTBlock", |
| 131 | + "EfficientViTBlock", |
| 132 | + ), |
| 133 | + "decoder_block_types": ( |
| 134 | + "ResBlock", |
| 135 | + "ResBlock", |
| 136 | + "ResBlock", |
| 137 | + "EfficientViTBlock", |
| 138 | + "EfficientViTBlock", |
| 139 | + "EfficientViTBlock", |
| 140 | + ), |
| 141 | + "encoder_block_out_channels": (128, 256, 512, 512, 1024, 1024), |
| 142 | + "decoder_block_out_channels": (128, 256, 512, 512, 1024, 1024), |
| 143 | + "encoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)), |
| 144 | + "decoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)), |
| 145 | + "encoder_layers_per_block": (2, 2, 2, 3, 3, 3), |
| 146 | + "decoder_layers_per_block": [3, 3, 3, 3, 3, 3], |
| 147 | + "downsample_block_type": "conv", |
| 148 | + "upsample_block_type": "interpolate", |
| 149 | + "decoder_norm_types": "rms_norm", |
| 150 | + "decoder_act_fns": "silu", |
| 151 | + "scaling_factor": 0.41407, |
| 152 | + } |
| 153 | + elif name in ["dc-ae-f32c32-in-1.0", "dc-ae-f32c32-mix-1.0"]: |
| 154 | + AE_KEYS_RENAME_DICT.update(AE_F32C32_KEYS) |
| 155 | + config = { |
| 156 | + "latent_channels": 32, |
| 157 | + "encoder_block_types": [ |
| 158 | + "ResBlock", |
| 159 | + "ResBlock", |
| 160 | + "ResBlock", |
| 161 | + "EfficientViTBlock", |
| 162 | + "EfficientViTBlock", |
| 163 | + "EfficientViTBlock", |
| 164 | + ], |
| 165 | + "decoder_block_types": [ |
| 166 | + "ResBlock", |
| 167 | + "ResBlock", |
| 168 | + "ResBlock", |
| 169 | + "EfficientViTBlock", |
| 170 | + "EfficientViTBlock", |
| 171 | + "EfficientViTBlock", |
| 172 | + ], |
| 173 | + "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024], |
| 174 | + "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024], |
| 175 | + "encoder_layers_per_block": [0, 4, 8, 2, 2, 2], |
| 176 | + "decoder_layers_per_block": [0, 5, 10, 2, 2, 2], |
| 177 | + "encoder_qkv_multiscales": ((), (), (), (), (), ()), |
| 178 | + "decoder_qkv_multiscales": ((), (), (), (), (), ()), |
| 179 | + "decoder_norm_types": ["batch_norm", "batch_norm", "batch_norm", "rms_norm", "rms_norm", "rms_norm"], |
| 180 | + "decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu"], |
| 181 | + } |
| 182 | + if name == "dc-ae-f32c32-in-1.0": |
| 183 | + config["scaling_factor"] = 0.3189 |
| 184 | + elif name == "dc-ae-f32c32-mix-1.0": |
| 185 | + config["scaling_factor"] = 0.4552 |
| 186 | + elif name in ["dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0"]: |
| 187 | + AE_KEYS_RENAME_DICT.update(AE_F64C128_KEYS) |
| 188 | + config = { |
| 189 | + "latent_channels": 128, |
| 190 | + "encoder_block_types": [ |
| 191 | + "ResBlock", |
| 192 | + "ResBlock", |
| 193 | + "ResBlock", |
| 194 | + "EfficientViTBlock", |
| 195 | + "EfficientViTBlock", |
| 196 | + "EfficientViTBlock", |
| 197 | + "EfficientViTBlock", |
| 198 | + ], |
| 199 | + "decoder_block_types": [ |
| 200 | + "ResBlock", |
| 201 | + "ResBlock", |
| 202 | + "ResBlock", |
| 203 | + "EfficientViTBlock", |
| 204 | + "EfficientViTBlock", |
| 205 | + "EfficientViTBlock", |
| 206 | + "EfficientViTBlock", |
| 207 | + ], |
| 208 | + "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048], |
| 209 | + "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048], |
| 210 | + "encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2], |
| 211 | + "decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2], |
| 212 | + "encoder_qkv_multiscales": ((), (), (), (), (), (), ()), |
| 213 | + "decoder_qkv_multiscales": ((), (), (), (), (), (), ()), |
| 214 | + "decoder_norm_types": [ |
| 215 | + "batch_norm", |
| 216 | + "batch_norm", |
| 217 | + "batch_norm", |
| 218 | + "rms_norm", |
| 219 | + "rms_norm", |
| 220 | + "rms_norm", |
| 221 | + "rms_norm", |
| 222 | + ], |
| 223 | + "decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu"], |
| 224 | + } |
| 225 | + if name == "dc-ae-f64c128-in-1.0": |
| 226 | + config["scaling_factor"] = 0.2889 |
| 227 | + elif name == "dc-ae-f64c128-mix-1.0": |
| 228 | + config["scaling_factor"] = 0.4538 |
| 229 | + elif name in ["dc-ae-f128c512-in-1.0", "dc-ae-f128c512-mix-1.0"]: |
| 230 | + AE_KEYS_RENAME_DICT.update(AE_F128C512_KEYS) |
| 231 | + config = { |
| 232 | + "latent_channels": 512, |
| 233 | + "encoder_block_types": [ |
| 234 | + "ResBlock", |
| 235 | + "ResBlock", |
| 236 | + "ResBlock", |
| 237 | + "EfficientViTBlock", |
| 238 | + "EfficientViTBlock", |
| 239 | + "EfficientViTBlock", |
| 240 | + "EfficientViTBlock", |
| 241 | + "EfficientViTBlock", |
| 242 | + ], |
| 243 | + "decoder_block_types": [ |
| 244 | + "ResBlock", |
| 245 | + "ResBlock", |
| 246 | + "ResBlock", |
| 247 | + "EfficientViTBlock", |
| 248 | + "EfficientViTBlock", |
| 249 | + "EfficientViTBlock", |
| 250 | + "EfficientViTBlock", |
| 251 | + "EfficientViTBlock", |
| 252 | + ], |
| 253 | + "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048], |
| 254 | + "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048], |
| 255 | + "encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2, 2], |
| 256 | + "decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2, 2], |
| 257 | + "encoder_qkv_multiscales": ((), (), (), (), (), (), (), ()), |
| 258 | + "decoder_qkv_multiscales": ((), (), (), (), (), (), (), ()), |
| 259 | + "decoder_norm_types": [ |
| 260 | + "batch_norm", |
| 261 | + "batch_norm", |
| 262 | + "batch_norm", |
| 263 | + "rms_norm", |
| 264 | + "rms_norm", |
| 265 | + "rms_norm", |
| 266 | + "rms_norm", |
| 267 | + "rms_norm", |
| 268 | + ], |
| 269 | + "decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu", "silu"], |
| 270 | + } |
| 271 | + if name == "dc-ae-f128c512-in-1.0": |
| 272 | + config["scaling_factor"] = 0.4883 |
| 273 | + elif name == "dc-ae-f128c512-mix-1.0": |
| 274 | + config["scaling_factor"] = 0.3620 |
| 275 | + else: |
| 276 | + raise ValueError("Invalid config name provided.") |
| 277 | + |
| 278 | + return config |
| 279 | + |
| 280 | + |
| 281 | +def get_args(): |
| 282 | + parser = argparse.ArgumentParser() |
| 283 | + parser.add_argument( |
| 284 | + "--config_name", |
| 285 | + type=str, |
| 286 | + default="dc-ae-f32c32-sana-1.0", |
| 287 | + choices=[ |
| 288 | + "dc-ae-f32c32-sana-1.0", |
| 289 | + "dc-ae-f32c32-in-1.0", |
| 290 | + "dc-ae-f32c32-mix-1.0", |
| 291 | + "dc-ae-f64c128-in-1.0", |
| 292 | + "dc-ae-f64c128-mix-1.0", |
| 293 | + "dc-ae-f128c512-in-1.0", |
| 294 | + "dc-ae-f128c512-mix-1.0", |
| 295 | + ], |
| 296 | + help="The DCAE checkpoint to convert", |
| 297 | + ) |
| 298 | + parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") |
| 299 | + parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.") |
| 300 | + return parser.parse_args() |
| 301 | + |
| 302 | + |
| 303 | +DTYPE_MAPPING = { |
| 304 | + "fp32": torch.float32, |
| 305 | + "fp16": torch.float16, |
| 306 | + "bf16": torch.bfloat16, |
| 307 | +} |
| 308 | + |
| 309 | +VARIANT_MAPPING = { |
| 310 | + "fp32": None, |
| 311 | + "fp16": "fp16", |
| 312 | + "bf16": "bf16", |
| 313 | +} |
| 314 | + |
| 315 | + |
| 316 | +if __name__ == "__main__": |
| 317 | + args = get_args() |
| 318 | + |
| 319 | + dtype = DTYPE_MAPPING[args.dtype] |
| 320 | + variant = VARIANT_MAPPING[args.dtype] |
| 321 | + |
| 322 | + ae = convert_ae(args.config_name, dtype) |
| 323 | + ae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant) |
0 commit comments