@@ -434,16 +434,40 @@ def forward(
434
434
if self .config .rope :
435
435
if self .use_fused_rope :
436
436
assert past_key_value is None , "fuse rotary not support cache kv for now"
437
+ batch_size , seq_length , num_heads , head_dim = query_states .shape
438
+ _ , kv_seq_len , num_key_value_heads , _ = key_states .shape
437
439
cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
438
- query_states , key_states , _ = fused_rotary_position_embedding (
439
- query_states ,
440
- key_states ,
441
- v = None ,
442
- sin = sin ,
443
- cos = cos ,
444
- position_ids = position_ids ,
445
- use_neox_rotary_style = False ,
446
- )
440
+
441
+ paddle_version = float (paddle .__version__ [:3 ])
442
+ if ((paddle_version != 0.0 ) and (paddle_version <= 2.6 )) and (num_heads != num_key_value_heads ):
443
+ query_states , _ , _ = fused_rotary_position_embedding (
444
+ query_states ,
445
+ None ,
446
+ None ,
447
+ sin = sin ,
448
+ cos = cos ,
449
+ position_ids = position_ids ,
450
+ use_neox_rotary_style = False ,
451
+ )
452
+ key_states , _ , _ = fused_rotary_position_embedding (
453
+ key_states ,
454
+ None ,
455
+ None ,
456
+ sin = sin ,
457
+ cos = cos ,
458
+ position_ids = position_ids ,
459
+ use_neox_rotary_style = False ,
460
+ )
461
+ else :
462
+ query_states , key_states , _ = fused_rotary_position_embedding (
463
+ query_states ,
464
+ key_states ,
465
+ v = None ,
466
+ sin = sin ,
467
+ cos = cos ,
468
+ position_ids = position_ids ,
469
+ use_neox_rotary_style = False ,
470
+ )
447
471
else :
448
472
cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
449
473
# hack here, because elementwise infer spmd not support broadcast now
@@ -463,8 +487,11 @@ def forward(
463
487
464
488
# TODO(wj-Mcat): use broadcast strategy when n_kv_heads = 1
465
489
# repeat k/v heads if n_kv_heads < n_heads
466
- key_states = repeat_kv (key_states , self .num_key_value_groups )
467
- value_states = repeat_kv (value_states , self .num_key_value_groups )
490
+ # paddle version > 2.6 or develop support flash-attn with gqa/mqa
491
+ paddle_version = float (paddle .__version__ [:3 ])
492
+ if (paddle_version != 0.0 ) and (paddle_version <= 2.6 ):
493
+ key_states = repeat_kv (key_states , self .num_key_value_groups )
494
+ value_states = repeat_kv (value_states , self .num_key_value_groups )
468
495
469
496
has_gradient = not (query_states .stop_gradient and key_states .stop_gradient and value_states .stop_gradient )
470
497
if (
0 commit comments