From b38aac9bb25eef2cae2b9ec6467e0649c636ec43 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 9 Sep 2022 17:22:33 +0530 Subject: [PATCH 1/4] add different method for sliced attention --- src/diffusers/models/attention.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 094ca0fb2299..e73e093bc9c7 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -262,11 +262,23 @@ def forward(self, x, context=None, mask=None): # TODO(PVP) - mask is currently never used. Remember to re-implement when used # attention, what we cannot get enough of - hidden_states = self._attention(q, k, v, sequence_length, dim) + if self._slice_size is None: + hidden_states = self._attention(q, k, v) + else: + hidden_states = self._sliced_attention(q, k, v, sequence_length, dim) return self.to_out(hidden_states) - def _attention(self, query, key, value, sequence_length, dim): + def _attention(self, query, key, value): + attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale + attention_probs = attention_scores.softmax(dim=-1) + # compute attention output + hidden_states = torch.matmul(attention_probs, value) + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _sliced_attention(self, query, key, value, sequence_length, dim): batch_size_attention = query.shape[0] hidden_states = torch.zeros( (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype From 43ca8fb3143791845b38cb218333abcda1434c52 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 14 Sep 2022 15:13:47 +0200 Subject: [PATCH 2/4] Update src/diffusers/models/attention.py --- src/diffusers/models/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index e73e093bc9c7..870d8de70991 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -262,7 +262,7 @@ def forward(self, x, context=None, mask=None): # TODO(PVP) - mask is currently never used. Remember to re-implement when used # attention, what we cannot get enough of - if self._slice_size is None: + if self._slice_size is None or q.shape[0] // slice_size == 1: hidden_states = self._attention(q, k, v) else: hidden_states = self._sliced_attention(q, k, v, sequence_length, dim) From 0decdc208070ecfd3accc266c78a4d34d6f877f0 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 14 Sep 2022 15:18:55 +0200 Subject: [PATCH 3/4] Apply suggestions from code review --- src/diffusers/models/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 3eb317fb9936..12d3991ba58c 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -263,8 +263,8 @@ def forward(self, hidden_states, context=None, mask=None): # attention, what we cannot get enough of - if self._slice_size is None or q.shape[0] // slice_size == 1: - hidden_states = self._attention(query, key, value, ) + if self._slice_size is None or query.shape[0] // slice_size == 1: + hidden_states = self._attention(query, key, value) else: hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) From 5298cbe08411ca9ddff66e5ce43a97a325287c0e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 14 Sep 2022 15:24:27 +0200 Subject: [PATCH 4/4] Update src/diffusers/models/attention.py --- src/diffusers/models/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 12d3991ba58c..55062c322e4a 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -263,7 +263,7 @@ def forward(self, hidden_states, context=None, mask=None): # attention, what we cannot get enough of - if self._slice_size is None or query.shape[0] // slice_size == 1: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: hidden_states = self._attention(query, key, value) else: hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)