Skip to content

Commit 896c01f

Browse files
committed
gqa fuse attention qkv
1 parent 6e0ac44 commit 896c01f

File tree

1 file changed

+39
-13
lines changed

1 file changed

+39
-13
lines changed

paddlenlp/transformers/llama/modeling.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -588,17 +588,15 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
588588
self.head_dim = self.hidden_size // config.num_attention_heads
589589

590590
self.num_key_value_heads = config.num_key_value_heads
591+
assert config.num_attention_heads // config.num_key_value_heads
591592
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
593+
self.gqa_or_mqa = config.num_attention_heads != config.num_key_value_heads
592594

593595
self.max_position_embeddings = config.max_position_embeddings
594596
self.seq_length = config.seq_length
595597
self.sequence_parallel = config.sequence_parallel
596598

597599
self.fuse_attention_qkv = config.fuse_attention_qkv
598-
if self.fuse_attention_qkv and config.num_attention_heads != config.num_key_value_heads:
599-
raise ValueError(
600-
f"fuse_attention_qkv can't be True when num_attention_heads {config.num_attention_heads}!= num_key_value_heads {config.num_key_value_heads}"
601-
)
602600

603601
self.kv_indices = None
604602
# Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True
@@ -615,6 +613,11 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
615613
if self.num_key_value_heads % config.tensor_parallel_degree == 0:
616614
self.num_key_value_heads = self.num_key_value_heads // config.tensor_parallel_degree
617615
else:
616+
if self.fuse_attention_qkv:
617+
# TODO(Yuang): support fusion for kv when kv heads cannot be divided by mp
618+
raise ValueError(
619+
f"fuse_attention_qkv can't be True when num_key_value_heads {config.num_key_value_heads} % tensor_parallel_degree {config.tensor_parallel_degree} != 0"
620+
)
618621
logger.warning(
619622
f"Get num_key_value_heads: {self.num_key_value_heads}, can't split to tensor_parallel_degree: {config.tensor_parallel_degree}, so we don't spilt key value weight."
620623
)
@@ -644,7 +647,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
644647
if self.fuse_attention_qkv:
645648
self.qkv_proj = ColumnParallelLinear(
646649
self.hidden_size,
647-
3 * self.hidden_size,
650+
self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim,
648651
has_bias=False,
649652
gather_output=False,
650653
)
@@ -684,7 +687,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
684687
if self.fuse_attention_qkv:
685688
self.qkv_proj = nn.Linear(
686689
self.hidden_size,
687-
3 * self.hidden_size,
690+
self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim,
688691
bias_attr=False,
689692
)
690693
else:
@@ -776,7 +779,11 @@ def forward(
776779
assert self.seq_length % self.config.sep_parallel_degree == 0
777780
mix_layer = paddle.reshape_(
778781
mix_layer,
779-
[-1, self.seq_length // self.config.sep_parallel_degree, 3 * self.num_heads * self.head_dim],
782+
[
783+
-1,
784+
self.seq_length // self.config.sep_parallel_degree,
785+
self.num_heads * self.head_dim + 2 * self.num_key_value_heads * self.head_dim,
786+
],
780787
)
781788
# [bs, seq_len / sep, num_head, head_dim] -> [bs, seq_len, num_head / sep, head_dim]
782789
mix_layer = self.reshard_layer(
@@ -785,15 +792,26 @@ def forward(
785792
concat_axis=1,
786793
)
787794
mix_layer = paddle.reshape_(
788-
mix_layer, [0, self.seq_length, -1, 3 * self.head_dim]
795+
mix_layer, [0, self.seq_length, -1, (self.num_key_value_groups + 2) * self.head_dim]
789796
) # [bs, seq_len, num_head/k, 3*head_dim], k is sep degree
790797
else:
791798
if self.sequence_parallel:
792-
target_shape = [-1, self.seq_length, self.num_heads, 3 * self.head_dim]
799+
target_shape = [
800+
-1,
801+
self.seq_length,
802+
self.num_key_value_heads,
803+
(self.num_key_value_groups + 2) * self.head_dim,
804+
]
793805
else:
794-
target_shape = [0, 0, self.num_heads, 3 * self.head_dim]
806+
target_shape = [0, 0, self.num_key_value_heads, (self.num_key_value_groups + 2) * self.head_dim]
795807
mix_layer = paddle.reshape_(mix_layer, target_shape)
796-
query_states, key_states, value_states = paddle.split(mix_layer, num_or_sections=3, axis=-1)
808+
query_states, key_states, value_states = paddle.split(
809+
mix_layer,
810+
num_or_sections=[self.num_key_value_groups * self.head_dim, self.head_dim, self.head_dim],
811+
axis=-1,
812+
)
813+
if self.gqa_or_mqa:
814+
query_states = paddle.reshape_(query_states, [0, 0, self.num_heads, self.head_dim])
797815
else:
798816
query_states = self.q_proj(hidden_states)
799817
key_states = self.k_proj(hidden_states)
@@ -807,11 +825,19 @@ def forward(
807825
)
808826
key_states = paddle.reshape(
809827
key_states,
810-
[-1, self.seq_length // self.config.sep_parallel_degree, self.num_heads * self.head_dim],
828+
[
829+
-1,
830+
self.seq_length // self.config.sep_parallel_degree,
831+
self.num_key_value_heads * self.head_dim,
832+
],
811833
)
812834
value_states = paddle.reshape(
813835
value_states,
814-
[-1, self.seq_length // self.config.sep_parallel_degree, self.num_heads * self.head_dim],
836+
[
837+
-1,
838+
self.seq_length // self.config.sep_parallel_degree,
839+
self.num_key_value_heads * self.head_dim,
840+
],
815841
)
816842
query_states = self.reshard_layer(
817843
query_states,

0 commit comments

Comments
 (0)