Skip to content

Commit 5843487

Browse files
authored
Renamed variables from single letter to better naming (#449)
* renamed variable names q -> query k -> key v -> value b -> batch c -> channel h -> height w -> weight * rename variable names missed some in the initial commit * renamed more variable names As per code review suggestions, renamed x -> hidden_states and x_in -> residual * fixed minor typo
1 parent 5adb0a7 commit 5843487

File tree

1 file changed

+31
-31
lines changed

1 file changed

+31
-31
lines changed

src/diffusers/models/attention.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -137,18 +137,18 @@ def _set_attention_slice(self, slice_size):
137137
for block in self.transformer_blocks:
138138
block._set_attention_slice(slice_size)
139139

140-
def forward(self, x, context=None):
140+
def forward(self, hidden_states, context=None):
141141
# note: if no context is given, cross-attention defaults to self-attention
142-
b, c, h, w = x.shape
143-
x_in = x
144-
x = self.norm(x)
145-
x = self.proj_in(x)
146-
x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
142+
batch, channel, height, weight = hidden_states.shape
143+
residual = hidden_states
144+
hidden_states = self.norm(hidden_states)
145+
hidden_states = self.proj_in(hidden_states)
146+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)
147147
for block in self.transformer_blocks:
148-
x = block(x, context=context)
149-
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
150-
x = self.proj_out(x)
151-
return x + x_in
148+
hidden_states = block(hidden_states, context=context)
149+
hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2)
150+
hidden_states = self.proj_out(hidden_states)
151+
return hidden_states + residual
152152

153153

154154
class BasicTransformerBlock(nn.Module):
@@ -192,12 +192,12 @@ def _set_attention_slice(self, slice_size):
192192
self.attn1._slice_size = slice_size
193193
self.attn2._slice_size = slice_size
194194

195-
def forward(self, x, context=None):
196-
x = x.contiguous() if x.device.type == "mps" else x
197-
x = self.attn1(self.norm1(x)) + x
198-
x = self.attn2(self.norm2(x), context=context) + x
199-
x = self.ff(self.norm3(x)) + x
200-
return x
195+
def forward(self, hidden_states, context=None):
196+
hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states
197+
hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
198+
hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states
199+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
200+
return hidden_states
201201

202202

203203
class CrossAttention(nn.Module):
@@ -247,22 +247,22 @@ def reshape_batch_dim_to_heads(self, tensor):
247247
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
248248
return tensor
249249

250-
def forward(self, x, context=None, mask=None):
251-
batch_size, sequence_length, dim = x.shape
250+
def forward(self, hidden_states, context=None, mask=None):
251+
batch_size, sequence_length, dim = hidden_states.shape
252252

253-
q = self.to_q(x)
254-
context = context if context is not None else x
255-
k = self.to_k(context)
256-
v = self.to_v(context)
253+
query = self.to_q(hidden_states)
254+
context = context if context is not None else hidden_states
255+
key = self.to_k(context)
256+
value = self.to_v(context)
257257

258-
q = self.reshape_heads_to_batch_dim(q)
259-
k = self.reshape_heads_to_batch_dim(k)
260-
v = self.reshape_heads_to_batch_dim(v)
258+
query = self.reshape_heads_to_batch_dim(query)
259+
key = self.reshape_heads_to_batch_dim(key)
260+
value = self.reshape_heads_to_batch_dim(value)
261261

262262
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
263263

264264
# attention, what we cannot get enough of
265-
hidden_states = self._attention(q, k, v, sequence_length, dim)
265+
hidden_states = self._attention(query, key, value, sequence_length, dim)
266266

267267
return self.to_out(hidden_states)
268268

@@ -308,8 +308,8 @@ def __init__(
308308

309309
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
310310

311-
def forward(self, x):
312-
return self.net(x)
311+
def forward(self, hidden_states):
312+
return self.net(hidden_states)
313313

314314

315315
# feedforward
@@ -326,6 +326,6 @@ def __init__(self, dim_in: int, dim_out: int):
326326
super().__init__()
327327
self.proj = nn.Linear(dim_in, dim_out * 2)
328328

329-
def forward(self, x):
330-
x, gate = self.proj(x).chunk(2, dim=-1)
331-
return x * F.gelu(gate)
329+
def forward(self, hidden_states):
330+
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
331+
return hidden_states * F.gelu(gate)

0 commit comments

Comments
 (0)