Skip to content

Commit 5adb0a7

Browse files
authored
use torch.matmul instead of einsum in attnetion. (#445)
* use torch.matmul instead of einsum * fix softmax
1 parent b2b3b1a commit 5adb0a7

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

src/diffusers/models/attention.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,11 +275,9 @@ def _attention(self, query, key, value, sequence_length, dim):
275275
for i in range(hidden_states.shape[0] // slice_size):
276276
start_idx = i * slice_size
277277
end_idx = (i + 1) * slice_size
278-
attn_slice = (
279-
torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale
280-
)
278+
attn_slice = torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
281279
attn_slice = attn_slice.softmax(dim=-1)
282-
attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
280+
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
283281

284282
hidden_states[start_idx:end_idx] = attn_slice
285283

0 commit comments

Comments
 (0)