Skip to content

Commit 220d867

Browse files
authored
fix (#8234)
1 parent a5f69e4 commit 220d867

File tree

1 file changed

+38
-11
lines changed

1 file changed

+38
-11
lines changed

paddlenlp/transformers/llama/modeling_auto.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -434,16 +434,40 @@ def forward(
434434
if self.config.rope:
435435
if self.use_fused_rope:
436436
assert past_key_value is None, "fuse rotary not support cache kv for now"
437+
batch_size, seq_length, num_heads, head_dim = query_states.shape
438+
_, kv_seq_len, num_key_value_heads, _ = key_states.shape
437439
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
438-
query_states, key_states, _ = fused_rotary_position_embedding(
439-
query_states,
440-
key_states,
441-
v=None,
442-
sin=sin,
443-
cos=cos,
444-
position_ids=position_ids,
445-
use_neox_rotary_style=False,
446-
)
440+
441+
paddle_version = float(paddle.__version__[:3])
442+
if ((paddle_version != 0.0) and (paddle_version <= 2.6)) and (num_heads != num_key_value_heads):
443+
query_states, _, _ = fused_rotary_position_embedding(
444+
query_states,
445+
None,
446+
None,
447+
sin=sin,
448+
cos=cos,
449+
position_ids=position_ids,
450+
use_neox_rotary_style=False,
451+
)
452+
key_states, _, _ = fused_rotary_position_embedding(
453+
key_states,
454+
None,
455+
None,
456+
sin=sin,
457+
cos=cos,
458+
position_ids=position_ids,
459+
use_neox_rotary_style=False,
460+
)
461+
else:
462+
query_states, key_states, _ = fused_rotary_position_embedding(
463+
query_states,
464+
key_states,
465+
v=None,
466+
sin=sin,
467+
cos=cos,
468+
position_ids=position_ids,
469+
use_neox_rotary_style=False,
470+
)
447471
else:
448472
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
449473
# hack here, because elementwise infer spmd not support broadcast now
@@ -463,8 +487,11 @@ def forward(
463487

464488
# TODO(wj-Mcat): use broadcast strategy when n_kv_heads = 1
465489
# repeat k/v heads if n_kv_heads < n_heads
466-
key_states = repeat_kv(key_states, self.num_key_value_groups)
467-
value_states = repeat_kv(value_states, self.num_key_value_groups)
490+
# paddle version > 2.6 or develop support flash-attn with gqa/mqa
491+
paddle_version = float(paddle.__version__[:3])
492+
if (paddle_version != 0.0) and (paddle_version <= 2.6):
493+
key_states = repeat_kv(key_states, self.num_key_value_groups)
494+
value_states = repeat_kv(value_states, self.num_key_value_groups)
468495

469496
has_gradient = not (query_states.stop_gradient and key_states.stop_gradient and value_states.stop_gradient)
470497
if (

0 commit comments

Comments
 (0)