@@ -137,18 +137,18 @@ def _set_attention_slice(self, slice_size):
137
137
for block in self .transformer_blocks :
138
138
block ._set_attention_slice (slice_size )
139
139
140
- def forward (self , x , context = None ):
140
+ def forward (self , hidden_states , context = None ):
141
141
# note: if no context is given, cross-attention defaults to self-attention
142
- b , c , h , w = x .shape
143
- x_in = x
144
- x = self .norm (x )
145
- x = self .proj_in (x )
146
- x = x .permute (0 , 2 , 3 , 1 ).reshape (b , h * w , c )
142
+ batch , channel , height , weight = hidden_states .shape
143
+ residual = hidden_states
144
+ hidden_states = self .norm (hidden_states )
145
+ hidden_states = self .proj_in (hidden_states )
146
+ hidden_states = hidden_states .permute (0 , 2 , 3 , 1 ).reshape (batch , height * weight , channel )
147
147
for block in self .transformer_blocks :
148
- x = block (x , context = context )
149
- x = x .reshape (b , h , w , c ).permute (0 , 3 , 1 , 2 )
150
- x = self .proj_out (x )
151
- return x + x_in
148
+ hidden_states = block (hidden_states , context = context )
149
+ hidden_states = hidden_states .reshape (batch , height , weight , channel ).permute (0 , 3 , 1 , 2 )
150
+ hidden_states = self .proj_out (hidden_states )
151
+ return hidden_states + residual
152
152
153
153
154
154
class BasicTransformerBlock (nn .Module ):
@@ -192,12 +192,12 @@ def _set_attention_slice(self, slice_size):
192
192
self .attn1 ._slice_size = slice_size
193
193
self .attn2 ._slice_size = slice_size
194
194
195
- def forward (self , x , context = None ):
196
- x = x .contiguous () if x .device .type == "mps" else x
197
- x = self .attn1 (self .norm1 (x )) + x
198
- x = self .attn2 (self .norm2 (x ), context = context ) + x
199
- x = self .ff (self .norm3 (x )) + x
200
- return x
195
+ def forward (self , hidden_states , context = None ):
196
+ hidden_states = hidden_states .contiguous () if hidden_states .device .type == "mps" else hidden_states
197
+ hidden_states = self .attn1 (self .norm1 (hidden_states )) + hidden_states
198
+ hidden_states = self .attn2 (self .norm2 (hidden_states ), context = context ) + hidden_states
199
+ hidden_states = self .ff (self .norm3 (hidden_states )) + hidden_states
200
+ return hidden_states
201
201
202
202
203
203
class CrossAttention (nn .Module ):
@@ -247,22 +247,22 @@ def reshape_batch_dim_to_heads(self, tensor):
247
247
tensor = tensor .permute (0 , 2 , 1 , 3 ).reshape (batch_size // head_size , seq_len , dim * head_size )
248
248
return tensor
249
249
250
- def forward (self , x , context = None , mask = None ):
251
- batch_size , sequence_length , dim = x .shape
250
+ def forward (self , hidden_states , context = None , mask = None ):
251
+ batch_size , sequence_length , dim = hidden_states .shape
252
252
253
- q = self .to_q (x )
254
- context = context if context is not None else x
255
- k = self .to_k (context )
256
- v = self .to_v (context )
253
+ query = self .to_q (hidden_states )
254
+ context = context if context is not None else hidden_states
255
+ key = self .to_k (context )
256
+ value = self .to_v (context )
257
257
258
- q = self .reshape_heads_to_batch_dim (q )
259
- k = self .reshape_heads_to_batch_dim (k )
260
- v = self .reshape_heads_to_batch_dim (v )
258
+ query = self .reshape_heads_to_batch_dim (query )
259
+ key = self .reshape_heads_to_batch_dim (key )
260
+ value = self .reshape_heads_to_batch_dim (value )
261
261
262
262
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
263
263
264
264
# attention, what we cannot get enough of
265
- hidden_states = self ._attention (q , k , v , sequence_length , dim )
265
+ hidden_states = self ._attention (query , key , value , sequence_length , dim )
266
266
267
267
return self .to_out (hidden_states )
268
268
@@ -308,8 +308,8 @@ def __init__(
308
308
309
309
self .net = nn .Sequential (project_in , nn .Dropout (dropout ), nn .Linear (inner_dim , dim_out ))
310
310
311
- def forward (self , x ):
312
- return self .net (x )
311
+ def forward (self , hidden_states ):
312
+ return self .net (hidden_states )
313
313
314
314
315
315
# feedforward
@@ -326,6 +326,6 @@ def __init__(self, dim_in: int, dim_out: int):
326
326
super ().__init__ ()
327
327
self .proj = nn .Linear (dim_in , dim_out * 2 )
328
328
329
- def forward (self , x ):
330
- x , gate = self .proj (x ).chunk (2 , dim = - 1 )
331
- return x * F .gelu (gate )
329
+ def forward (self , hidden_states ):
330
+ hidden_states , gate = self .proj (hidden_states ).chunk (2 , dim = - 1 )
331
+ return hidden_states * F .gelu (gate )
0 commit comments