Skip to content

[Bug fixes] Fix ring attention #8740

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 47 additions & 80 deletions paddlenlp/transformers/ring_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,11 @@

def add_to_buffers(self, key, value):
if key.shape != self._k_buffer[self._next_buffer_idx].shape:
k_buffer_chunk = paddle.slice(
self._k_buffer[self._next_buffer_idx], axes=[1], starts=[0], ends=[key.shape[1]]
)
v_buffer_chunk = paddle.slice(
self._v_buffer[self._next_buffer_idx], axes=[1], starts=[0], ends=[value.shape[1]]
)
k_buffer_chunk += key
v_buffer_chunk += value
self._k_buffer[self._next_buffer_idx][:, : key.shape[1], :, :].add_(key)
self._v_buffer[self._next_buffer_idx][:, : key.shape[1], :, :].add_(value)

Check warning on line 59 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L58-L59

Added lines #L58 - L59 were not covered by tests
else:
self._k_buffer[self._next_buffer_idx] += key
self._v_buffer[self._next_buffer_idx] += value
self._k_buffer[self._next_buffer_idx].add_(key)
self._v_buffer[self._next_buffer_idx].add_(value)

Check warning on line 62 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L61-L62

Added lines #L61 - L62 were not covered by tests

def get_buffers(self):
return self._k_buffer[self._next_buffer_idx], self._v_buffer[self._next_buffer_idx]
Expand All @@ -84,23 +78,19 @@


def update_out_and_lse(old_out, old_lse, block_out, block_lse, second_chunk_only=False):
if old_out is None and old_lse is None:
return block_out.to("float32"), block_lse.to("float32")

if second_chunk_only:
second_chunk_out_ = paddle.slice(old_out, axes=[1], starts=[old_out.shape[1] // 2], ends=[old_out.shape[1]])
second_chunk_lse_ = paddle.slice(old_lse, axes=[1], starts=[old_lse.shape[1] // 2], ends=[old_lse.shape[1]])
second_chunk_out = old_out[:, old_out.shape[1] // 2 :, :, :]
second_chunk_lse = old_lse[:, old_lse.shape[1] // 2 :, :, :]

Check warning on line 83 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L82-L83

Added lines #L82 - L83 were not covered by tests
second_chunk_out, second_chunk_lse = update_out_and_lse(
second_chunk_out_, second_chunk_lse_, block_out, block_lse
second_chunk_out, second_chunk_lse, block_out, block_lse
)
paddle.assign(second_chunk_out, second_chunk_out_)
paddle.assign(second_chunk_lse, second_chunk_lse_)
old_out[:, old_out.shape[1] // 2 :, :, :] = second_chunk_out
old_lse[:, old_lse.shape[1] // 2 :, :, :] = second_chunk_lse

Check warning on line 88 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L87-L88

Added lines #L87 - L88 were not covered by tests
return old_out, old_lse
else:
block_out, block_lse = block_out.to("float32"), block_lse.to("float32")
with paddle.amp.auto_cast(enable=False, dtype="bfloat16"):
lse = old_lse - F.log_sigmoid(old_lse - block_lse)
return old_out - (old_out - block_out) * F.sigmoid(block_lse - old_lse), lse
return old_out - (old_out - block_out) * F.sigmoid(block_lse - old_lse), old_lse - F.log_sigmoid(

Check warning on line 91 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L91

Added line #L91 was not covered by tests
old_lse - block_lse
)


def get_chunk_id(rank, cp_size):
Expand Down Expand Up @@ -130,14 +120,10 @@
comm_buffer = RingCommunicator(group, local_key, local_value)
local_q_seq_len = local_query.shape[1]

out, lse, k_cache, v_cache = None, None, dict(), dict()

if attn_mask is not None:
attn_masks_list = paddle.split(attn_mask, num_or_sections=cp_size * 2, axis=3)
if is_causal:
local_query_second_chunk = paddle.slice(
local_query, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len]
)
local_query_second_chunk = local_query[:, local_q_seq_len // 2 :, :, :]

Check warning on line 126 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L126

Added line #L126 was not covered by tests
for step in range(cp_size):
block_k, block_v = comm_buffer.get_buffers()

Expand All @@ -159,16 +145,19 @@
not training,
"",
)
block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)

Check warning on line 148 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L148

Added line #L148 was not covered by tests

if step == 0:
out, lse = block_out, block_lse

Check warning on line 151 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L150-L151

Added lines #L150 - L151 were not covered by tests
else:
out, lse = update_out_and_lse(out, lse, block_out, block_lse)

Check warning on line 153 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L153

Added line #L153 was not covered by tests
else:
# block_k and block_v is from rank (group.rank - step) % cp_size
if step == 0:
block_out, _, block_lse, _ = _C_ops.flash_attn(
local_query, block_k, block_v, fixed_seed_offset, None, dropout, True, False, not training, ""
)
block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)
out, lse = block_out, block_lse

Check warning on line 160 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L159-L160

Added lines #L159 - L160 were not covered by tests
elif step > rank:
block_out, _, block_lse, _ = _C_ops.flash_attn(
local_query_second_chunk,
Expand All @@ -182,16 +171,14 @@
not training,
"",
)
block_lse = paddle.slice(block_lse, axes=[1], starts=[0], ends=[local_q_seq_len // 2])
block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)
block_lse = block_lse[:, :, 0 : (local_q_seq_len // 2)]
paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)

Check warning on line 175 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L174-L175

Added lines #L174 - L175 were not covered by tests
out, lse = update_out_and_lse(out, lse, block_out, block_lse, True)
else:
block_k = paddle.slice(block_k, axes=[1], starts=[0], ends=[local_q_seq_len // 2])
block_v = paddle.slice(block_v, axes=[1], starts=[0], ends=[local_q_seq_len // 2])
block_out, _, block_lse, _ = _C_ops.flash_attn(
local_query,
block_k,
block_v,
block_k[:, : local_q_seq_len // 2, :, :],
block_v[:, : local_q_seq_len // 2, :, :],
fixed_seed_offset,
None,
dropout,
Expand All @@ -200,23 +187,19 @@
not training,
"",
)
block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)
paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)

Check warning on line 190 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L190

Added line #L190 was not covered by tests
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
k_cache[step] = block_k
v_cache[step] = block_v

# TODO(zhangyuqin1998):batch_isend_irecv异步流下,无法wait,需要修复。对性能有影响。
# if step != cp_size - 1:
# comm_buffer.wait()
paddle.device.synchronize()

out = out.to(local_query.dtype)
lse = paddle.transpose_(paddle.squeeze_(lse, axis=-1), [0, 2, 1])
return out, lse, k_cache, v_cache
return out.to(local_query.dtype), paddle.transpose_(paddle.squeeze(lse, axis=-1), [0, 2, 1])

Check warning on line 198 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L198

Added line #L198 was not covered by tests


def balanced_ring_flash_attention_bwd_func(
group,
k_cache,
v_cache,
out_grad,
local_query,
local_key,
Expand All @@ -240,17 +223,10 @@
grad_comm_buffer = RingCommunicator(group, key_grad_buffer, value_grad_buffer)

if is_causal:
local_query_second_chunk = paddle.slice(
local_query, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len]
)
local_out_second_chunk = paddle.slice(
local_out, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len]
)
lse_second_chunk = paddle.slice(lse, axes=[2], starts=[local_q_seq_len // 2], ends=[local_q_seq_len])
out_grad_second_chunk = paddle.slice(out_grad, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len])
query_grad_buffer_second_chunk = paddle.slice(
query_grad_buffer, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len]
)
local_query_second_chunk = local_query[:, local_q_seq_len // 2 :, :, :]
local_out_second_chunk = local_out[:, local_q_seq_len // 2 :, :, :]
lse_second_chunk = lse[:, :, local_q_seq_len // 2 :]
out_grad_second_chunk = out_grad[:, local_q_seq_len // 2 :, :, :]

Check warning on line 229 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L226-L229

Added lines #L226 - L229 were not covered by tests

if attn_mask is not None:
attn_masks_list = paddle.split(attn_mask, num_or_sections=cp_size * 2, axis=3)
Expand All @@ -274,13 +250,13 @@
dropout,
False,
)
query_grad_buffer += block_q_grad
query_grad_buffer.add_(block_q_grad)

Check warning on line 253 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L253

Added line #L253 was not covered by tests
else:
if step == 0:
block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd(
local_query, block_k, block_v, local_out, lse, fixed_seed_offset, None, out_grad, dropout, True
)
query_grad_buffer += block_q_grad
query_grad_buffer.add_(block_q_grad)

Check warning on line 259 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L259

Added line #L259 was not covered by tests
elif step > rank:
block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd(
local_query_second_chunk,
Expand All @@ -294,12 +270,12 @@
dropout,
False,
)
query_grad_buffer_second_chunk += block_q_grad
query_grad_buffer[:, local_q_seq_len // 2 :, :, :].add_(block_q_grad)

Check warning on line 273 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L273

Added line #L273 was not covered by tests
else:
block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd(
local_query,
k_cache[step],
v_cache[step],
block_k[:, : local_q_seq_len // 2, :, :],
block_v[:, : local_q_seq_len // 2, :, :],
local_out,
lse,
fixed_seed_offset,
Expand All @@ -308,9 +284,12 @@
dropout,
False,
)
query_grad_buffer += block_q_grad
query_grad_buffer.add_(block_q_grad)

Check warning on line 287 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L287

Added line #L287 was not covered by tests

# TODO(zhangyuqin1998):batch_isend_irecv异步流下,无法wait,需要修复。对性能有影响。
# if step != cp_size - 1:
# kv_comm_buffer.wait()
# if step != 0:
# grad_comm_buffer.wait()
paddle.device.synchronize()

grad_comm_buffer.add_to_buffers(block_k_grad, block_v_grad)
Expand Down Expand Up @@ -344,10 +323,10 @@
if attn_mask is not None:
is_causal = False

out, lse, k_cache, v_cache = balanced_ring_flash_attention_fwd_func(
out, lse = balanced_ring_flash_attention_fwd_func(

Check warning on line 326 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L326

Added line #L326 was not covered by tests
group, query, key, value, fixed_seed_offset, attn_mask, dropout, is_causal, training
)
ctx.save_for_backward(query, key, value, out, lse, attn_mask, k_cache, v_cache)
ctx.save_for_backward(query, key, value, out, lse, attn_mask)

Check warning on line 329 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L329

Added line #L329 was not covered by tests
ctx.group = group
ctx.fixed_seed_offset = fixed_seed_offset
ctx.dropout = dropout
Expand All @@ -356,7 +335,7 @@

@staticmethod
def backward(ctx, out_grad):
query, key, value, out, lse, attn_mask, k_cache, v_cache = ctx.saved_tensor()
query, key, value, out, lse, attn_mask = ctx.saved_tensor()

Check warning on line 338 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L338

Added line #L338 was not covered by tests
group = ctx.group
fixed_seed_offset = ctx.fixed_seed_offset
dropout = ctx.dropout
Expand All @@ -366,19 +345,7 @@
fixed_seed_offset = paddle.to_tensor([0, 0], place=paddle.CPUPlace(), dtype=paddle.int64)

query_grad, key_grad, value_grad = balanced_ring_flash_attention_bwd_func(
group,
k_cache,
v_cache,
out_grad,
query,
key,
value,
out,
lse,
fixed_seed_offset,
attn_mask,
dropout,
is_causal,
group, out_grad, query, key, value, out, lse, fixed_seed_offset, attn_mask, dropout, is_causal
)
if attn_mask is not None and not attn_mask.stop_gradient:
return query_grad, key_grad, value_grad, None
Expand Down
9 changes: 4 additions & 5 deletions tests/transformers/test_ring_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,16 @@ def single_test(self, bsz, seq_len_per_device, head_num, head_dim, is_causal, us
)
ref_out = scaled_dot_product_attention(query, key, value, is_causal=is_causal, attn_mask=attn_mask)

local_out.mean().backward()
ref_out.mean().backward()
local_out.backward()
ref_out.backward()

ref_local_query_grad = self.split_belanced_data(query.grad)
ref_local_key_grad = self.split_belanced_data(key.grad)
ref_local_value_grad = self.split_belanced_data(value.grad)

ref_local_out = self.split_belanced_data(ref_out)

rtol = 1e-04
atol = 5e-03
rtol = 1e-02
atol = 1e-02
np.testing.assert_allclose(
local_out.to("float32").numpy(), ref_local_out.to("float32").numpy(), rtol=rtol, atol=atol
)
Expand Down
Loading