Skip to content

Commit 095b2bb

Browse files
update
1 parent 88bc460 commit 095b2bb

File tree

1 file changed

+50
-31
lines changed

1 file changed

+50
-31
lines changed

paddlenlp/transformers/ring_flash_attention.py

Lines changed: 50 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,7 @@ def __init__(self, group, local_key, local_value):
5050
self._reqs = []
5151

5252
def wait(self):
53-
# for req in self._reqs:
54-
# req.wait()
55-
# self._reqs = None
53+
# TODO(zhangyuqin1998):batch_isend_irecv异步流下,无法wait,需要修复。对性能有影响。
5654
paddle.device.synchronize()
5755

5856
def add_to_buffers(self, key, value):
@@ -126,12 +124,14 @@ def balanced_ring_flash_attention_fwd_func(
126124
comm_buffer = RingCommunicator(group, local_key, local_value)
127125
local_q_seq_len = local_query.shape[1]
128126

129-
out, lse = None, None
127+
out, lse, k_cache, v_cache = None, None, dict(), dict()
130128

131129
if attn_mask is not None:
132130
attn_masks_list = paddle.split(attn_mask, num_or_sections=cp_size * 2, axis=3)
133131
if is_causal:
134-
local_query_second_chunk = local_query[:, local_q_seq_len // 2 :, :, :].clone().contiguous()
132+
local_query_second_chunk = paddle.slice(
133+
local_query, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len]
134+
)
135135
for step in range(cp_size):
136136
block_k, block_v = comm_buffer.get_buffers()
137137

@@ -153,14 +153,15 @@ def balanced_ring_flash_attention_fwd_func(
153153
not training,
154154
"",
155155
)
156-
block_lse = paddle.unsqueeze(paddle.transpose(block_lse, [0, 2, 1]), axis=-1)
156+
block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)
157157
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
158158
else:
159+
# block_k and block_v is from rank (group.rank - step) % cp_size
159160
if step == 0:
160161
block_out, _, block_lse, _ = _C_ops.flash_attn(
161162
local_query, block_k, block_v, fixed_seed_offset, None, dropout, True, False, not training, ""
162163
)
163-
block_lse = paddle.unsqueeze(paddle.transpose(block_lse, [0, 2, 1]), axis=-1)
164+
block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)
164165
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
165166
elif step > rank:
166167
block_out, _, block_lse, _ = _C_ops.flash_attn(
@@ -175,14 +176,16 @@ def balanced_ring_flash_attention_fwd_func(
175176
not training,
176177
"",
177178
)
178-
block_lse = block_lse[:, :, 0 : (local_q_seq_len // 2)]
179-
block_lse = paddle.unsqueeze(paddle.transpose(block_lse, [0, 2, 1]), axis=-1)
179+
block_lse = paddle.slice(block_lse, axes=[1], starts=[0], ends=[local_q_seq_len // 2])
180+
block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)
180181
out, lse = update_out_and_lse(out, lse, block_out, block_lse, True)
181182
else:
183+
block_k = paddle.slice(block_k, axes=[1], starts=[0], ends=[local_q_seq_len // 2])
184+
block_v = paddle.slice(block_v, axes=[1], starts=[0], ends=[local_q_seq_len // 2])
182185
block_out, _, block_lse, _ = _C_ops.flash_attn(
183186
local_query,
184-
block_k[:, : local_q_seq_len // 2, :, :],
185-
block_v[:, : local_q_seq_len // 2, :, :],
187+
block_k,
188+
block_v,
186189
fixed_seed_offset,
187190
None,
188191
dropout,
@@ -191,20 +194,23 @@ def balanced_ring_flash_attention_fwd_func(
191194
not training,
192195
"",
193196
)
194-
block_lse = paddle.unsqueeze(paddle.transpose(block_lse, [0, 2, 1]), axis=-1)
197+
block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)
195198
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
199+
k_cache[step] = block_k
200+
v_cache[step] = block_v
196201

197-
# if step != cp_size - 1:
198-
# comm_buffer.wait()
202+
# TODO(zhangyuqin1998):batch_isend_irecv异步流下,无法wait,需要修复。对性能有影响。
199203
paddle.device.synchronize()
200204

201205
out = out.to(local_query.dtype)
202-
lse = paddle.transpose(paddle.squeeze(lse, axis=-1), [0, 2, 1])
203-
return out, lse
206+
lse = paddle.transpose_(paddle.squeeze_(lse, axis=-1), [0, 2, 1])
207+
return out, lse, k_cache, v_cache
204208

205209

206210
def balanced_ring_flash_attention_bwd_func(
207211
group,
212+
k_cache,
213+
v_cache,
208214
out_grad,
209215
local_query,
210216
local_key,
@@ -228,10 +234,14 @@ def balanced_ring_flash_attention_bwd_func(
228234
grad_comm_buffer = RingCommunicator(group, key_grad_buffer, value_grad_buffer)
229235

230236
if is_causal:
231-
local_query_second_chunk = local_query[:, local_q_seq_len // 2 :, :, :].clone().contiguous()
232-
local_out_second_chunk = local_out[:, local_q_seq_len // 2 :, :, :].clone().contiguous()
233-
lse_second_chunk = lse[:, :, local_q_seq_len // 2 :].clone().contiguous()
234-
out_grad_second_chunk = out_grad[:, local_q_seq_len // 2 :, :, :].clone().contiguous()
237+
local_query_second_chunk = paddle.slice(
238+
local_query, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len]
239+
)
240+
local_out_second_chunk = paddle.slice(
241+
local_out, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len]
242+
)
243+
lse_second_chunk = paddle.slice(lse, axes=[2], starts=[local_q_seq_len // 2], ends=[local_q_seq_len])
244+
out_grad_second_chunk = paddle.slice(out_grad, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len])
235245

236246
if attn_mask is not None:
237247
attn_masks_list = paddle.split(attn_mask, num_or_sections=cp_size * 2, axis=3)
@@ -279,8 +289,8 @@ def balanced_ring_flash_attention_bwd_func(
279289
else:
280290
block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd(
281291
local_query,
282-
block_k[:, : local_q_seq_len // 2, :, :],
283-
block_v[:, : local_q_seq_len // 2, :, :],
292+
k_cache[step],
293+
v_cache[step],
284294
local_out,
285295
lse,
286296
fixed_seed_offset,
@@ -291,10 +301,7 @@ def balanced_ring_flash_attention_bwd_func(
291301
)
292302
query_grad_buffer += block_q_grad
293303

294-
# if step != cp_size - 1:
295-
# kv_comm_buffer.wait()
296-
# if step != 0:
297-
# grad_comm_buffer.wait()
304+
# TODO(zhangyuqin1998):batch_isend_irecv异步流下,无法wait,需要修复。对性能有影响。
298305
paddle.device.synchronize()
299306

300307
grad_comm_buffer.add_to_buffers(block_k_grad, block_v_grad)
@@ -328,10 +335,10 @@ def forward(
328335
if attn_mask is not None:
329336
is_causal = False
330337

331-
out, lse = balanced_ring_flash_attention_fwd_func(
338+
out, lse, k_cache, v_cache = balanced_ring_flash_attention_fwd_func(
332339
group, query, key, value, fixed_seed_offset, attn_mask, dropout, is_causal, training
333340
)
334-
ctx.save_for_backward(query, key, value, out, lse, attn_mask)
341+
ctx.save_for_backward(query, key, value, out, lse, attn_mask, k_cache, v_cache)
335342
ctx.group = group
336343
ctx.fixed_seed_offset = fixed_seed_offset
337344
ctx.dropout = dropout
@@ -340,17 +347,29 @@ def forward(
340347

341348
@staticmethod
342349
def backward(ctx, out_grad):
343-
query, key, value, out, lse, attn_mask = ctx.saved_tensor()
350+
query, key, value, out, lse, attn_mask, k_cache, v_cache = ctx.saved_tensor()
344351
group = ctx.group
345352
fixed_seed_offset = ctx.fixed_seed_offset
346353
dropout = ctx.dropout
347354
is_causal = ctx.is_causal
348355

349356
if fixed_seed_offset is None:
350-
fixed_seed_offset = paddle.to_tensor([0, 0], place=paddle.CPUPlace(), dtype=paddle.int64).contiguous()
357+
fixed_seed_offset = paddle.to_tensor([0, 0], place=paddle.CPUPlace(), dtype=paddle.int64)
351358

352359
query_grad, key_grad, value_grad = balanced_ring_flash_attention_bwd_func(
353-
group, out_grad, query, key, value, out, lse, fixed_seed_offset, attn_mask, dropout, is_causal
360+
group,
361+
k_cache,
362+
v_cache,
363+
out_grad,
364+
query,
365+
key,
366+
value,
367+
out,
368+
lse,
369+
fixed_seed_offset,
370+
attn_mask,
371+
dropout,
372+
is_causal,
354373
)
355374
if attn_mask is not None and not attn_mask.stop_gradient:
356375
return query_grad, key_grad, value_grad, None

0 commit comments

Comments
 (0)