From f5d3c3b9337a63dc12c02d3009e6cf2d5bb691ad Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Mon, 8 Apr 2024 08:20:29 +0000 Subject: [PATCH] fix --- paddlenlp/transformers/llama/modeling_auto.py | 49 ++++++++++++++----- 1 file changed, 38 insertions(+), 11 deletions(-) diff --git a/paddlenlp/transformers/llama/modeling_auto.py b/paddlenlp/transformers/llama/modeling_auto.py index 36a5826e7fe2..21635da46cca 100644 --- a/paddlenlp/transformers/llama/modeling_auto.py +++ b/paddlenlp/transformers/llama/modeling_auto.py @@ -434,16 +434,40 @@ def forward( if self.config.rope: if self.use_fused_rope: assert past_key_value is None, "fuse rotary not support cache kv for now" + batch_size, seq_length, num_heads, head_dim = query_states.shape + _, kv_seq_len, num_key_value_heads, _ = key_states.shape cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states, _ = fused_rotary_position_embedding( - query_states, - key_states, - v=None, - sin=sin, - cos=cos, - position_ids=position_ids, - use_neox_rotary_style=False, - ) + + paddle_version = float(paddle.__version__[:3]) + if ((paddle_version != 0.0) and (paddle_version <= 2.6)) and (num_heads != num_key_value_heads): + query_states, _, _ = fused_rotary_position_embedding( + query_states, + None, + None, + sin=sin, + cos=cos, + position_ids=position_ids, + use_neox_rotary_style=False, + ) + key_states, _, _ = fused_rotary_position_embedding( + key_states, + None, + None, + sin=sin, + cos=cos, + position_ids=position_ids, + use_neox_rotary_style=False, + ) + else: + query_states, key_states, _ = fused_rotary_position_embedding( + query_states, + key_states, + v=None, + sin=sin, + cos=cos, + position_ids=position_ids, + use_neox_rotary_style=False, + ) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) # hack here, because elementwise infer spmd not support broadcast now @@ -463,8 +487,11 @@ def forward( # TODO(wj-Mcat): use broadcast strategy when n_kv_heads = 1 # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + # paddle version > 2.6 or develop support flash-attn with gqa/mqa + paddle_version = float(paddle.__version__[:3]) + if (paddle_version != 0.0) and (paddle_version <= 2.6): + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) has_gradient = not (query_states.stop_gradient and key_states.stop_gradient and value_states.stop_gradient) if (