Skip to content

Commit d9ee387

Browse files
authored
SD3 IP-Adapter runtime checkpoint conversion (#10718)
* Added runtime checkpoint conversion * Updated docs * Fix for quantized model
1 parent 454f82e commit d9ee387

File tree

2 files changed

+118
-39
lines changed

2 files changed

+118
-39
lines changed

docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ from diffusers import StableDiffusion3Pipeline
7777
from transformers import SiglipVisionModel, SiglipImageProcessor
7878

7979
image_encoder_id = "google/siglip-so400m-patch14-384"
80-
ip_adapter_id = "guiyrt/InstantX-SD3.5-Large-IP-Adapter-diffusers"
80+
ip_adapter_id = "InstantX/SD3.5-Large-IP-Adapter"
8181

8282
feature_extractor = SiglipImageProcessor.from_pretrained(
8383
image_encoder_id,

src/diffusers/loaders/transformer_sd3.py

Lines changed: 117 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -11,50 +11,66 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from contextlib import nullcontext
1415
from typing import Dict
1516

1617
from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
1718
from ..models.embeddings import IPAdapterTimeImageProjection
1819
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
20+
from ..utils import is_accelerate_available, is_torch_version, logging
21+
22+
23+
logger = logging.get_logger(__name__)
1924

2025

2126
class SD3Transformer2DLoadersMixin:
2227
"""Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`."""
2328

24-
def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None:
25-
"""Sets IP-Adapter attention processors, image projection, and loads state_dict.
29+
def _convert_ip_adapter_attn_to_diffusers(
30+
self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT
31+
) -> Dict:
32+
if low_cpu_mem_usage:
33+
if is_accelerate_available():
34+
from accelerate import init_empty_weights
35+
36+
else:
37+
low_cpu_mem_usage = False
38+
logger.warning(
39+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
40+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
41+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
42+
" install accelerate\n```\n."
43+
)
44+
45+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
46+
raise NotImplementedError(
47+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
48+
" `low_cpu_mem_usage=False`."
49+
)
2650

27-
Args:
28-
state_dict (`Dict`):
29-
State dict with keys "ip_adapter", which contains parameters for attention processors, and
30-
"image_proj", which contains parameters for image projection net.
31-
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
32-
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
33-
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
34-
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
35-
argument to `True` will raise an error.
36-
"""
3751
# IP-Adapter cross attention parameters
3852
hidden_size = self.config.attention_head_dim * self.config.num_attention_heads
3953
ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads
40-
timesteps_emb_dim = state_dict["ip_adapter"]["0.norm_ip.linear.weight"].shape[1]
54+
timesteps_emb_dim = state_dict["0.norm_ip.linear.weight"].shape[1]
4155

4256
# Dict where key is transformer layer index, value is attention processor's state dict
4357
# ip_adapter state dict keys example: "0.norm_ip.linear.weight"
4458
layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))}
45-
for key, weights in state_dict["ip_adapter"].items():
59+
for key, weights in state_dict.items():
4660
idx, name = key.split(".", maxsplit=1)
4761
layer_state_dict[int(idx)][name] = weights
4862

49-
# Create IP-Adapter attention processor
63+
# Create IP-Adapter attention processor & load state_dict
5064
attn_procs = {}
65+
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
5166
for idx, name in enumerate(self.attn_processors.keys()):
52-
attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0(
53-
hidden_size=hidden_size,
54-
ip_hidden_states_dim=ip_hidden_states_dim,
55-
head_dim=self.config.attention_head_dim,
56-
timesteps_emb_dim=timesteps_emb_dim,
57-
).to(self.device, dtype=self.dtype)
67+
with init_context():
68+
attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0(
69+
hidden_size=hidden_size,
70+
ip_hidden_states_dim=ip_hidden_states_dim,
71+
head_dim=self.config.attention_head_dim,
72+
timesteps_emb_dim=timesteps_emb_dim,
73+
)
5874

5975
if not low_cpu_mem_usage:
6076
attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
@@ -63,27 +79,90 @@ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _
6379
attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype
6480
)
6581

66-
self.set_attn_processor(attn_procs)
82+
return attn_procs
83+
84+
def _convert_ip_adapter_image_proj_to_diffusers(
85+
self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT
86+
) -> IPAdapterTimeImageProjection:
87+
if low_cpu_mem_usage:
88+
if is_accelerate_available():
89+
from accelerate import init_empty_weights
90+
91+
else:
92+
low_cpu_mem_usage = False
93+
logger.warning(
94+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
95+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
96+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
97+
" install accelerate\n```\n."
98+
)
99+
100+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
101+
raise NotImplementedError(
102+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
103+
" `low_cpu_mem_usage=False`."
104+
)
105+
106+
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
107+
108+
# Convert to diffusers
109+
updated_state_dict = {}
110+
for key, value in state_dict.items():
111+
# InstantX/SD3.5-Large-IP-Adapter
112+
if key.startswith("layers."):
113+
idx = key.split(".")[1]
114+
key = key.replace(f"layers.{idx}.0.norm1", f"layers.{idx}.ln0")
115+
key = key.replace(f"layers.{idx}.0.norm2", f"layers.{idx}.ln1")
116+
key = key.replace(f"layers.{idx}.0.to_q", f"layers.{idx}.attn.to_q")
117+
key = key.replace(f"layers.{idx}.0.to_kv", f"layers.{idx}.attn.to_kv")
118+
key = key.replace(f"layers.{idx}.0.to_out", f"layers.{idx}.attn.to_out.0")
119+
key = key.replace(f"layers.{idx}.1.0", f"layers.{idx}.adaln_norm")
120+
key = key.replace(f"layers.{idx}.1.1", f"layers.{idx}.ff.net.0.proj")
121+
key = key.replace(f"layers.{idx}.1.3", f"layers.{idx}.ff.net.2")
122+
key = key.replace(f"layers.{idx}.2.1", f"layers.{idx}.adaln_proj")
123+
updated_state_dict[key] = value
67124

68125
# Image projetion parameters
69-
embed_dim = state_dict["image_proj"]["proj_in.weight"].shape[1]
70-
output_dim = state_dict["image_proj"]["proj_out.weight"].shape[0]
71-
hidden_dim = state_dict["image_proj"]["proj_in.weight"].shape[0]
72-
heads = state_dict["image_proj"]["layers.0.attn.to_q.weight"].shape[0] // 64
73-
num_queries = state_dict["image_proj"]["latents"].shape[1]
74-
timestep_in_dim = state_dict["image_proj"]["time_embedding.linear_1.weight"].shape[1]
126+
embed_dim = updated_state_dict["proj_in.weight"].shape[1]
127+
output_dim = updated_state_dict["proj_out.weight"].shape[0]
128+
hidden_dim = updated_state_dict["proj_in.weight"].shape[0]
129+
heads = updated_state_dict["layers.0.attn.to_q.weight"].shape[0] // 64
130+
num_queries = updated_state_dict["latents"].shape[1]
131+
timestep_in_dim = updated_state_dict["time_embedding.linear_1.weight"].shape[1]
75132

76133
# Image projection
77-
self.image_proj = IPAdapterTimeImageProjection(
78-
embed_dim=embed_dim,
79-
output_dim=output_dim,
80-
hidden_dim=hidden_dim,
81-
heads=heads,
82-
num_queries=num_queries,
83-
timestep_in_dim=timestep_in_dim,
84-
).to(device=self.device, dtype=self.dtype)
134+
with init_context():
135+
image_proj = IPAdapterTimeImageProjection(
136+
embed_dim=embed_dim,
137+
output_dim=output_dim,
138+
hidden_dim=hidden_dim,
139+
heads=heads,
140+
num_queries=num_queries,
141+
timestep_in_dim=timestep_in_dim,
142+
)
85143

86144
if not low_cpu_mem_usage:
87-
self.image_proj.load_state_dict(state_dict["image_proj"], strict=True)
145+
image_proj.load_state_dict(updated_state_dict, strict=True)
88146
else:
89-
load_model_dict_into_meta(self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype)
147+
load_model_dict_into_meta(image_proj, updated_state_dict, device=self.device, dtype=self.dtype)
148+
149+
return image_proj
150+
151+
def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None:
152+
"""Sets IP-Adapter attention processors, image projection, and loads state_dict.
153+
154+
Args:
155+
state_dict (`Dict`):
156+
State dict with keys "ip_adapter", which contains parameters for attention processors, and
157+
"image_proj", which contains parameters for image projection net.
158+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
159+
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
160+
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
161+
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
162+
argument to `True` will raise an error.
163+
"""
164+
165+
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dict["ip_adapter"], low_cpu_mem_usage)
166+
self.set_attn_processor(attn_procs)
167+
168+
self.image_proj = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"], low_cpu_mem_usage)

0 commit comments

Comments
 (0)