From 1fbb034cfd2f80679aa08d2c31db158d743a71fe Mon Sep 17 00:00:00 2001 From: ydshieh Date: Mon, 19 Sep 2022 21:28:37 +0200 Subject: [PATCH 1/2] Fix SpatialTransformer --- src/diffusers/models/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 25e1ea28dcf0..0782ec885233 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -144,10 +144,10 @@ def forward(self, hidden_states, context=None): residual = hidden_states hidden_states = self.norm(hidden_states) hidden_states = self.proj_in(hidden_states) - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel) + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, -1) for block in self.transformer_blocks: hidden_states = block(hidden_states, context=context) - hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2) + hidden_states = hidden_states.reshape(batch, height, weight, -1).permute(0, 3, 1, 2) hidden_states = self.proj_out(hidden_states) return hidden_states + residual From 9a69d67568dd7bc125d816a2c6dc3faba7922399 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Tue, 20 Sep 2022 17:31:51 +0200 Subject: [PATCH 2/2] Fix SpatialTransformer --- src/diffusers/models/attention.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 0782ec885233..f963310f12eb 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -144,10 +144,11 @@ def forward(self, hidden_states, context=None): residual = hidden_states hidden_states = self.norm(hidden_states) hidden_states = self.proj_in(hidden_states) - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, -1) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) for block in self.transformer_blocks: hidden_states = block(hidden_states, context=context) - hidden_states = hidden_states.reshape(batch, height, weight, -1).permute(0, 3, 1, 2) + hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2) hidden_states = self.proj_out(hidden_states) return hidden_states + residual