@@ -588,17 +588,15 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
588
588
self .head_dim = self .hidden_size // config .num_attention_heads
589
589
590
590
self .num_key_value_heads = config .num_key_value_heads
591
+ assert config .num_attention_heads // config .num_key_value_heads
591
592
self .num_key_value_groups = config .num_attention_heads // config .num_key_value_heads
593
+ self .gqa_or_mqa = config .num_attention_heads != config .num_key_value_heads
592
594
593
595
self .max_position_embeddings = config .max_position_embeddings
594
596
self .seq_length = config .seq_length
595
597
self .sequence_parallel = config .sequence_parallel
596
598
597
599
self .fuse_attention_qkv = config .fuse_attention_qkv
598
- if self .fuse_attention_qkv and config .num_attention_heads != config .num_key_value_heads :
599
- raise ValueError (
600
- f"fuse_attention_qkv can't be True when num_attention_heads { config .num_attention_heads } != num_key_value_heads { config .num_key_value_heads } "
601
- )
602
600
603
601
self .kv_indices = None
604
602
# Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True
@@ -615,6 +613,11 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
615
613
if self .num_key_value_heads % config .tensor_parallel_degree == 0 :
616
614
self .num_key_value_heads = self .num_key_value_heads // config .tensor_parallel_degree
617
615
else :
616
+ if self .fuse_attention_qkv :
617
+ # TODO(Yuang): support fusion for kv when kv heads cannot be divided by mp
618
+ raise ValueError (
619
+ f"fuse_attention_qkv can't be True when num_key_value_heads { config .num_key_value_heads } % tensor_parallel_degree { config .tensor_parallel_degree } != 0"
620
+ )
618
621
logger .warning (
619
622
f"Get num_key_value_heads: { self .num_key_value_heads } , can't split to tensor_parallel_degree: { config .tensor_parallel_degree } , so we don't spilt key value weight."
620
623
)
@@ -644,7 +647,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
644
647
if self .fuse_attention_qkv :
645
648
self .qkv_proj = ColumnParallelLinear (
646
649
self .hidden_size ,
647
- 3 * self .hidden_size ,
650
+ self . hidden_size + 2 * self .config . num_key_value_heads * self . head_dim ,
648
651
has_bias = False ,
649
652
gather_output = False ,
650
653
)
@@ -684,7 +687,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
684
687
if self .fuse_attention_qkv :
685
688
self .qkv_proj = nn .Linear (
686
689
self .hidden_size ,
687
- 3 * self .hidden_size ,
690
+ self . hidden_size + 2 * self .config . num_key_value_heads * self . head_dim ,
688
691
bias_attr = False ,
689
692
)
690
693
else :
@@ -776,7 +779,11 @@ def forward(
776
779
assert self .seq_length % self .config .sep_parallel_degree == 0
777
780
mix_layer = paddle .reshape_ (
778
781
mix_layer ,
779
- [- 1 , self .seq_length // self .config .sep_parallel_degree , 3 * self .num_heads * self .head_dim ],
782
+ [
783
+ - 1 ,
784
+ self .seq_length // self .config .sep_parallel_degree ,
785
+ self .num_heads * self .head_dim + 2 * self .num_key_value_heads * self .head_dim ,
786
+ ],
780
787
)
781
788
# [bs, seq_len / sep, num_head, head_dim] -> [bs, seq_len, num_head / sep, head_dim]
782
789
mix_layer = self .reshard_layer (
@@ -785,15 +792,26 @@ def forward(
785
792
concat_axis = 1 ,
786
793
)
787
794
mix_layer = paddle .reshape_ (
788
- mix_layer , [0 , self .seq_length , - 1 , 3 * self .head_dim ]
795
+ mix_layer , [0 , self .seq_length , - 1 , ( self . num_key_value_groups + 2 ) * self .head_dim ]
789
796
) # [bs, seq_len, num_head/k, 3*head_dim], k is sep degree
790
797
else :
791
798
if self .sequence_parallel :
792
- target_shape = [- 1 , self .seq_length , self .num_heads , 3 * self .head_dim ]
799
+ target_shape = [
800
+ - 1 ,
801
+ self .seq_length ,
802
+ self .num_key_value_heads ,
803
+ (self .num_key_value_groups + 2 ) * self .head_dim ,
804
+ ]
793
805
else :
794
- target_shape = [0 , 0 , self .num_heads , 3 * self .head_dim ]
806
+ target_shape = [0 , 0 , self .num_key_value_heads , ( self . num_key_value_groups + 2 ) * self .head_dim ]
795
807
mix_layer = paddle .reshape_ (mix_layer , target_shape )
796
- query_states , key_states , value_states = paddle .split (mix_layer , num_or_sections = 3 , axis = - 1 )
808
+ query_states , key_states , value_states = paddle .split (
809
+ mix_layer ,
810
+ num_or_sections = [self .num_key_value_groups * self .head_dim , self .head_dim , self .head_dim ],
811
+ axis = - 1 ,
812
+ )
813
+ if self .gqa_or_mqa :
814
+ query_states = paddle .reshape_ (query_states , [0 , 0 , self .num_heads , self .head_dim ])
797
815
else :
798
816
query_states = self .q_proj (hidden_states )
799
817
key_states = self .k_proj (hidden_states )
@@ -807,11 +825,19 @@ def forward(
807
825
)
808
826
key_states = paddle .reshape (
809
827
key_states ,
810
- [- 1 , self .seq_length // self .config .sep_parallel_degree , self .num_heads * self .head_dim ],
828
+ [
829
+ - 1 ,
830
+ self .seq_length // self .config .sep_parallel_degree ,
831
+ self .num_key_value_heads * self .head_dim ,
832
+ ],
811
833
)
812
834
value_states = paddle .reshape (
813
835
value_states ,
814
- [- 1 , self .seq_length // self .config .sep_parallel_degree , self .num_heads * self .head_dim ],
836
+ [
837
+ - 1 ,
838
+ self .seq_length // self .config .sep_parallel_degree ,
839
+ self .num_key_value_heads * self .head_dim ,
840
+ ],
815
841
)
816
842
query_states = self .reshard_layer (
817
843
query_states ,
@@ -855,16 +881,38 @@ def forward(
855
881
position_ids = paddle .arange (seq_length , dtype = "int64" ).expand ((batch_size , seq_length ))
856
882
if self .use_fused_rope :
857
883
assert past_key_value is None , "fuse rotary not support cache kv for now"
884
+ batch_size , seq_length , num_heads , head_dim = query_states .shape
885
+ _ , kv_seq_len , num_key_value_heads , _ = key_states .shape
858
886
cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
859
- query_states , key_states , _ = fused_rotary_position_embedding (
860
- query_states ,
861
- key_states ,
862
- v = None ,
863
- sin = sin ,
864
- cos = cos ,
865
- position_ids = position_ids ,
866
- use_neox_rotary_style = False ,
867
- )
887
+ if num_heads != num_key_value_heads :
888
+ query_states , _ , _ = fused_rotary_position_embedding (
889
+ query_states ,
890
+ None ,
891
+ None ,
892
+ sin = sin ,
893
+ cos = cos ,
894
+ position_ids = position_ids ,
895
+ use_neox_rotary_style = False ,
896
+ )
897
+ key_states , _ , _ = fused_rotary_position_embedding (
898
+ key_states ,
899
+ None ,
900
+ None ,
901
+ sin = sin ,
902
+ cos = cos ,
903
+ position_ids = position_ids ,
904
+ use_neox_rotary_style = False ,
905
+ )
906
+ else :
907
+ query_states , key_states , _ = fused_rotary_position_embedding (
908
+ query_states ,
909
+ key_states ,
910
+ v = None ,
911
+ sin = sin ,
912
+ cos = cos ,
913
+ position_ids = position_ids ,
914
+ use_neox_rotary_style = False ,
915
+ )
868
916
else :
869
917
cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
870
918
query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin , position_ids )
@@ -881,11 +929,6 @@ def forward(
881
929
key_states = paddle .index_select (key_states , self .kv_indices , axis = 2 )
882
930
value_states = paddle .index_select (value_states , self .kv_indices , axis = 2 )
883
931
884
- # TODO(wj-Mcat): use broadcast strategy when n_kv_heads = 1
885
- # repeat k/v heads if n_kv_heads < n_heads
886
- key_states = repeat_kv (key_states , self .num_key_value_groups )
887
- value_states = repeat_kv (value_states , self .num_key_value_groups )
888
-
889
932
has_gradient = not (query_states .stop_gradient and key_states .stop_gradient and value_states .stop_gradient )
890
933
if (
891
934
self .enable_recompute
0 commit comments