Skip to content

Fix SpatialTransformer #578

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 27, 2022
Merged

Fix SpatialTransformer #578

merged 2 commits into from
Sep 27, 2022

Conversation

ydshieh
Copy link
Contributor

@ydshieh ydshieh commented Sep 19, 2022

hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)

should be

hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, -1)

where -1 is inner_dim instead of the channels in the initial hidden_states, as it is already projected by

hidden_states = self.proj_in(hidden_states)

@ydshieh ydshieh marked this pull request as ready for review September 19, 2022 19:32
@@ -144,10 +144,10 @@ def forward(self, hidden_states, context=None):
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, -1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.proj_in(hidden_states) changes the number of channels.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 19, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. As can be seen below d_head is computed as in_channels // n_head
so inner_dim = n_head * d_head = in_channels.

SpatialTransformer(
out_channels,
attn_num_head_channels,
out_channels // attn_num_head_channels,

So it's fine to leave it as is because specifieng size as variable is much readable than -1.

Maybe we could make this more clear that inner_dim = n_head * d_head = in_channels

@ydshieh
Copy link
Contributor Author

ydshieh commented Sep 20, 2022

Hi @patil-suraj:

We can force inner_dim = n_head * d_head = in_channels - if this is the only case that will be used. We just to make this clear.


Otherwise, let's fix it :-)

I do agree that specifieng size as variable is much readable than -1. In this case, we can do

        hidden_states = self.proj_in(hidden_states)
        inner_dim = hidden_states.shape[1]
        hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)

@ydshieh
Copy link
Contributor Author

ydshieh commented Sep 20, 2022

I made the change to make the shape more clear.

@patrickvonplaten
Copy link
Contributor

Hmmm - I'm also not too sure about his here @ydshieh, are we fixing a bug here? If the current code is not buggy, it's the better more readable option IMO. +1 on what @patil-suraj said

@ydshieh
Copy link
Contributor Author

ydshieh commented Sep 22, 2022

Hi @patrickvonplaten It is indeed buggy as long as SpatialTransformer get in_channels different from n_heads * d_head, see the To reproduce section below.

In terms of readability, the latest change should be fine.

    inner_dim = hidden_states.shape[1]
    hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)

Leaving the current code as it is makes the code somehow confusing, and also not good for testing purpose.

To reproduce

import numpy as np
import torch

from diffusers.models.attention import SpatialTransformer

N, H, W, C = (1, 16, 16, 6)
heads = 2
dim_head = 8
context_dim = 4
context_seq_len = 3

sample = np.random.default_rng().standard_normal(size=(N, C, H, W), dtype=np.float32)
context = np.random.default_rng().standard_normal(size=(N, context_seq_len, context_dim), dtype=np.float32)

pt_sample = torch.tensor(sample, dtype=torch.float32)

pt_context = torch.tensor(context, dtype=torch.float32)
tf_context = tf.constant(context)

pt_layer = SpatialTransformer(in_channels=C, context_dim=context_dim, n_heads=heads, d_head=dim_head, num_groups=3)

with torch.no_grad():
    pt_output = pt_layer(pt_sample, context=pt_context)

Error

    hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)
RuntimeError: shape '[1, 256, 6]' is invalid for input of size 4096

@patrickvonplaten
Copy link
Contributor

Hey @ydshieh,

Sorry I think in_channels is defined by n_heads * d_head, so I don't really think this is an issue 😅

@ydshieh
Copy link
Contributor Author

ydshieh commented Sep 27, 2022

Hi @patrickvonplaten

The usage seems to be the case, but this is not mentioned in the __init__ method, where in_channels has nothing to do with n_heads * d_head, and we have inner_dim = n_heads * d_head.

However, feel free to close this PR if you and @patil-suraj think the change is not really necessary :-).

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upon second reflection this is actually quite clean!

@patrickvonplaten patrickvonplaten merged commit d886e49 into main Sep 27, 2022
@patil-suraj patil-suraj deleted the fix_SpatialTransformer branch September 28, 2022 13:14
prathikr pushed a commit to prathikr/diffusers that referenced this pull request Oct 26, 2022
* Fix SpatialTransformer

* Fix SpatialTransformer

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Fix SpatialTransformer

* Fix SpatialTransformer

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants