Skip to content

gqa fuse attention qkv #7890

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 50 additions & 13 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,17 +588,15 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
self.head_dim = self.hidden_size // config.num_attention_heads

self.num_key_value_heads = config.num_key_value_heads
assert config.num_attention_heads // config.num_key_value_heads
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.gqa_or_mqa = config.num_attention_heads != config.num_key_value_heads

self.max_position_embeddings = config.max_position_embeddings
self.seq_length = config.seq_length
self.sequence_parallel = config.sequence_parallel

self.fuse_attention_qkv = config.fuse_attention_qkv
if self.fuse_attention_qkv and config.num_attention_heads != config.num_key_value_heads:
raise ValueError(
f"fuse_attention_qkv can't be True when num_attention_heads {config.num_attention_heads}!= num_key_value_heads {config.num_key_value_heads}"
)

self.kv_indices = None
# Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True
Expand All @@ -615,6 +613,11 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
if self.num_key_value_heads % config.tensor_parallel_degree == 0:
self.num_key_value_heads = self.num_key_value_heads // config.tensor_parallel_degree
else:
if self.fuse_attention_qkv:
# TODO(Yuang): support fusion for kv when kv heads cannot be divided by mp
raise ValueError(
f"fuse_attention_qkv can't be True when num_key_value_heads {config.num_key_value_heads} % tensor_parallel_degree {config.tensor_parallel_degree} != 0"
)
logger.warning(
f"Get num_key_value_heads: {self.num_key_value_heads}, can't split to tensor_parallel_degree: {config.tensor_parallel_degree}, so we don't spilt key value weight."
)
Expand Down Expand Up @@ -644,7 +647,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
if self.fuse_attention_qkv:
self.qkv_proj = ColumnParallelLinear(
self.hidden_size,
3 * self.hidden_size,
self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim,
has_bias=False,
gather_output=False,
)
Expand Down Expand Up @@ -684,7 +687,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
if self.fuse_attention_qkv:
self.qkv_proj = nn.Linear(
self.hidden_size,
3 * self.hidden_size,
self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)
else:
Expand Down Expand Up @@ -771,12 +774,27 @@ def forward(

if self.fuse_attention_qkv:
mix_layer = self.qkv_proj(hidden_states)
# NOTE for GQA attention fusion (compatible with MHA and MQA):
# The weight for qkv_proj is in shape like [hidden_size, hidden_size + 2 * num_kv_heads * head_dim].
# After the projection, the mix_layer is in shape like [b, s, hidden_size + 2 * num_kv_heads * head_dim].
# Reshape the mix_layer into a shape like [b, s, num_kv_heads, (num_groups + 2) * head_dim],
# where num_groups = num_q_heads // num_kv_heads.
# Split the mix_layer on the last axis into three sections [num_groups * head_dim, head_dim, head_dim]
# to represent the q, k and v respectively.
# The q is in the shape like [b, s, num_kv_heads, num_groups * head_dim].
# The k and v are in the shape like [b, s, num_kv_heads, head_dim].
# Under MHA, the q is ready for the following calculation since num_kv_heads == num_q_heads,
# But for the GQA or MQA, q should be reshaped into [b, s, num_q_heads, head_dim].
if self.reshard_layer is not None:
if self.sequence_parallel:
assert self.seq_length % self.config.sep_parallel_degree == 0
mix_layer = paddle.reshape_(
mix_layer,
[-1, self.seq_length // self.config.sep_parallel_degree, 3 * self.num_heads * self.head_dim],
[
-1,
self.seq_length // self.config.sep_parallel_degree,
self.num_heads * self.head_dim + 2 * self.num_key_value_heads * self.head_dim,
],
)
# [bs, seq_len / sep, num_head, head_dim] -> [bs, seq_len, num_head / sep, head_dim]
mix_layer = self.reshard_layer(
Expand All @@ -785,15 +803,26 @@ def forward(
concat_axis=1,
)
mix_layer = paddle.reshape_(
mix_layer, [0, self.seq_length, -1, 3 * self.head_dim]
mix_layer, [0, self.seq_length, -1, (self.num_key_value_groups + 2) * self.head_dim]
) # [bs, seq_len, num_head/k, 3*head_dim], k is sep degree
else:
if self.sequence_parallel:
target_shape = [-1, self.seq_length, self.num_heads, 3 * self.head_dim]
target_shape = [
-1,
self.seq_length,
self.num_key_value_heads,
(self.num_key_value_groups + 2) * self.head_dim,
]
else:
target_shape = [0, 0, self.num_heads, 3 * self.head_dim]
target_shape = [0, 0, self.num_key_value_heads, (self.num_key_value_groups + 2) * self.head_dim]
mix_layer = paddle.reshape_(mix_layer, target_shape)
query_states, key_states, value_states = paddle.split(mix_layer, num_or_sections=3, axis=-1)
query_states, key_states, value_states = paddle.split(
mix_layer,
num_or_sections=[self.num_key_value_groups * self.head_dim, self.head_dim, self.head_dim],
axis=-1,
)
if self.gqa_or_mqa:
query_states = paddle.reshape_(query_states, [0, 0, self.num_heads, self.head_dim])
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
Expand All @@ -807,11 +836,19 @@ def forward(
)
key_states = paddle.reshape(
key_states,
[-1, self.seq_length // self.config.sep_parallel_degree, self.num_heads * self.head_dim],
[
-1,
self.seq_length // self.config.sep_parallel_degree,
self.num_key_value_heads * self.head_dim,
],
)
value_states = paddle.reshape(
value_states,
[-1, self.seq_length // self.config.sep_parallel_degree, self.num_heads * self.head_dim],
[
-1,
self.seq_length // self.config.sep_parallel_degree,
self.num_key_value_heads * self.head_dim,
],
)
query_states = self.reshard_layer(
query_states,
Expand Down