From 4cd6dbceb71fdaaa9de2c5d25ab99e8d87d6e9bf Mon Sep 17 00:00:00 2001 From: zhangyuqin1998 Date: Wed, 10 Jul 2024 11:38:02 +0800 Subject: [PATCH 1/3] fix ring attention --- .../transformers/ring_flash_attention.py | 127 +++++++----------- .../transformers/test_ring_flash_attention.py | 9 +- 2 files changed, 51 insertions(+), 85 deletions(-) diff --git a/paddlenlp/transformers/ring_flash_attention.py b/paddlenlp/transformers/ring_flash_attention.py index 3ff4d9def8d8..34162b948d24 100644 --- a/paddlenlp/transformers/ring_flash_attention.py +++ b/paddlenlp/transformers/ring_flash_attention.py @@ -55,17 +55,11 @@ def wait(self): def add_to_buffers(self, key, value): if key.shape != self._k_buffer[self._next_buffer_idx].shape: - k_buffer_chunk = paddle.slice( - self._k_buffer[self._next_buffer_idx], axes=[1], starts=[0], ends=[key.shape[1]] - ) - v_buffer_chunk = paddle.slice( - self._v_buffer[self._next_buffer_idx], axes=[1], starts=[0], ends=[value.shape[1]] - ) - k_buffer_chunk += key - v_buffer_chunk += value + self._k_buffer[self._next_buffer_idx][:, : key.shape[1], :, :].add_(key) + self._v_buffer[self._next_buffer_idx][:, : key.shape[1], :, :].add_(value) else: - self._k_buffer[self._next_buffer_idx] += key - self._v_buffer[self._next_buffer_idx] += value + self._k_buffer[self._next_buffer_idx].add_(key) + self._v_buffer[self._next_buffer_idx].add_(value) def get_buffers(self): return self._k_buffer[self._next_buffer_idx], self._v_buffer[self._next_buffer_idx] @@ -84,23 +78,19 @@ def send_recv(self): def update_out_and_lse(old_out, old_lse, block_out, block_lse, second_chunk_only=False): - if old_out is None and old_lse is None: - return block_out.to("float32"), block_lse.to("float32") - if second_chunk_only: - second_chunk_out_ = paddle.slice(old_out, axes=[1], starts=[old_out.shape[1] // 2], ends=[old_out.shape[1]]) - second_chunk_lse_ = paddle.slice(old_lse, axes=[1], starts=[old_lse.shape[1] // 2], ends=[old_lse.shape[1]]) + second_chunk_out = old_out[:, old_out.shape[1] // 2 :, :, :] + second_chunk_lse = old_lse[:, old_lse.shape[1] // 2 :, :, :] second_chunk_out, second_chunk_lse = update_out_and_lse( - second_chunk_out_, second_chunk_lse_, block_out, block_lse + second_chunk_out, second_chunk_lse, block_out, block_lse ) - paddle.assign(second_chunk_out, second_chunk_out_) - paddle.assign(second_chunk_lse, second_chunk_lse_) + old_out[:, old_out.shape[1] // 2 :, :, :] = second_chunk_out + old_lse[:, old_lse.shape[1] // 2 :, :, :] = second_chunk_lse return old_out, old_lse else: - block_out, block_lse = block_out.to("float32"), block_lse.to("float32") - with paddle.amp.auto_cast(enable=False, dtype="bfloat16"): - lse = old_lse - F.log_sigmoid(old_lse - block_lse) - return old_out - (old_out - block_out) * F.sigmoid(block_lse - old_lse), lse + return old_out - (old_out - block_out) * F.sigmoid(block_lse - old_lse), old_lse - F.log_sigmoid( + old_lse - block_lse + ) def get_chunk_id(rank, cp_size): @@ -130,14 +120,10 @@ def balanced_ring_flash_attention_fwd_func( comm_buffer = RingCommunicator(group, local_key, local_value) local_q_seq_len = local_query.shape[1] - out, lse, k_cache, v_cache = None, None, dict(), dict() - if attn_mask is not None: attn_masks_list = paddle.split(attn_mask, num_or_sections=cp_size * 2, axis=3) if is_causal: - local_query_second_chunk = paddle.slice( - local_query, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len] - ) + local_query_second_chunk = local_query[:, local_q_seq_len // 2 :, :, :] for step in range(cp_size): block_k, block_v = comm_buffer.get_buffers() @@ -159,16 +145,19 @@ def balanced_ring_flash_attention_fwd_func( not training, "", ) - block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1) - out, lse = update_out_and_lse(out, lse, block_out, block_lse) + paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1) + + if step == 0: + out, lse = block_out, block_lse + else: + out, lse = update_out_and_lse(out, lse, block_out, block_lse) else: - # block_k and block_v is from rank (group.rank - step) % cp_size if step == 0: block_out, _, block_lse, _ = _C_ops.flash_attn( local_query, block_k, block_v, fixed_seed_offset, None, dropout, True, False, not training, "" ) - block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1) - out, lse = update_out_and_lse(out, lse, block_out, block_lse) + paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1) + out, lse = block_out, block_lse elif step > rank: block_out, _, block_lse, _ = _C_ops.flash_attn( local_query_second_chunk, @@ -182,16 +171,14 @@ def balanced_ring_flash_attention_fwd_func( not training, "", ) - block_lse = paddle.slice(block_lse, axes=[1], starts=[0], ends=[local_q_seq_len // 2]) - block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1) + block_lse = block_lse[:, :, 0 : (local_q_seq_len // 2)] + paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1) out, lse = update_out_and_lse(out, lse, block_out, block_lse, True) else: - block_k = paddle.slice(block_k, axes=[1], starts=[0], ends=[local_q_seq_len // 2]) - block_v = paddle.slice(block_v, axes=[1], starts=[0], ends=[local_q_seq_len // 2]) block_out, _, block_lse, _ = _C_ops.flash_attn( local_query, - block_k, - block_v, + block_k[:, : local_q_seq_len // 2, :, :], + block_v[:, : local_q_seq_len // 2, :, :], fixed_seed_offset, None, dropout, @@ -200,23 +187,19 @@ def balanced_ring_flash_attention_fwd_func( not training, "", ) - block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1) + paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1) out, lse = update_out_and_lse(out, lse, block_out, block_lse) - k_cache[step] = block_k - v_cache[step] = block_v # TODO(zhangyuqin1998):batch_isend_irecv异步流下,无法wait,需要修复。对性能有影响。 + # if step != cp_size - 1: + # comm_buffer.wait() paddle.device.synchronize() - out = out.to(local_query.dtype) - lse = paddle.transpose_(paddle.squeeze_(lse, axis=-1), [0, 2, 1]) - return out, lse, k_cache, v_cache + return out.to(local_query.dtype), paddle.transpose_(paddle.squeeze(lse, axis=-1), [0, 2, 1]) def balanced_ring_flash_attention_bwd_func( group, - k_cache, - v_cache, out_grad, local_query, local_key, @@ -240,17 +223,10 @@ def balanced_ring_flash_attention_bwd_func( grad_comm_buffer = RingCommunicator(group, key_grad_buffer, value_grad_buffer) if is_causal: - local_query_second_chunk = paddle.slice( - local_query, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len] - ) - local_out_second_chunk = paddle.slice( - local_out, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len] - ) - lse_second_chunk = paddle.slice(lse, axes=[2], starts=[local_q_seq_len // 2], ends=[local_q_seq_len]) - out_grad_second_chunk = paddle.slice(out_grad, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len]) - query_grad_buffer_second_chunk = paddle.slice( - query_grad_buffer, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len] - ) + local_query_second_chunk = local_query[:, local_q_seq_len // 2 :, :, :] + local_out_second_chunk = local_out[:, local_q_seq_len // 2 :, :, :] + lse_second_chunk = lse[:, :, local_q_seq_len // 2 :] + out_grad_second_chunk = out_grad[:, local_q_seq_len // 2 :, :, :] if attn_mask is not None: attn_masks_list = paddle.split(attn_mask, num_or_sections=cp_size * 2, axis=3) @@ -274,13 +250,13 @@ def balanced_ring_flash_attention_bwd_func( dropout, False, ) - query_grad_buffer += block_q_grad + query_grad_buffer.add_(block_q_grad) else: if step == 0: block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd( local_query, block_k, block_v, local_out, lse, fixed_seed_offset, None, out_grad, dropout, True ) - query_grad_buffer += block_q_grad + query_grad_buffer.add_(block_q_grad) elif step > rank: block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd( local_query_second_chunk, @@ -294,12 +270,12 @@ def balanced_ring_flash_attention_bwd_func( dropout, False, ) - query_grad_buffer_second_chunk += block_q_grad + query_grad_buffer[:, local_q_seq_len // 2 :, :, :].add_(block_q_grad) else: block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd( local_query, - k_cache[step], - v_cache[step], + block_k[:, : local_q_seq_len // 2, :, :], + block_v[:, : local_q_seq_len // 2, :, :], local_out, lse, fixed_seed_offset, @@ -308,9 +284,12 @@ def balanced_ring_flash_attention_bwd_func( dropout, False, ) - query_grad_buffer += block_q_grad + query_grad_buffer.add_(block_q_grad) - # TODO(zhangyuqin1998):batch_isend_irecv异步流下,无法wait,需要修复。对性能有影响。 + # if step != cp_size - 1: + # kv_comm_buffer.wait() + # if step != 0: + # grad_comm_buffer.wait() paddle.device.synchronize() grad_comm_buffer.add_to_buffers(block_k_grad, block_v_grad) @@ -344,10 +323,10 @@ def forward( if attn_mask is not None: is_causal = False - out, lse, k_cache, v_cache = balanced_ring_flash_attention_fwd_func( + out, lse = balanced_ring_flash_attention_fwd_func( group, query, key, value, fixed_seed_offset, attn_mask, dropout, is_causal, training ) - ctx.save_for_backward(query, key, value, out, lse, attn_mask, k_cache, v_cache) + ctx.save_for_backward(query, key, value, out, lse, attn_mask) ctx.group = group ctx.fixed_seed_offset = fixed_seed_offset ctx.dropout = dropout @@ -356,7 +335,7 @@ def forward( @staticmethod def backward(ctx, out_grad): - query, key, value, out, lse, attn_mask, k_cache, v_cache = ctx.saved_tensor() + query, key, value, out, lse, attn_mask = ctx.saved_tensor() group = ctx.group fixed_seed_offset = ctx.fixed_seed_offset dropout = ctx.dropout @@ -366,19 +345,7 @@ def backward(ctx, out_grad): fixed_seed_offset = paddle.to_tensor([0, 0], place=paddle.CPUPlace(), dtype=paddle.int64) query_grad, key_grad, value_grad = balanced_ring_flash_attention_bwd_func( - group, - k_cache, - v_cache, - out_grad, - query, - key, - value, - out, - lse, - fixed_seed_offset, - attn_mask, - dropout, - is_causal, + group, out_grad, query, key, value, out, lse, fixed_seed_offset, attn_mask, dropout, is_causal ) if attn_mask is not None and not attn_mask.stop_gradient: return query_grad, key_grad, value_grad, None diff --git a/tests/transformers/test_ring_flash_attention.py b/tests/transformers/test_ring_flash_attention.py index 134d2f9c011a..ba49046070da 100644 --- a/tests/transformers/test_ring_flash_attention.py +++ b/tests/transformers/test_ring_flash_attention.py @@ -83,17 +83,16 @@ def single_test(self, bsz, seq_len_per_device, head_num, head_dim, is_causal, us ) ref_out = scaled_dot_product_attention(query, key, value, is_causal=is_causal, attn_mask=attn_mask) - local_out.mean().backward() - ref_out.mean().backward() + local_out.backward() + ref_out.backward() ref_local_query_grad = self.split_belanced_data(query.grad) ref_local_key_grad = self.split_belanced_data(key.grad) ref_local_value_grad = self.split_belanced_data(value.grad) ref_local_out = self.split_belanced_data(ref_out) - - rtol = 1e-04 - atol = 5e-03 + rtol = 1e-02 + atol = 1e-02 np.testing.assert_allclose( local_out.to("float32").numpy(), ref_local_out.to("float32").numpy(), rtol=rtol, atol=atol ) From 10cccd993d1a985419711cc966157fb5924a8def Mon Sep 17 00:00:00 2001 From: zhangyuqin1998 Date: Wed, 10 Jul 2024 21:19:03 +0800 Subject: [PATCH 2/3] fix --- paddlenlp/transformers/ring_flash_attention.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/paddlenlp/transformers/ring_flash_attention.py b/paddlenlp/transformers/ring_flash_attention.py index 34162b948d24..5e79cbe9fbb0 100644 --- a/paddlenlp/transformers/ring_flash_attention.py +++ b/paddlenlp/transformers/ring_flash_attention.py @@ -88,9 +88,11 @@ def update_out_and_lse(old_out, old_lse, block_out, block_lse, second_chunk_only old_lse[:, old_lse.shape[1] // 2 :, :, :] = second_chunk_lse return old_out, old_lse else: - return old_out - (old_out - block_out) * F.sigmoid(block_lse - old_lse), old_lse - F.log_sigmoid( - old_lse - block_lse - ) + block_out, block_lse = block_out.to("float32"), block_lse.to("float32") + with paddle.amp.auto_cast(enable=False): + return old_out - (old_out - block_out) * F.sigmoid(block_lse - old_lse), old_lse - F.log_sigmoid( + old_lse - block_lse + ) def get_chunk_id(rank, cp_size): @@ -195,7 +197,7 @@ def balanced_ring_flash_attention_fwd_func( # comm_buffer.wait() paddle.device.synchronize() - return out.to(local_query.dtype), paddle.transpose_(paddle.squeeze(lse, axis=-1), [0, 2, 1]) + return paddle.cast(out, local_query.dtype), paddle.transpose_(paddle.squeeze(lse, axis=-1), [0, 2, 1]) def balanced_ring_flash_attention_bwd_func( @@ -298,8 +300,7 @@ def balanced_ring_flash_attention_bwd_func( grad_comm_buffer.wait() key_grad_buffer, value_grad_buffer = grad_comm_buffer.get_buffers() - dtype = local_query.dtype - return query_grad_buffer.to(dtype), key_grad_buffer.to(dtype), value_grad_buffer.to(dtype) + return query_grad_buffer, key_grad_buffer, value_grad_buffer class RingFlashAttention(PyLayer): From 812d4cc4961b830af9470c2e217dca51d015dbbe Mon Sep 17 00:00:00 2001 From: zhangyuqin1998 Date: Wed, 10 Jul 2024 21:22:46 +0800 Subject: [PATCH 3/3] fix --- paddlenlp/transformers/ring_flash_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlenlp/transformers/ring_flash_attention.py b/paddlenlp/transformers/ring_flash_attention.py index 5e79cbe9fbb0..9fa8ea52b655 100644 --- a/paddlenlp/transformers/ring_flash_attention.py +++ b/paddlenlp/transformers/ring_flash_attention.py @@ -88,7 +88,7 @@ def update_out_and_lse(old_out, old_lse, block_out, block_lse, second_chunk_only old_lse[:, old_lse.shape[1] // 2 :, :, :] = second_chunk_lse return old_out, old_lse else: - block_out, block_lse = block_out.to("float32"), block_lse.to("float32") + block_out, block_lse = paddle.cast(block_out, "float32"), paddle.cast(block_lse, "float32") with paddle.amp.auto_cast(enable=False): return old_out - (old_out - block_out) * F.sigmoid(block_lse - old_lse), old_lse - F.log_sigmoid( old_lse - block_lse