-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Fix SpatialTransformer
#578
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
src/diffusers/models/attention.py
Outdated
@@ -144,10 +144,10 @@ def forward(self, hidden_states, context=None): | |||
residual = hidden_states | |||
hidden_states = self.norm(hidden_states) | |||
hidden_states = self.proj_in(hidden_states) | |||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel) | |||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.proj_in(hidden_states)
changes the number of channels.
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR. As can be seen below d_head
is computed as in_channels // n_head
so inner_dim = n_head * d_head = in_channels
.
diffusers/src/diffusers/models/unet_blocks.py
Lines 507 to 510 in c01ec2d
SpatialTransformer( | |
out_channels, | |
attn_num_head_channels, | |
out_channels // attn_num_head_channels, |
So it's fine to leave it as is because specifieng size as variable is much readable than -1
.
Maybe we could make this more clear that inner_dim = n_head * d_head = in_channels
Hi @patil-suraj: We can force Otherwise, let's fix it :-) I do agree that hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) |
I made the change to make the shape more clear. |
Hmmm - I'm also not too sure about his here @ydshieh, are we fixing a bug here? If the current code is not buggy, it's the better more readable option IMO. +1 on what @patil-suraj said |
Hi @patrickvonplaten It is indeed buggy as long as In terms of readability, the latest change should be fine. inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) Leaving the current code as it is makes the code somehow confusing, and also not good for testing purpose. To reproduceimport numpy as np
import torch
from diffusers.models.attention import SpatialTransformer
N, H, W, C = (1, 16, 16, 6)
heads = 2
dim_head = 8
context_dim = 4
context_seq_len = 3
sample = np.random.default_rng().standard_normal(size=(N, C, H, W), dtype=np.float32)
context = np.random.default_rng().standard_normal(size=(N, context_seq_len, context_dim), dtype=np.float32)
pt_sample = torch.tensor(sample, dtype=torch.float32)
pt_context = torch.tensor(context, dtype=torch.float32)
tf_context = tf.constant(context)
pt_layer = SpatialTransformer(in_channels=C, context_dim=context_dim, n_heads=heads, d_head=dim_head, num_groups=3)
with torch.no_grad():
pt_output = pt_layer(pt_sample, context=pt_context) Error hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)
RuntimeError: shape '[1, 256, 6]' is invalid for input of size 4096 |
Hey @ydshieh, Sorry I think |
The usage seems to be the case, but this is not mentioned in the However, feel free to close this PR if you and @patil-suraj think the change is not really necessary :-). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Upon second reflection this is actually quite clean!
* Fix SpatialTransformer * Fix SpatialTransformer Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
* Fix SpatialTransformer * Fix SpatialTransformer Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
should be
where
-1
isinner_dim
instead of the channels in the initialhidden_states
, as it is already projected by