Skip to content

Commit c0c64fa

Browse files
authored
gqa fuse attention qkv (#7890)
* gqa fuse attention qkv * add annotation for the fusion
1 parent 3a704ea commit c0c64fa

File tree

1 file changed

+50
-13
lines changed

1 file changed

+50
-13
lines changed

paddlenlp/transformers/llama/modeling.py

Lines changed: 50 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -588,17 +588,15 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
588588
self.head_dim = self.hidden_size // config.num_attention_heads
589589

590590
self.num_key_value_heads = config.num_key_value_heads
591+
assert config.num_attention_heads // config.num_key_value_heads
591592
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
593+
self.gqa_or_mqa = config.num_attention_heads != config.num_key_value_heads
592594

593595
self.max_position_embeddings = config.max_position_embeddings
594596
self.seq_length = config.seq_length
595597
self.sequence_parallel = config.sequence_parallel
596598

597599
self.fuse_attention_qkv = config.fuse_attention_qkv
598-
if self.fuse_attention_qkv and config.num_attention_heads != config.num_key_value_heads:
599-
raise ValueError(
600-
f"fuse_attention_qkv can't be True when num_attention_heads {config.num_attention_heads}!= num_key_value_heads {config.num_key_value_heads}"
601-
)
602600

603601
self.kv_indices = None
604602
# Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True
@@ -615,6 +613,11 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
615613
if self.num_key_value_heads % config.tensor_parallel_degree == 0:
616614
self.num_key_value_heads = self.num_key_value_heads // config.tensor_parallel_degree
617615
else:
616+
if self.fuse_attention_qkv:
617+
# TODO(Yuang): support fusion for kv when kv heads cannot be divided by mp
618+
raise ValueError(
619+
f"fuse_attention_qkv can't be True when num_key_value_heads {config.num_key_value_heads} % tensor_parallel_degree {config.tensor_parallel_degree} != 0"
620+
)
618621
logger.warning(
619622
f"Get num_key_value_heads: {self.num_key_value_heads}, can't split to tensor_parallel_degree: {config.tensor_parallel_degree}, so we don't spilt key value weight."
620623
)
@@ -644,7 +647,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
644647
if self.fuse_attention_qkv:
645648
self.qkv_proj = ColumnParallelLinear(
646649
self.hidden_size,
647-
3 * self.hidden_size,
650+
self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim,
648651
has_bias=False,
649652
gather_output=False,
650653
)
@@ -684,7 +687,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
684687
if self.fuse_attention_qkv:
685688
self.qkv_proj = nn.Linear(
686689
self.hidden_size,
687-
3 * self.hidden_size,
690+
self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim,
688691
bias_attr=False,
689692
)
690693
else:
@@ -771,12 +774,27 @@ def forward(
771774

772775
if self.fuse_attention_qkv:
773776
mix_layer = self.qkv_proj(hidden_states)
777+
# NOTE for GQA attention fusion (compatible with MHA and MQA):
778+
# The weight for qkv_proj is in shape like [hidden_size, hidden_size + 2 * num_kv_heads * head_dim].
779+
# After the projection, the mix_layer is in shape like [b, s, hidden_size + 2 * num_kv_heads * head_dim].
780+
# Reshape the mix_layer into a shape like [b, s, num_kv_heads, (num_groups + 2) * head_dim],
781+
# where num_groups = num_q_heads // num_kv_heads.
782+
# Split the mix_layer on the last axis into three sections [num_groups * head_dim, head_dim, head_dim]
783+
# to represent the q, k and v respectively.
784+
# The q is in the shape like [b, s, num_kv_heads, num_groups * head_dim].
785+
# The k and v are in the shape like [b, s, num_kv_heads, head_dim].
786+
# Under MHA, the q is ready for the following calculation since num_kv_heads == num_q_heads,
787+
# But for the GQA or MQA, q should be reshaped into [b, s, num_q_heads, head_dim].
774788
if self.reshard_layer is not None:
775789
if self.sequence_parallel:
776790
assert self.seq_length % self.config.sep_parallel_degree == 0
777791
mix_layer = paddle.reshape_(
778792
mix_layer,
779-
[-1, self.seq_length // self.config.sep_parallel_degree, 3 * self.num_heads * self.head_dim],
793+
[
794+
-1,
795+
self.seq_length // self.config.sep_parallel_degree,
796+
self.num_heads * self.head_dim + 2 * self.num_key_value_heads * self.head_dim,
797+
],
780798
)
781799
# [bs, seq_len / sep, num_head, head_dim] -> [bs, seq_len, num_head / sep, head_dim]
782800
mix_layer = self.reshard_layer(
@@ -785,15 +803,26 @@ def forward(
785803
concat_axis=1,
786804
)
787805
mix_layer = paddle.reshape_(
788-
mix_layer, [0, self.seq_length, -1, 3 * self.head_dim]
806+
mix_layer, [0, self.seq_length, -1, (self.num_key_value_groups + 2) * self.head_dim]
789807
) # [bs, seq_len, num_head/k, 3*head_dim], k is sep degree
790808
else:
791809
if self.sequence_parallel:
792-
target_shape = [-1, self.seq_length, self.num_heads, 3 * self.head_dim]
810+
target_shape = [
811+
-1,
812+
self.seq_length,
813+
self.num_key_value_heads,
814+
(self.num_key_value_groups + 2) * self.head_dim,
815+
]
793816
else:
794-
target_shape = [0, 0, self.num_heads, 3 * self.head_dim]
817+
target_shape = [0, 0, self.num_key_value_heads, (self.num_key_value_groups + 2) * self.head_dim]
795818
mix_layer = paddle.reshape_(mix_layer, target_shape)
796-
query_states, key_states, value_states = paddle.split(mix_layer, num_or_sections=3, axis=-1)
819+
query_states, key_states, value_states = paddle.split(
820+
mix_layer,
821+
num_or_sections=[self.num_key_value_groups * self.head_dim, self.head_dim, self.head_dim],
822+
axis=-1,
823+
)
824+
if self.gqa_or_mqa:
825+
query_states = paddle.reshape_(query_states, [0, 0, self.num_heads, self.head_dim])
797826
else:
798827
query_states = self.q_proj(hidden_states)
799828
key_states = self.k_proj(hidden_states)
@@ -807,11 +836,19 @@ def forward(
807836
)
808837
key_states = paddle.reshape(
809838
key_states,
810-
[-1, self.seq_length // self.config.sep_parallel_degree, self.num_heads * self.head_dim],
839+
[
840+
-1,
841+
self.seq_length // self.config.sep_parallel_degree,
842+
self.num_key_value_heads * self.head_dim,
843+
],
811844
)
812845
value_states = paddle.reshape(
813846
value_states,
814-
[-1, self.seq_length // self.config.sep_parallel_degree, self.num_heads * self.head_dim],
847+
[
848+
-1,
849+
self.seq_length // self.config.sep_parallel_degree,
850+
self.num_key_value_heads * self.head_dim,
851+
],
815852
)
816853
query_states = self.reshard_layer(
817854
query_states,

0 commit comments

Comments
 (0)