Skip to content

Commit 7c25331

Browse files
haofanwangyiyixuxu
authored andcommitted
Allow from_transformer in SD3ControlNetModel (#8749)
* Update controlnet_sd3.py --------- Co-authored-by: YiYi Xu <yixu310@gmail.com>
1 parent a039005 commit 7c25331

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/diffusers/models/controlnet_sd3.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -239,16 +239,16 @@ def _set_gradient_checkpointing(self, module, value=False):
239239
module.gradient_checkpointing = value
240240

241241
@classmethod
242-
def from_transformer(cls, transformer, num_layers=None, load_weights_from_transformer=True):
242+
def from_transformer(cls, transformer, num_layers=12, load_weights_from_transformer=True):
243243
config = transformer.config
244244
config["num_layers"] = num_layers or config.num_layers
245245
controlnet = cls(**config)
246246

247247
if load_weights_from_transformer:
248-
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict(), strict=False)
249-
controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict(), strict=False)
250-
controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict(), strict=False)
251-
controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict())
248+
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
249+
controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
250+
controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
251+
controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
252252

253253
controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input)
254254

0 commit comments

Comments
 (0)