Skip to content

correct attention_head_dim for JointTransformerBlock #8608

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 2, 2024
Merged

Conversation

yiyixuxu
Copy link
Collaborator

No description provided.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yiyixuxu yiyixuxu requested a review from DN6 June 18, 2024 02:48
@yiyixuxu
Copy link
Collaborator Author

should finish #6893 and #7027

@@ -81,7 +81,7 @@ def __init__(
JointTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=self.inner_dim,
attention_head_dim=attention_head_dim,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also be self.config.attention_head_dim to match transformer_sd3.py?

@@ -128,9 +128,9 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_onl
query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim,
dim_head=attention_head_dim // num_attention_heads,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't break? Wouldn't the value of dim_head be computed differently?

Copy link
Collaborator Author

@yiyixuxu yiyixuxu Jul 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well no

currently dim_head=attention_head_dim // num_attention_heads with attention_head_dim and num_attention_heads passed from SD3ControlNetModel like this
* attention_head_dim=self.inner_dim

attention_head_dim=self.inner_dim,

* self.inner_dim = num_attention_heads * attention_head_dim
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim

* -> so basically attention_head_dim is num_attention_heads * attention_head_dim
* num_attention_heads is num_attention_heads
* -> so dim_heads here are just attention_head_dim we used to configure the model, and if we pass it down correctly, we can use it directly

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh I see. Thanks for the explaining! 🙏🏽

@yiyixuxu yiyixuxu merged commit d9f71ab into main Jul 2, 2024
18 checks passed
@yiyixuxu yiyixuxu deleted the attn-dim branch July 2, 2024 17:42
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
* add

* update sd3 controlnet

* Update src/diffusers/models/controlnet_sd3.py

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants