Skip to content

Commit 8b45096

Browse files
[CrossAttention] add different method for sliced attention (#446)
* add different method for sliced attention * Update src/diffusers/models/attention.py * Apply suggestions from code review * Update src/diffusers/models/attention.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
1 parent 1a69c6f commit 8b45096

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

src/diffusers/models/attention.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,11 +262,24 @@ def forward(self, hidden_states, context=None, mask=None):
262262
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
263263

264264
# attention, what we cannot get enough of
265-
hidden_states = self._attention(query, key, value, sequence_length, dim)
265+
266+
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
267+
hidden_states = self._attention(query, key, value)
268+
else:
269+
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
266270

267271
return self.to_out(hidden_states)
268272

269-
def _attention(self, query, key, value, sequence_length, dim):
273+
def _attention(self, query, key, value):
274+
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
275+
attention_probs = attention_scores.softmax(dim=-1)
276+
# compute attention output
277+
hidden_states = torch.matmul(attention_probs, value)
278+
# reshape hidden_states
279+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
280+
return hidden_states
281+
282+
def _sliced_attention(self, query, key, value, sequence_length, dim):
270283
batch_size_attention = query.shape[0]
271284
hidden_states = torch.zeros(
272285
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype

0 commit comments

Comments
 (0)