diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index e19b087431a2..2a81f357d48b 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -128,9 +128,9 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_onl query_dim=dim, cross_attention_dim=None, added_kv_proj_dim=dim, - dim_head=attention_head_dim // num_attention_heads, + dim_head=attention_head_dim, heads=num_attention_heads, - out_dim=attention_head_dim, + out_dim=dim, context_pre_only=context_pre_only, bias=True, processor=processor, diff --git a/src/diffusers/models/controlnet_sd3.py b/src/diffusers/models/controlnet_sd3.py index 86a6a9178514..2b4dd0fa8b72 100644 --- a/src/diffusers/models/controlnet_sd3.py +++ b/src/diffusers/models/controlnet_sd3.py @@ -81,7 +81,7 @@ def __init__( JointTransformerBlock( dim=self.inner_dim, num_attention_heads=num_attention_heads, - attention_head_dim=self.inner_dim, + attention_head_dim=self.config.attention_head_dim, context_pre_only=False, ) for i in range(num_layers) diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index d514a43537d8..1b9126b3b849 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -97,7 +97,7 @@ def __init__( JointTransformerBlock( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, - attention_head_dim=self.inner_dim, + attention_head_dim=self.config.attention_head_dim, context_pre_only=i == num_layers - 1, ) for i in range(self.config.num_layers)