From 927d7e2a8a207b7dabcd9871d73af5882ce4afc9 Mon Sep 17 00:00:00 2001 From: daspartho Date: Fri, 9 Sep 2022 21:17:20 +0530 Subject: [PATCH 1/4] renamed variable names q -> query k -> key v -> value b -> batch c -> channel h -> height w -> weight --- src/diffusers/models/attention.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 094ca0fb2299..0547bfec382d 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -139,14 +139,14 @@ def _set_attention_slice(self, slice_size): def forward(self, x, context=None): # note: if no context is given, cross-attention defaults to self-attention - b, c, h, w = x.shape + batch, channel, height, weight = x.shape x_in = x x = self.norm(x) x = self.proj_in(x) - x = x.permute(0, 2, 3, 1).reshape(b, h * w, c) + x = x.permute(0, 2, 3, 1).reshape(batch, height * weight, channel) for block in self.transformer_blocks: x = block(x, context=context) - x = x.reshape(b, h, w, c).permute(0, 3, 1, 2) + x = x.reshape(batch, height, weight, channel).permute(0, 3, 1, 2) x = self.proj_out(x) return x + x_in @@ -250,19 +250,19 @@ def reshape_batch_dim_to_heads(self, tensor): def forward(self, x, context=None, mask=None): batch_size, sequence_length, dim = x.shape - q = self.to_q(x) + query = self.to_q(x) context = context if context is not None else x - k = self.to_k(context) - v = self.to_v(context) + key = self.to_k(context) + value = self.to_v(context) - q = self.reshape_heads_to_batch_dim(q) - k = self.reshape_heads_to_batch_dim(k) - v = self.reshape_heads_to_batch_dim(v) + query = self.reshape_heads_to_batch_dim(q) + key = self.reshape_heads_to_batch_dim(k) + value = self.reshape_heads_to_batch_dim(v) # TODO(PVP) - mask is currently never used. Remember to re-implement when used # attention, what we cannot get enough of - hidden_states = self._attention(q, k, v, sequence_length, dim) + hidden_states = self._attention(query, key, value, sequence_length, dim) return self.to_out(hidden_states) From 1fc0cb8c312144df7e7e37b82579e20e671c2b3c Mon Sep 17 00:00:00 2001 From: daspartho Date: Fri, 9 Sep 2022 21:28:17 +0530 Subject: [PATCH 2/4] rename variable names missed some in the initial commit --- src/diffusers/models/attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 0547bfec382d..2d35dda4eb96 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -255,9 +255,9 @@ def forward(self, x, context=None, mask=None): key = self.to_k(context) value = self.to_v(context) - query = self.reshape_heads_to_batch_dim(q) - key = self.reshape_heads_to_batch_dim(k) - value = self.reshape_heads_to_batch_dim(v) + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) # TODO(PVP) - mask is currently never used. Remember to re-implement when used From 84c7eaace8b95f93b8316593c2d2827166da23bd Mon Sep 17 00:00:00 2001 From: daspartho Date: Fri, 9 Sep 2022 21:54:54 +0530 Subject: [PATCH 3/4] renamed more variable names As per code review suggestions, renamed x -> hidden_states and x_in -> residual --- src/diffusers/models/attention.py | 50 +++++++++++++++---------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 2d35dda4eb96..35de035cfc7c 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -137,18 +137,18 @@ def _set_attention_slice(self, slice_size): for block in self.transformer_blocks: block._set_attention_slice(slice_size) - def forward(self, x, context=None): + def forward(self, hidden_states, context=None): # note: if no context is given, cross-attention defaults to self-attention - batch, channel, height, weight = x.shape - x_in = x - x = self.norm(x) - x = self.proj_in(x) - x = x.permute(0, 2, 3, 1).reshape(batch, height * weight, channel) + batch, channel, height, weight = hidden_states.shape + residual_in = hidden_states + hidden_states = self.norm(hidden_states) + hidden_states = self.proj_in(hidden_states) + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel) for block in self.transformer_blocks: - x = block(x, context=context) - x = x.reshape(batch, height, weight, channel).permute(0, 3, 1, 2) - x = self.proj_out(x) - return x + x_in + hidden_states = block(hidden_states, context=context) + hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2) + hidden_states = self.proj_out(hidden_states) + return hidden_states + residual class BasicTransformerBlock(nn.Module): @@ -192,12 +192,12 @@ def _set_attention_slice(self, slice_size): self.attn1._slice_size = slice_size self.attn2._slice_size = slice_size - def forward(self, x, context=None): - x = x.contiguous() if x.device.type == "mps" else x - x = self.attn1(self.norm1(x)) + x - x = self.attn2(self.norm2(x), context=context) + x - x = self.ff(self.norm3(x)) + x - return x + def forward(self, hidden_states, context=None): + hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states + hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states + hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + return hidden_states class CrossAttention(nn.Module): @@ -247,11 +247,11 @@ def reshape_batch_dim_to_heads(self, tensor): tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) return tensor - def forward(self, x, context=None, mask=None): - batch_size, sequence_length, dim = x.shape + def forward(self, hidden_states, context=None, mask=None): + batch_size, sequence_length, dim = hidden_states.shape - query = self.to_q(x) - context = context if context is not None else x + query = self.to_q(hidden_states) + context = context if context is not None else hidden_states key = self.to_k(context) value = self.to_v(context) @@ -308,8 +308,8 @@ def __init__( self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) - def forward(self, x): - return self.net(x) + def forward(self, hidden_states): + return self.net(hidden_states) # feedforward @@ -326,6 +326,6 @@ def __init__(self, dim_in: int, dim_out: int): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2) - def forward(self, x): - x, gate = self.proj(x).chunk(2, dim=-1) - return x * F.gelu(gate) + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + return hidden_states * F.gelu(gate) From 075e91c13acd41a835013b39465317723e633a42 Mon Sep 17 00:00:00 2001 From: daspartho Date: Fri, 9 Sep 2022 21:57:44 +0530 Subject: [PATCH 4/4] fixed minor typo --- src/diffusers/models/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 35de035cfc7c..accddacdad89 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -140,7 +140,7 @@ def _set_attention_slice(self, slice_size): def forward(self, hidden_states, context=None): # note: if no context is given, cross-attention defaults to self-attention batch, channel, height, weight = hidden_states.shape - residual_in = hidden_states + residual = hidden_states hidden_states = self.norm(hidden_states) hidden_states = self.proj_in(hidden_states) hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)