@@ -55,8 +55,14 @@ def wait(self):
55
55
56
56
def add_to_buffers (self , key , value ):
57
57
if key .shape != self ._k_buffer [self ._next_buffer_idx ].shape :
58
- self ._k_buffer [self ._next_buffer_idx ][:, : key .shape [1 ], :, :] += key
59
- self ._v_buffer [self ._next_buffer_idx ][:, : key .shape [1 ], :, :] += value
58
+ k_buffer_chunk = paddle .slice (
59
+ self ._k_buffer [self ._next_buffer_idx ], axes = [1 ], starts = [0 ], ends = [key .shape [1 ]]
60
+ )
61
+ v_buffer_chunk = paddle .slice (
62
+ self ._v_buffer [self ._next_buffer_idx ], axes = [1 ], starts = [0 ], ends = [value .shape [1 ]]
63
+ )
64
+ k_buffer_chunk += key
65
+ v_buffer_chunk += value
60
66
else :
61
67
self ._k_buffer [self ._next_buffer_idx ] += key
62
68
self ._v_buffer [self ._next_buffer_idx ] += value
@@ -82,13 +88,13 @@ def update_out_and_lse(old_out, old_lse, block_out, block_lse, second_chunk_only
82
88
return block_out .to ("float32" ), block_lse .to ("float32" )
83
89
84
90
if second_chunk_only :
85
- second_chunk_out = old_out [:, old_out .shape [1 ] // 2 :, :, :]
86
- second_chunk_lse = old_lse [:, old_lse .shape [1 ] // 2 :, :, :]
91
+ second_chunk_out_ = paddle . slice ( old_out , axes = [ 1 ], starts = [ old_out .shape [1 ] // 2 ], ends = [ old_out . shape [ 1 ]])
92
+ second_chunk_lse_ = paddle . slice ( old_lse , axes = [ 1 ], starts = [ old_lse .shape [1 ] // 2 ], ends = [ old_lse . shape [ 1 ]])
87
93
second_chunk_out , second_chunk_lse = update_out_and_lse (
88
- second_chunk_out , second_chunk_lse , block_out , block_lse
94
+ second_chunk_out_ , second_chunk_lse_ , block_out , block_lse
89
95
)
90
- old_out [:, old_out . shape [ 1 ] // 2 :, :, :] = second_chunk_out
91
- old_lse [:, old_lse . shape [ 1 ] // 2 :, :, :] = second_chunk_lse
96
+ paddle . assign ( second_chunk_out , second_chunk_out_ )
97
+ paddle . assign ( second_chunk_lse , second_chunk_lse_ )
92
98
return old_out , old_lse
93
99
else :
94
100
block_out , block_lse = block_out .to ("float32" ), block_lse .to ("float32" )
@@ -242,6 +248,9 @@ def balanced_ring_flash_attention_bwd_func(
242
248
)
243
249
lse_second_chunk = paddle .slice (lse , axes = [2 ], starts = [local_q_seq_len // 2 ], ends = [local_q_seq_len ])
244
250
out_grad_second_chunk = paddle .slice (out_grad , axes = [1 ], starts = [local_q_seq_len // 2 ], ends = [local_q_seq_len ])
251
+ query_grad_buffer_second_chunk = paddle .slice (
252
+ query_grad_buffer , axes = [1 ], starts = [local_q_seq_len // 2 ], ends = [local_q_seq_len ]
253
+ )
245
254
246
255
if attn_mask is not None :
247
256
attn_masks_list = paddle .split (attn_mask , num_or_sections = cp_size * 2 , axis = 3 )
@@ -285,7 +294,7 @@ def balanced_ring_flash_attention_bwd_func(
285
294
dropout ,
286
295
False ,
287
296
)
288
- query_grad_buffer [:, local_q_seq_len // 2 :, :, :] += block_q_grad
297
+ query_grad_buffer_second_chunk += block_q_grad
289
298
else :
290
299
block_q_grad , block_k_grad , block_v_grad = flash_attn_bwd (
291
300
local_query ,
0 commit comments