11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+ from contextlib import nullcontext
14
15
from typing import Dict
15
16
16
17
from ..models .attention_processor import SD3IPAdapterJointAttnProcessor2_0
17
18
from ..models .embeddings import IPAdapterTimeImageProjection
18
19
from ..models .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT , load_model_dict_into_meta
20
+ from ..utils import is_accelerate_available , is_torch_version , logging
21
+
22
+
23
+ logger = logging .get_logger (__name__ )
19
24
20
25
21
26
class SD3Transformer2DLoadersMixin :
22
27
"""Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`."""
23
28
24
- def _load_ip_adapter_weights (self , state_dict : Dict , low_cpu_mem_usage : bool = _LOW_CPU_MEM_USAGE_DEFAULT ) -> None :
25
- """Sets IP-Adapter attention processors, image projection, and loads state_dict.
29
+ def _convert_ip_adapter_attn_to_diffusers (
30
+ self , state_dict : Dict , low_cpu_mem_usage : bool = _LOW_CPU_MEM_USAGE_DEFAULT
31
+ ) -> Dict :
32
+ if low_cpu_mem_usage :
33
+ if is_accelerate_available ():
34
+ from accelerate import init_empty_weights
35
+
36
+ else :
37
+ low_cpu_mem_usage = False
38
+ logger .warning (
39
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
40
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
41
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n ```\n pip"
42
+ " install accelerate\n ```\n ."
43
+ )
44
+
45
+ if low_cpu_mem_usage is True and not is_torch_version (">=" , "1.9.0" ):
46
+ raise NotImplementedError (
47
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
48
+ " `low_cpu_mem_usage=False`."
49
+ )
26
50
27
- Args:
28
- state_dict (`Dict`):
29
- State dict with keys "ip_adapter", which contains parameters for attention processors, and
30
- "image_proj", which contains parameters for image projection net.
31
- low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
32
- Speed up model loading only loading the pretrained weights and not initializing the weights. This also
33
- tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
34
- Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
35
- argument to `True` will raise an error.
36
- """
37
51
# IP-Adapter cross attention parameters
38
52
hidden_size = self .config .attention_head_dim * self .config .num_attention_heads
39
53
ip_hidden_states_dim = self .config .attention_head_dim * self .config .num_attention_heads
40
- timesteps_emb_dim = state_dict ["ip_adapter" ][ " 0.norm_ip.linear.weight" ].shape [1 ]
54
+ timesteps_emb_dim = state_dict ["0.norm_ip.linear.weight" ].shape [1 ]
41
55
42
56
# Dict where key is transformer layer index, value is attention processor's state dict
43
57
# ip_adapter state dict keys example: "0.norm_ip.linear.weight"
44
58
layer_state_dict = {idx : {} for idx in range (len (self .attn_processors ))}
45
- for key , weights in state_dict [ "ip_adapter" ] .items ():
59
+ for key , weights in state_dict .items ():
46
60
idx , name = key .split ("." , maxsplit = 1 )
47
61
layer_state_dict [int (idx )][name ] = weights
48
62
49
- # Create IP-Adapter attention processor
63
+ # Create IP-Adapter attention processor & load state_dict
50
64
attn_procs = {}
65
+ init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
51
66
for idx , name in enumerate (self .attn_processors .keys ()):
52
- attn_procs [name ] = SD3IPAdapterJointAttnProcessor2_0 (
53
- hidden_size = hidden_size ,
54
- ip_hidden_states_dim = ip_hidden_states_dim ,
55
- head_dim = self .config .attention_head_dim ,
56
- timesteps_emb_dim = timesteps_emb_dim ,
57
- ).to (self .device , dtype = self .dtype )
67
+ with init_context ():
68
+ attn_procs [name ] = SD3IPAdapterJointAttnProcessor2_0 (
69
+ hidden_size = hidden_size ,
70
+ ip_hidden_states_dim = ip_hidden_states_dim ,
71
+ head_dim = self .config .attention_head_dim ,
72
+ timesteps_emb_dim = timesteps_emb_dim ,
73
+ )
58
74
59
75
if not low_cpu_mem_usage :
60
76
attn_procs [name ].load_state_dict (layer_state_dict [idx ], strict = True )
@@ -63,27 +79,90 @@ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _
63
79
attn_procs [name ], layer_state_dict [idx ], device = self .device , dtype = self .dtype
64
80
)
65
81
66
- self .set_attn_processor (attn_procs )
82
+ return attn_procs
83
+
84
+ def _convert_ip_adapter_image_proj_to_diffusers (
85
+ self , state_dict : Dict , low_cpu_mem_usage : bool = _LOW_CPU_MEM_USAGE_DEFAULT
86
+ ) -> IPAdapterTimeImageProjection :
87
+ if low_cpu_mem_usage :
88
+ if is_accelerate_available ():
89
+ from accelerate import init_empty_weights
90
+
91
+ else :
92
+ low_cpu_mem_usage = False
93
+ logger .warning (
94
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
95
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
96
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n ```\n pip"
97
+ " install accelerate\n ```\n ."
98
+ )
99
+
100
+ if low_cpu_mem_usage is True and not is_torch_version (">=" , "1.9.0" ):
101
+ raise NotImplementedError (
102
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
103
+ " `low_cpu_mem_usage=False`."
104
+ )
105
+
106
+ init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
107
+
108
+ # Convert to diffusers
109
+ updated_state_dict = {}
110
+ for key , value in state_dict .items ():
111
+ # InstantX/SD3.5-Large-IP-Adapter
112
+ if key .startswith ("layers." ):
113
+ idx = key .split ("." )[1 ]
114
+ key = key .replace (f"layers.{ idx } .0.norm1" , f"layers.{ idx } .ln0" )
115
+ key = key .replace (f"layers.{ idx } .0.norm2" , f"layers.{ idx } .ln1" )
116
+ key = key .replace (f"layers.{ idx } .0.to_q" , f"layers.{ idx } .attn.to_q" )
117
+ key = key .replace (f"layers.{ idx } .0.to_kv" , f"layers.{ idx } .attn.to_kv" )
118
+ key = key .replace (f"layers.{ idx } .0.to_out" , f"layers.{ idx } .attn.to_out.0" )
119
+ key = key .replace (f"layers.{ idx } .1.0" , f"layers.{ idx } .adaln_norm" )
120
+ key = key .replace (f"layers.{ idx } .1.1" , f"layers.{ idx } .ff.net.0.proj" )
121
+ key = key .replace (f"layers.{ idx } .1.3" , f"layers.{ idx } .ff.net.2" )
122
+ key = key .replace (f"layers.{ idx } .2.1" , f"layers.{ idx } .adaln_proj" )
123
+ updated_state_dict [key ] = value
67
124
68
125
# Image projetion parameters
69
- embed_dim = state_dict [ "image_proj" ] ["proj_in.weight" ].shape [1 ]
70
- output_dim = state_dict [ "image_proj" ] ["proj_out.weight" ].shape [0 ]
71
- hidden_dim = state_dict [ "image_proj" ] ["proj_in.weight" ].shape [0 ]
72
- heads = state_dict [ "image_proj" ] ["layers.0.attn.to_q.weight" ].shape [0 ] // 64
73
- num_queries = state_dict [ "image_proj" ] ["latents" ].shape [1 ]
74
- timestep_in_dim = state_dict [ "image_proj" ] ["time_embedding.linear_1.weight" ].shape [1 ]
126
+ embed_dim = updated_state_dict ["proj_in.weight" ].shape [1 ]
127
+ output_dim = updated_state_dict ["proj_out.weight" ].shape [0 ]
128
+ hidden_dim = updated_state_dict ["proj_in.weight" ].shape [0 ]
129
+ heads = updated_state_dict ["layers.0.attn.to_q.weight" ].shape [0 ] // 64
130
+ num_queries = updated_state_dict ["latents" ].shape [1 ]
131
+ timestep_in_dim = updated_state_dict ["time_embedding.linear_1.weight" ].shape [1 ]
75
132
76
133
# Image projection
77
- self .image_proj = IPAdapterTimeImageProjection (
78
- embed_dim = embed_dim ,
79
- output_dim = output_dim ,
80
- hidden_dim = hidden_dim ,
81
- heads = heads ,
82
- num_queries = num_queries ,
83
- timestep_in_dim = timestep_in_dim ,
84
- ).to (device = self .device , dtype = self .dtype )
134
+ with init_context ():
135
+ image_proj = IPAdapterTimeImageProjection (
136
+ embed_dim = embed_dim ,
137
+ output_dim = output_dim ,
138
+ hidden_dim = hidden_dim ,
139
+ heads = heads ,
140
+ num_queries = num_queries ,
141
+ timestep_in_dim = timestep_in_dim ,
142
+ )
85
143
86
144
if not low_cpu_mem_usage :
87
- self . image_proj .load_state_dict (state_dict [ "image_proj" ] , strict = True )
145
+ image_proj .load_state_dict (updated_state_dict , strict = True )
88
146
else :
89
- load_model_dict_into_meta (self .image_proj , state_dict ["image_proj" ], device = self .device , dtype = self .dtype )
147
+ load_model_dict_into_meta (image_proj , updated_state_dict , device = self .device , dtype = self .dtype )
148
+
149
+ return image_proj
150
+
151
+ def _load_ip_adapter_weights (self , state_dict : Dict , low_cpu_mem_usage : bool = _LOW_CPU_MEM_USAGE_DEFAULT ) -> None :
152
+ """Sets IP-Adapter attention processors, image projection, and loads state_dict.
153
+
154
+ Args:
155
+ state_dict (`Dict`):
156
+ State dict with keys "ip_adapter", which contains parameters for attention processors, and
157
+ "image_proj", which contains parameters for image projection net.
158
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
159
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
160
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
161
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
162
+ argument to `True` will raise an error.
163
+ """
164
+
165
+ attn_procs = self ._convert_ip_adapter_attn_to_diffusers (state_dict ["ip_adapter" ], low_cpu_mem_usage )
166
+ self .set_attn_processor (attn_procs )
167
+
168
+ self .image_proj = self ._convert_ip_adapter_image_proj_to_diffusers (state_dict ["image_proj" ], low_cpu_mem_usage )
0 commit comments