diff --git a/src/diffusers/models/controlnet_sd3.py b/src/diffusers/models/controlnet_sd3.py index 629cb661eda5..86a6a9178514 100644 --- a/src/diffusers/models/controlnet_sd3.py +++ b/src/diffusers/models/controlnet_sd3.py @@ -239,16 +239,16 @@ def _set_gradient_checkpointing(self, module, value=False): module.gradient_checkpointing = value @classmethod - def from_transformer(cls, transformer, num_layers=None, load_weights_from_transformer=True): + def from_transformer(cls, transformer, num_layers=12, load_weights_from_transformer=True): config = transformer.config config["num_layers"] = num_layers or config.num_layers controlnet = cls(**config) if load_weights_from_transformer: - controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict(), strict=False) - controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict(), strict=False) - controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict(), strict=False) - controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict()) + controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict()) + controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict()) + controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict()) + controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False) controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input)