Skip to content

Commit 11d18f3

Browse files
authored
Add single file loading support for AnimateDiff (#8819)
* update * update * update * update
1 parent d2df40c commit 11d18f3

File tree

5 files changed

+149
-2
lines changed

5 files changed

+149
-2
lines changed

docs/source/en/api/pipelines/animatediff.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,20 @@ export_to_gif(frames, "animatelcm-motion-lora.gif")
560560
</table>
561561

562562

563+
## Using `from_single_file` with the MotionAdapter
564+
565+
`diffusers>=0.30.0` supports loading the AnimateDiff checkpoints into the `MotionAdapter` in their original format via `from_single_file`
566+
567+
```python
568+
from diffusers import MotionAdapter
569+
570+
ckpt_path = "https://huggingface.co/Lightricks/LongAnimateDiff/blob/main/lt_long_mm_32_frames.ckpt"
571+
572+
adapter = MotionAdapter.from_single_file(ckpt_path, torch_dtype=torch.float16)
573+
pipe = AnimateDiffPipeline.from_pretrained("emilianJR/epiCRealism", motion_adapter=adapter)
574+
575+
```
576+
563577
## AnimateDiffPipeline
564578

565579
[[autodoc]] AnimateDiffPipeline

src/diffusers/loaders/single_file_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ..utils import deprecate, is_accelerate_available, logging
2323
from .single_file_utils import (
2424
SingleFileComponentError,
25+
convert_animatediff_checkpoint_to_diffusers,
2526
convert_controlnet_checkpoint,
2627
convert_ldm_unet_checkpoint,
2728
convert_ldm_vae_checkpoint,
@@ -70,6 +71,9 @@
7071
"checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers,
7172
"default_subfolder": "transformer",
7273
},
74+
"MotionAdapter": {
75+
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
76+
},
7377
}
7478

7579

src/diffusers/loaders/single_file_utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@
7474
"stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
7575
"stable_cascade_stage_c": "clip_txt_mapper.weight",
7676
"sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
77+
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe",
78+
"animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
79+
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
7780
}
7881

7982
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -103,6 +106,10 @@
103106
"sd3": {
104107
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers",
105108
},
109+
"animatediff_v1": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5"},
110+
"animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"},
111+
"animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"},
112+
"animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
106113
}
107114

108115
# Use to configure model sample size when original config is provided
@@ -485,6 +492,19 @@ def infer_diffusers_model_type(checkpoint):
485492
elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint:
486493
model_type = "sd3"
487494

495+
elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
496+
if CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint:
497+
model_type = "animatediff_v2"
498+
499+
elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff_sdxl_beta"]].shape[-1] == 320:
500+
model_type = "animatediff_sdxl_beta"
501+
502+
elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff"]].shape[1] == 24:
503+
model_type = "animatediff_v1"
504+
505+
else:
506+
model_type = "animatediff_v3"
507+
488508
else:
489509
model_type = "v1"
490510

@@ -1822,3 +1842,22 @@ def create_diffusers_t5_model_from_checkpoint(
18221842
param.data = param.data.to(torch.float32)
18231843

18241844
return model
1845+
1846+
1847+
def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
1848+
converted_state_dict = {}
1849+
for k, v in checkpoint.items():
1850+
if "pos_encoder" in k:
1851+
continue
1852+
1853+
else:
1854+
converted_state_dict[
1855+
k.replace(".norms.0", ".norm1")
1856+
.replace(".norms.1", ".norm2")
1857+
.replace(".ff_norm", ".norm3")
1858+
.replace(".attention_blocks.0", ".attn1")
1859+
.replace(".attention_blocks.1", ".attn2")
1860+
.replace(".temporal_transformer", "")
1861+
] = v
1862+
1863+
return converted_state_dict

src/diffusers/models/unets/unet_motion_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import torch.utils.checkpoint
2020

2121
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
22-
from ...loaders import UNet2DConditionLoadersMixin
22+
from ...loaders import FromOriginalModelMixin, UNet2DConditionLoadersMixin
2323
from ...utils import logging
2424
from ..attention_processor import (
2525
ADDED_KV_ATTENTION_PROCESSORS,
@@ -93,7 +93,7 @@ def __init__(
9393
)
9494

9595

96-
class MotionAdapter(ModelMixin, ConfigMixin):
96+
class MotionAdapter(ModelMixin, ConfigMixin, FromOriginalModelMixin):
9797
@register_to_config
9898
def __init__(
9999
self,
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import unittest
17+
18+
from diffusers import (
19+
MotionAdapter,
20+
)
21+
from diffusers.utils.testing_utils import (
22+
enable_full_determinism,
23+
)
24+
25+
26+
enable_full_determinism()
27+
28+
29+
class MotionAdapterSingleFileTests(unittest.TestCase):
30+
model_class = MotionAdapter
31+
32+
def test_single_file_components_version_v1_5(self):
33+
ckpt_path = "https://huggingface.co/guoyww/animatediff/blob/main/mm_sd_v15.ckpt"
34+
repo_id = "guoyww/animatediff-motion-adapter-v1-5"
35+
36+
model = self.model_class.from_pretrained(repo_id)
37+
model_single_file = self.model_class.from_single_file(ckpt_path)
38+
39+
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
40+
for param_name, param_value in model_single_file.config.items():
41+
if param_name in PARAMS_TO_IGNORE:
42+
continue
43+
assert (
44+
model.config[param_name] == param_value
45+
), f"{param_name} differs between pretrained loading and single file loading"
46+
47+
def test_single_file_components_version_v1_5_2(self):
48+
ckpt_path = "https://huggingface.co/guoyww/animatediff/blob/main/mm_sd_v15_v2.ckpt"
49+
repo_id = "guoyww/animatediff-motion-adapter-v1-5-2"
50+
51+
model = self.model_class.from_pretrained(repo_id)
52+
model_single_file = self.model_class.from_single_file(ckpt_path)
53+
54+
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
55+
for param_name, param_value in model_single_file.config.items():
56+
if param_name in PARAMS_TO_IGNORE:
57+
continue
58+
assert (
59+
model.config[param_name] == param_value
60+
), f"{param_name} differs between pretrained loading and single file loading"
61+
62+
def test_single_file_components_version_v1_5_3(self):
63+
ckpt_path = "https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_mm.ckpt"
64+
repo_id = "guoyww/animatediff-motion-adapter-v1-5-3"
65+
66+
model = self.model_class.from_pretrained(repo_id)
67+
model_single_file = self.model_class.from_single_file(ckpt_path)
68+
69+
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
70+
for param_name, param_value in model_single_file.config.items():
71+
if param_name in PARAMS_TO_IGNORE:
72+
continue
73+
assert (
74+
model.config[param_name] == param_value
75+
), f"{param_name} differs between pretrained loading and single file loading"
76+
77+
def test_single_file_components_version_sdxl_beta(self):
78+
ckpt_path = "https://huggingface.co/guoyww/animatediff/blob/main/mm_sdxl_v10_beta.ckpt"
79+
repo_id = "guoyww/animatediff-motion-adapter-sdxl-beta"
80+
81+
model = self.model_class.from_pretrained(repo_id)
82+
model_single_file = self.model_class.from_single_file(ckpt_path)
83+
84+
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
85+
for param_name, param_value in model_single_file.config.items():
86+
if param_name in PARAMS_TO_IGNORE:
87+
continue
88+
assert (
89+
model.config[param_name] == param_value
90+
), f"{param_name} differs between pretrained loading and single file loading"

0 commit comments

Comments
 (0)