@@ -588,17 +588,15 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
588
588
self .head_dim = self .hidden_size // config .num_attention_heads
589
589
590
590
self .num_key_value_heads = config .num_key_value_heads
591
+ assert config .num_attention_heads // config .num_key_value_heads
591
592
self .num_key_value_groups = config .num_attention_heads // config .num_key_value_heads
593
+ self .gqa_or_mqa = config .num_attention_heads != config .num_key_value_heads
592
594
593
595
self .max_position_embeddings = config .max_position_embeddings
594
596
self .seq_length = config .seq_length
595
597
self .sequence_parallel = config .sequence_parallel
596
598
597
599
self .fuse_attention_qkv = config .fuse_attention_qkv
598
- if self .fuse_attention_qkv and config .num_attention_heads != config .num_key_value_heads :
599
- raise ValueError (
600
- f"fuse_attention_qkv can't be True when num_attention_heads { config .num_attention_heads } != num_key_value_heads { config .num_key_value_heads } "
601
- )
602
600
603
601
self .kv_indices = None
604
602
# Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True
@@ -615,6 +613,11 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
615
613
if self .num_key_value_heads % config .tensor_parallel_degree == 0 :
616
614
self .num_key_value_heads = self .num_key_value_heads // config .tensor_parallel_degree
617
615
else :
616
+ if self .fuse_attention_qkv :
617
+ # TODO(Yuang): support fusion for kv when kv heads cannot be divided by mp
618
+ raise ValueError (
619
+ f"fuse_attention_qkv can't be True when num_key_value_heads { config .num_key_value_heads } % tensor_parallel_degree { config .tensor_parallel_degree } != 0"
620
+ )
618
621
logger .warning (
619
622
f"Get num_key_value_heads: { self .num_key_value_heads } , can't split to tensor_parallel_degree: { config .tensor_parallel_degree } , so we don't spilt key value weight."
620
623
)
@@ -644,7 +647,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
644
647
if self .fuse_attention_qkv :
645
648
self .qkv_proj = ColumnParallelLinear (
646
649
self .hidden_size ,
647
- 3 * self .hidden_size ,
650
+ self . hidden_size + 2 * self .config . num_key_value_heads * self . head_dim ,
648
651
has_bias = False ,
649
652
gather_output = False ,
650
653
)
@@ -684,7 +687,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
684
687
if self .fuse_attention_qkv :
685
688
self .qkv_proj = nn .Linear (
686
689
self .hidden_size ,
687
- 3 * self .hidden_size ,
690
+ self . hidden_size + 2 * self .config . num_key_value_heads * self . head_dim ,
688
691
bias_attr = False ,
689
692
)
690
693
else :
@@ -771,12 +774,27 @@ def forward(
771
774
772
775
if self .fuse_attention_qkv :
773
776
mix_layer = self .qkv_proj (hidden_states )
777
+ # NOTE for GQA attention fusion (compatible with MHA and MQA):
778
+ # The weight for qkv_proj is in shape like [hidden_size, hidden_size + 2 * num_kv_heads * head_dim].
779
+ # After the projection, the mix_layer is in shape like [b, s, hidden_size + 2 * num_kv_heads * head_dim].
780
+ # Reshape the mix_layer into a shape like [b, s, num_kv_heads, (num_groups + 2) * head_dim],
781
+ # where num_groups = num_q_heads // num_kv_heads.
782
+ # Split the mix_layer on the last axis into three sections [num_groups * head_dim, head_dim, head_dim]
783
+ # to represent the q, k and v respectively.
784
+ # The q is in the shape like [b, s, num_kv_heads, num_groups * head_dim].
785
+ # The k and v are in the shape like [b, s, num_kv_heads, head_dim].
786
+ # Under MHA, the q is ready for the following calculation since num_kv_heads == num_q_heads,
787
+ # But for the GQA or MQA, q should be reshaped into [b, s, num_q_heads, head_dim].
774
788
if self .reshard_layer is not None :
775
789
if self .sequence_parallel :
776
790
assert self .seq_length % self .config .sep_parallel_degree == 0
777
791
mix_layer = paddle .reshape_ (
778
792
mix_layer ,
779
- [- 1 , self .seq_length // self .config .sep_parallel_degree , 3 * self .num_heads * self .head_dim ],
793
+ [
794
+ - 1 ,
795
+ self .seq_length // self .config .sep_parallel_degree ,
796
+ self .num_heads * self .head_dim + 2 * self .num_key_value_heads * self .head_dim ,
797
+ ],
780
798
)
781
799
# [bs, seq_len / sep, num_head, head_dim] -> [bs, seq_len, num_head / sep, head_dim]
782
800
mix_layer = self .reshard_layer (
@@ -785,15 +803,26 @@ def forward(
785
803
concat_axis = 1 ,
786
804
)
787
805
mix_layer = paddle .reshape_ (
788
- mix_layer , [0 , self .seq_length , - 1 , 3 * self .head_dim ]
806
+ mix_layer , [0 , self .seq_length , - 1 , ( self . num_key_value_groups + 2 ) * self .head_dim ]
789
807
) # [bs, seq_len, num_head/k, 3*head_dim], k is sep degree
790
808
else :
791
809
if self .sequence_parallel :
792
- target_shape = [- 1 , self .seq_length , self .num_heads , 3 * self .head_dim ]
810
+ target_shape = [
811
+ - 1 ,
812
+ self .seq_length ,
813
+ self .num_key_value_heads ,
814
+ (self .num_key_value_groups + 2 ) * self .head_dim ,
815
+ ]
793
816
else :
794
- target_shape = [0 , 0 , self .num_heads , 3 * self .head_dim ]
817
+ target_shape = [0 , 0 , self .num_key_value_heads , ( self . num_key_value_groups + 2 ) * self .head_dim ]
795
818
mix_layer = paddle .reshape_ (mix_layer , target_shape )
796
- query_states , key_states , value_states = paddle .split (mix_layer , num_or_sections = 3 , axis = - 1 )
819
+ query_states , key_states , value_states = paddle .split (
820
+ mix_layer ,
821
+ num_or_sections = [self .num_key_value_groups * self .head_dim , self .head_dim , self .head_dim ],
822
+ axis = - 1 ,
823
+ )
824
+ if self .gqa_or_mqa :
825
+ query_states = paddle .reshape_ (query_states , [0 , 0 , self .num_heads , self .head_dim ])
797
826
else :
798
827
query_states = self .q_proj (hidden_states )
799
828
key_states = self .k_proj (hidden_states )
@@ -807,11 +836,19 @@ def forward(
807
836
)
808
837
key_states = paddle .reshape (
809
838
key_states ,
810
- [- 1 , self .seq_length // self .config .sep_parallel_degree , self .num_heads * self .head_dim ],
839
+ [
840
+ - 1 ,
841
+ self .seq_length // self .config .sep_parallel_degree ,
842
+ self .num_key_value_heads * self .head_dim ,
843
+ ],
811
844
)
812
845
value_states = paddle .reshape (
813
846
value_states ,
814
- [- 1 , self .seq_length // self .config .sep_parallel_degree , self .num_heads * self .head_dim ],
847
+ [
848
+ - 1 ,
849
+ self .seq_length // self .config .sep_parallel_degree ,
850
+ self .num_key_value_heads * self .head_dim ,
851
+ ],
815
852
)
816
853
query_states = self .reshard_layer (
817
854
query_states ,
0 commit comments