Skip to content

Commit a360468

Browse files
update
1 parent 095b2bb commit a360468

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

paddlenlp/transformers/ring_flash_attention.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,14 @@ def wait(self):
5555

5656
def add_to_buffers(self, key, value):
5757
if key.shape != self._k_buffer[self._next_buffer_idx].shape:
58-
self._k_buffer[self._next_buffer_idx][:, : key.shape[1], :, :] += key
59-
self._v_buffer[self._next_buffer_idx][:, : key.shape[1], :, :] += value
58+
k_buffer_chunk = paddle.slice(
59+
self._k_buffer[self._next_buffer_idx], axes=[1], starts=[0], ends=[key.shape[1]]
60+
)
61+
v_buffer_chunk = paddle.slice(
62+
self._v_buffer[self._next_buffer_idx], axes=[1], starts=[0], ends=[value.shape[1]]
63+
)
64+
k_buffer_chunk += key
65+
v_buffer_chunk += value
6066
else:
6167
self._k_buffer[self._next_buffer_idx] += key
6268
self._v_buffer[self._next_buffer_idx] += value
@@ -82,13 +88,13 @@ def update_out_and_lse(old_out, old_lse, block_out, block_lse, second_chunk_only
8288
return block_out.to("float32"), block_lse.to("float32")
8389

8490
if second_chunk_only:
85-
second_chunk_out = old_out[:, old_out.shape[1] // 2 :, :, :]
86-
second_chunk_lse = old_lse[:, old_lse.shape[1] // 2 :, :, :]
91+
second_chunk_out_ = paddle.slice(old_out, axes=[1], starts=[old_out.shape[1] // 2], ends=[old_out.shape[1]])
92+
second_chunk_lse_ = paddle.slice(old_lse, axes=[1], starts=[old_lse.shape[1] // 2], ends=[old_lse.shape[1]])
8793
second_chunk_out, second_chunk_lse = update_out_and_lse(
88-
second_chunk_out, second_chunk_lse, block_out, block_lse
94+
second_chunk_out_, second_chunk_lse_, block_out, block_lse
8995
)
90-
old_out[:, old_out.shape[1] // 2 :, :, :] = second_chunk_out
91-
old_lse[:, old_lse.shape[1] // 2 :, :, :] = second_chunk_lse
96+
paddle.assign(second_chunk_out, second_chunk_out_)
97+
paddle.assign(second_chunk_lse, second_chunk_lse_)
9298
return old_out, old_lse
9399
else:
94100
block_out, block_lse = block_out.to("float32"), block_lse.to("float32")
@@ -242,6 +248,9 @@ def balanced_ring_flash_attention_bwd_func(
242248
)
243249
lse_second_chunk = paddle.slice(lse, axes=[2], starts=[local_q_seq_len // 2], ends=[local_q_seq_len])
244250
out_grad_second_chunk = paddle.slice(out_grad, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len])
251+
query_grad_buffer_second_chunk = paddle.slice(
252+
query_grad_buffer, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len]
253+
)
245254

246255
if attn_mask is not None:
247256
attn_masks_list = paddle.split(attn_mask, num_or_sections=cp_size * 2, axis=3)
@@ -285,7 +294,7 @@ def balanced_ring_flash_attention_bwd_func(
285294
dropout,
286295
False,
287296
)
288-
query_grad_buffer[:, local_q_seq_len // 2 :, :, :] += block_q_grad
297+
query_grad_buffer_second_chunk += block_q_grad
289298
else:
290299
block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd(
291300
local_query,

0 commit comments

Comments
 (0)