Skip to content

Commit b3ef512

Browse files
committed
support GQA
1 parent 082dc52 commit b3ef512

File tree

1 file changed

+70
-27
lines changed

1 file changed

+70
-27
lines changed

paddlenlp/transformers/llama/modeling.py

Lines changed: 70 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -588,17 +588,15 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
588588
self.head_dim = self.hidden_size // config.num_attention_heads
589589

590590
self.num_key_value_heads = config.num_key_value_heads
591+
assert config.num_attention_heads // config.num_key_value_heads
591592
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
593+
self.gqa_or_mqa = config.num_attention_heads != config.num_key_value_heads
592594

593595
self.max_position_embeddings = config.max_position_embeddings
594596
self.seq_length = config.seq_length
595597
self.sequence_parallel = config.sequence_parallel
596598

597599
self.fuse_attention_qkv = config.fuse_attention_qkv
598-
if self.fuse_attention_qkv and config.num_attention_heads != config.num_key_value_heads:
599-
raise ValueError(
600-
f"fuse_attention_qkv can't be True when num_attention_heads {config.num_attention_heads}!= num_key_value_heads {config.num_key_value_heads}"
601-
)
602600

603601
self.kv_indices = None
604602
# Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True
@@ -615,6 +613,11 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
615613
if self.num_key_value_heads % config.tensor_parallel_degree == 0:
616614
self.num_key_value_heads = self.num_key_value_heads // config.tensor_parallel_degree
617615
else:
616+
if self.fuse_attention_qkv:
617+
# TODO(Yuang): support fusion for kv when kv heads cannot be divided by mp
618+
raise ValueError(
619+
f"fuse_attention_qkv can't be True when num_key_value_heads {config.num_key_value_heads} % tensor_parallel_degree {config.tensor_parallel_degree} != 0"
620+
)
618621
logger.warning(
619622
f"Get num_key_value_heads: {self.num_key_value_heads}, can't split to tensor_parallel_degree: {config.tensor_parallel_degree}, so we don't spilt key value weight."
620623
)
@@ -644,7 +647,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
644647
if self.fuse_attention_qkv:
645648
self.qkv_proj = ColumnParallelLinear(
646649
self.hidden_size,
647-
3 * self.hidden_size,
650+
self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim,
648651
has_bias=False,
649652
gather_output=False,
650653
)
@@ -684,7 +687,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
684687
if self.fuse_attention_qkv:
685688
self.qkv_proj = nn.Linear(
686689
self.hidden_size,
687-
3 * self.hidden_size,
690+
self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim,
688691
bias_attr=False,
689692
)
690693
else:
@@ -776,7 +779,11 @@ def forward(
776779
assert self.seq_length % self.config.sep_parallel_degree == 0
777780
mix_layer = paddle.reshape_(
778781
mix_layer,
779-
[-1, self.seq_length // self.config.sep_parallel_degree, 3 * self.num_heads * self.head_dim],
782+
[
783+
-1,
784+
self.seq_length // self.config.sep_parallel_degree,
785+
self.num_heads * self.head_dim + 2 * self.num_key_value_heads * self.head_dim,
786+
],
780787
)
781788
# [bs, seq_len / sep, num_head, head_dim] -> [bs, seq_len, num_head / sep, head_dim]
782789
mix_layer = self.reshard_layer(
@@ -785,15 +792,26 @@ def forward(
785792
concat_axis=1,
786793
)
787794
mix_layer = paddle.reshape_(
788-
mix_layer, [0, self.seq_length, -1, 3 * self.head_dim]
795+
mix_layer, [0, self.seq_length, -1, (self.num_key_value_groups + 2) * self.head_dim]
789796
) # [bs, seq_len, num_head/k, 3*head_dim], k is sep degree
790797
else:
791798
if self.sequence_parallel:
792-
target_shape = [-1, self.seq_length, self.num_heads, 3 * self.head_dim]
799+
target_shape = [
800+
-1,
801+
self.seq_length,
802+
self.num_key_value_heads,
803+
(self.num_key_value_groups + 2) * self.head_dim,
804+
]
793805
else:
794-
target_shape = [0, 0, self.num_heads, 3 * self.head_dim]
806+
target_shape = [0, 0, self.num_key_value_heads, (self.num_key_value_groups + 2) * self.head_dim]
795807
mix_layer = paddle.reshape_(mix_layer, target_shape)
796-
query_states, key_states, value_states = paddle.split(mix_layer, num_or_sections=3, axis=-1)
808+
query_states, key_states, value_states = paddle.split(
809+
mix_layer,
810+
num_or_sections=[self.num_key_value_groups * self.head_dim, self.head_dim, self.head_dim],
811+
axis=-1,
812+
)
813+
if self.gqa_or_mqa:
814+
query_states = paddle.reshape_(query_states, [0, 0, self.num_heads, self.head_dim])
797815
else:
798816
query_states = self.q_proj(hidden_states)
799817
key_states = self.k_proj(hidden_states)
@@ -807,11 +825,19 @@ def forward(
807825
)
808826
key_states = paddle.reshape(
809827
key_states,
810-
[-1, self.seq_length // self.config.sep_parallel_degree, self.num_heads * self.head_dim],
828+
[
829+
-1,
830+
self.seq_length // self.config.sep_parallel_degree,
831+
self.num_key_value_heads * self.head_dim,
832+
],
811833
)
812834
value_states = paddle.reshape(
813835
value_states,
814-
[-1, self.seq_length // self.config.sep_parallel_degree, self.num_heads * self.head_dim],
836+
[
837+
-1,
838+
self.seq_length // self.config.sep_parallel_degree,
839+
self.num_key_value_heads * self.head_dim,
840+
],
815841
)
816842
query_states = self.reshard_layer(
817843
query_states,
@@ -855,16 +881,38 @@ def forward(
855881
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))
856882
if self.use_fused_rope:
857883
assert past_key_value is None, "fuse rotary not support cache kv for now"
884+
batch_size, seq_length, num_heads, head_dim = query_states.shape
885+
_, kv_seq_len, num_key_value_heads, _ = key_states.shape
858886
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
859-
query_states, key_states, _ = fused_rotary_position_embedding(
860-
query_states,
861-
key_states,
862-
v=None,
863-
sin=sin,
864-
cos=cos,
865-
position_ids=position_ids,
866-
use_neox_rotary_style=False,
867-
)
887+
if num_heads != num_key_value_heads:
888+
query_states, _, _ = fused_rotary_position_embedding(
889+
query_states,
890+
None,
891+
None,
892+
sin=sin,
893+
cos=cos,
894+
position_ids=position_ids,
895+
use_neox_rotary_style=False,
896+
)
897+
key_states, _, _ = fused_rotary_position_embedding(
898+
key_states,
899+
None,
900+
None,
901+
sin=sin,
902+
cos=cos,
903+
position_ids=position_ids,
904+
use_neox_rotary_style=False,
905+
)
906+
else:
907+
query_states, key_states, _ = fused_rotary_position_embedding(
908+
query_states,
909+
key_states,
910+
v=None,
911+
sin=sin,
912+
cos=cos,
913+
position_ids=position_ids,
914+
use_neox_rotary_style=False,
915+
)
868916
else:
869917
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
870918
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
@@ -881,11 +929,6 @@ def forward(
881929
key_states = paddle.index_select(key_states, self.kv_indices, axis=2)
882930
value_states = paddle.index_select(value_states, self.kv_indices, axis=2)
883931

884-
# TODO(wj-Mcat): use broadcast strategy when n_kv_heads = 1
885-
# repeat k/v heads if n_kv_heads < n_heads
886-
key_states = repeat_kv(key_states, self.num_key_value_groups)
887-
value_states = repeat_kv(value_states, self.num_key_value_groups)
888-
889932
has_gradient = not (query_states.stop_gradient and key_states.stop_gradient and value_states.stop_gradient)
890933
if (
891934
self.enable_recompute

0 commit comments

Comments
 (0)