@@ -50,9 +50,7 @@ def __init__(self, group, local_key, local_value):
50
50
self ._reqs = []
51
51
52
52
def wait (self ):
53
- # for req in self._reqs:
54
- # req.wait()
55
- # self._reqs = None
53
+ # TODO(zhangyuqin1998):batch_isend_irecv异步流下,无法wait,需要修复。对性能有影响。
56
54
paddle .device .synchronize ()
57
55
58
56
def add_to_buffers (self , key , value ):
@@ -126,12 +124,14 @@ def balanced_ring_flash_attention_fwd_func(
126
124
comm_buffer = RingCommunicator (group , local_key , local_value )
127
125
local_q_seq_len = local_query .shape [1 ]
128
126
129
- out , lse = None , None
127
+ out , lse , k_cache , v_cache = None , None , dict (), dict ()
130
128
131
129
if attn_mask is not None :
132
130
attn_masks_list = paddle .split (attn_mask , num_or_sections = cp_size * 2 , axis = 3 )
133
131
if is_causal :
134
- local_query_second_chunk = local_query [:, local_q_seq_len // 2 :, :, :].clone ().contiguous ()
132
+ local_query_second_chunk = paddle .slice (
133
+ local_query , axes = [1 ], starts = [local_q_seq_len // 2 ], ends = [local_q_seq_len ]
134
+ )
135
135
for step in range (cp_size ):
136
136
block_k , block_v = comm_buffer .get_buffers ()
137
137
@@ -153,14 +153,15 @@ def balanced_ring_flash_attention_fwd_func(
153
153
not training ,
154
154
"" ,
155
155
)
156
- block_lse = paddle .unsqueeze (paddle .transpose (block_lse , [0 , 2 , 1 ]), axis = - 1 )
156
+ block_lse = paddle .unsqueeze_ (paddle .transpose_ (block_lse , [0 , 2 , 1 ]), axis = - 1 )
157
157
out , lse = update_out_and_lse (out , lse , block_out , block_lse )
158
158
else :
159
+ # block_k and block_v is from rank (group.rank - step) % cp_size
159
160
if step == 0 :
160
161
block_out , _ , block_lse , _ = _C_ops .flash_attn (
161
162
local_query , block_k , block_v , fixed_seed_offset , None , dropout , True , False , not training , ""
162
163
)
163
- block_lse = paddle .unsqueeze (paddle .transpose (block_lse , [0 , 2 , 1 ]), axis = - 1 )
164
+ block_lse = paddle .unsqueeze_ (paddle .transpose_ (block_lse , [0 , 2 , 1 ]), axis = - 1 )
164
165
out , lse = update_out_and_lse (out , lse , block_out , block_lse )
165
166
elif step > rank :
166
167
block_out , _ , block_lse , _ = _C_ops .flash_attn (
@@ -175,14 +176,16 @@ def balanced_ring_flash_attention_fwd_func(
175
176
not training ,
176
177
"" ,
177
178
)
178
- block_lse = block_lse [:, :, 0 : ( local_q_seq_len // 2 )]
179
- block_lse = paddle .unsqueeze (paddle .transpose (block_lse , [0 , 2 , 1 ]), axis = - 1 )
179
+ block_lse = paddle . slice ( block_lse , axes = [ 1 ], starts = [ 0 ], ends = [ local_q_seq_len // 2 ])
180
+ block_lse = paddle .unsqueeze_ (paddle .transpose_ (block_lse , [0 , 2 , 1 ]), axis = - 1 )
180
181
out , lse = update_out_and_lse (out , lse , block_out , block_lse , True )
181
182
else :
183
+ block_k = paddle .slice (block_k , axes = [1 ], starts = [0 ], ends = [local_q_seq_len // 2 ])
184
+ block_v = paddle .slice (block_v , axes = [1 ], starts = [0 ], ends = [local_q_seq_len // 2 ])
182
185
block_out , _ , block_lse , _ = _C_ops .flash_attn (
183
186
local_query ,
184
- block_k [:, : local_q_seq_len // 2 , :, :] ,
185
- block_v [:, : local_q_seq_len // 2 , :, :] ,
187
+ block_k ,
188
+ block_v ,
186
189
fixed_seed_offset ,
187
190
None ,
188
191
dropout ,
@@ -191,20 +194,23 @@ def balanced_ring_flash_attention_fwd_func(
191
194
not training ,
192
195
"" ,
193
196
)
194
- block_lse = paddle .unsqueeze (paddle .transpose (block_lse , [0 , 2 , 1 ]), axis = - 1 )
197
+ block_lse = paddle .unsqueeze_ (paddle .transpose_ (block_lse , [0 , 2 , 1 ]), axis = - 1 )
195
198
out , lse = update_out_and_lse (out , lse , block_out , block_lse )
199
+ k_cache [step ] = block_k
200
+ v_cache [step ] = block_v
196
201
197
- # if step != cp_size - 1:
198
- # comm_buffer.wait()
202
+ # TODO(zhangyuqin1998):batch_isend_irecv异步流下,无法wait,需要修复。对性能有影响。
199
203
paddle .device .synchronize ()
200
204
201
205
out = out .to (local_query .dtype )
202
- lse = paddle .transpose (paddle .squeeze (lse , axis = - 1 ), [0 , 2 , 1 ])
203
- return out , lse
206
+ lse = paddle .transpose_ (paddle .squeeze_ (lse , axis = - 1 ), [0 , 2 , 1 ])
207
+ return out , lse , k_cache , v_cache
204
208
205
209
206
210
def balanced_ring_flash_attention_bwd_func (
207
211
group ,
212
+ k_cache ,
213
+ v_cache ,
208
214
out_grad ,
209
215
local_query ,
210
216
local_key ,
@@ -228,10 +234,14 @@ def balanced_ring_flash_attention_bwd_func(
228
234
grad_comm_buffer = RingCommunicator (group , key_grad_buffer , value_grad_buffer )
229
235
230
236
if is_causal :
231
- local_query_second_chunk = local_query [:, local_q_seq_len // 2 :, :, :].clone ().contiguous ()
232
- local_out_second_chunk = local_out [:, local_q_seq_len // 2 :, :, :].clone ().contiguous ()
233
- lse_second_chunk = lse [:, :, local_q_seq_len // 2 :].clone ().contiguous ()
234
- out_grad_second_chunk = out_grad [:, local_q_seq_len // 2 :, :, :].clone ().contiguous ()
237
+ local_query_second_chunk = paddle .slice (
238
+ local_query , axes = [1 ], starts = [local_q_seq_len // 2 ], ends = [local_q_seq_len ]
239
+ )
240
+ local_out_second_chunk = paddle .slice (
241
+ local_out , axes = [1 ], starts = [local_q_seq_len // 2 ], ends = [local_q_seq_len ]
242
+ )
243
+ lse_second_chunk = paddle .slice (lse , axes = [2 ], starts = [local_q_seq_len // 2 ], ends = [local_q_seq_len ])
244
+ out_grad_second_chunk = paddle .slice (out_grad , axes = [1 ], starts = [local_q_seq_len // 2 ], ends = [local_q_seq_len ])
235
245
236
246
if attn_mask is not None :
237
247
attn_masks_list = paddle .split (attn_mask , num_or_sections = cp_size * 2 , axis = 3 )
@@ -279,8 +289,8 @@ def balanced_ring_flash_attention_bwd_func(
279
289
else :
280
290
block_q_grad , block_k_grad , block_v_grad = flash_attn_bwd (
281
291
local_query ,
282
- block_k [:, : local_q_seq_len // 2 , :, : ],
283
- block_v [:, : local_q_seq_len // 2 , :, : ],
292
+ k_cache [ step ],
293
+ v_cache [ step ],
284
294
local_out ,
285
295
lse ,
286
296
fixed_seed_offset ,
@@ -291,10 +301,7 @@ def balanced_ring_flash_attention_bwd_func(
291
301
)
292
302
query_grad_buffer += block_q_grad
293
303
294
- # if step != cp_size - 1:
295
- # kv_comm_buffer.wait()
296
- # if step != 0:
297
- # grad_comm_buffer.wait()
304
+ # TODO(zhangyuqin1998):batch_isend_irecv异步流下,无法wait,需要修复。对性能有影响。
298
305
paddle .device .synchronize ()
299
306
300
307
grad_comm_buffer .add_to_buffers (block_k_grad , block_v_grad )
@@ -328,10 +335,10 @@ def forward(
328
335
if attn_mask is not None :
329
336
is_causal = False
330
337
331
- out , lse = balanced_ring_flash_attention_fwd_func (
338
+ out , lse , k_cache , v_cache = balanced_ring_flash_attention_fwd_func (
332
339
group , query , key , value , fixed_seed_offset , attn_mask , dropout , is_causal , training
333
340
)
334
- ctx .save_for_backward (query , key , value , out , lse , attn_mask )
341
+ ctx .save_for_backward (query , key , value , out , lse , attn_mask , k_cache , v_cache )
335
342
ctx .group = group
336
343
ctx .fixed_seed_offset = fixed_seed_offset
337
344
ctx .dropout = dropout
@@ -340,17 +347,29 @@ def forward(
340
347
341
348
@staticmethod
342
349
def backward (ctx , out_grad ):
343
- query , key , value , out , lse , attn_mask = ctx .saved_tensor ()
350
+ query , key , value , out , lse , attn_mask , k_cache , v_cache = ctx .saved_tensor ()
344
351
group = ctx .group
345
352
fixed_seed_offset = ctx .fixed_seed_offset
346
353
dropout = ctx .dropout
347
354
is_causal = ctx .is_causal
348
355
349
356
if fixed_seed_offset is None :
350
- fixed_seed_offset = paddle .to_tensor ([0 , 0 ], place = paddle .CPUPlace (), dtype = paddle .int64 ). contiguous ()
357
+ fixed_seed_offset = paddle .to_tensor ([0 , 0 ], place = paddle .CPUPlace (), dtype = paddle .int64 )
351
358
352
359
query_grad , key_grad , value_grad = balanced_ring_flash_attention_bwd_func (
353
- group , out_grad , query , key , value , out , lse , fixed_seed_offset , attn_mask , dropout , is_causal
360
+ group ,
361
+ k_cache ,
362
+ v_cache ,
363
+ out_grad ,
364
+ query ,
365
+ key ,
366
+ value ,
367
+ out ,
368
+ lse ,
369
+ fixed_seed_offset ,
370
+ attn_mask ,
371
+ dropout ,
372
+ is_causal ,
354
373
)
355
374
if attn_mask is not None and not attn_mask .stop_gradient :
356
375
return query_grad , key_grad , value_grad , None
0 commit comments