From 65d15999aa68887b2af593e63fc2d64594ddb6b0 Mon Sep 17 00:00:00 2001 From: liuyuang Date: Wed, 24 Jan 2024 09:26:30 +0800 Subject: [PATCH 1/2] gqa fuse attention qkv --- paddlenlp/transformers/llama/modeling.py | 52 ++++++++++++++++++------ 1 file changed, 39 insertions(+), 13 deletions(-) diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index f810d7e17b9b..06c46fa5b85a 100644 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -588,17 +588,15 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): self.head_dim = self.hidden_size // config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads + assert config.num_attention_heads // config.num_key_value_heads self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.gqa_or_mqa = config.num_attention_heads != config.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.seq_length = config.seq_length self.sequence_parallel = config.sequence_parallel self.fuse_attention_qkv = config.fuse_attention_qkv - if self.fuse_attention_qkv and config.num_attention_heads != config.num_key_value_heads: - raise ValueError( - f"fuse_attention_qkv can't be True when num_attention_heads {config.num_attention_heads}!= num_key_value_heads {config.num_key_value_heads}" - ) self.kv_indices = None # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True @@ -615,6 +613,11 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): if self.num_key_value_heads % config.tensor_parallel_degree == 0: self.num_key_value_heads = self.num_key_value_heads // config.tensor_parallel_degree else: + if self.fuse_attention_qkv: + # TODO(Yuang): support fusion for kv when kv heads cannot be divided by mp + raise ValueError( + f"fuse_attention_qkv can't be True when num_key_value_heads {config.num_key_value_heads} % tensor_parallel_degree {config.tensor_parallel_degree} != 0" + ) logger.warning( f"Get num_key_value_heads: {self.num_key_value_heads}, can't split to tensor_parallel_degree: {config.tensor_parallel_degree}, so we don't spilt key value weight." ) @@ -644,7 +647,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): if self.fuse_attention_qkv: self.qkv_proj = ColumnParallelLinear( self.hidden_size, - 3 * self.hidden_size, + self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim, has_bias=False, gather_output=False, ) @@ -684,7 +687,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): if self.fuse_attention_qkv: self.qkv_proj = nn.Linear( self.hidden_size, - 3 * self.hidden_size, + self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim, bias_attr=False, ) else: @@ -776,7 +779,11 @@ def forward( assert self.seq_length % self.config.sep_parallel_degree == 0 mix_layer = paddle.reshape_( mix_layer, - [-1, self.seq_length // self.config.sep_parallel_degree, 3 * self.num_heads * self.head_dim], + [ + -1, + self.seq_length // self.config.sep_parallel_degree, + self.num_heads * self.head_dim + 2 * self.num_key_value_heads * self.head_dim, + ], ) # [bs, seq_len / sep, num_head, head_dim] -> [bs, seq_len, num_head / sep, head_dim] mix_layer = self.reshard_layer( @@ -785,15 +792,26 @@ def forward( concat_axis=1, ) mix_layer = paddle.reshape_( - mix_layer, [0, self.seq_length, -1, 3 * self.head_dim] + mix_layer, [0, self.seq_length, -1, (self.num_key_value_groups + 2) * self.head_dim] ) # [bs, seq_len, num_head/k, 3*head_dim], k is sep degree else: if self.sequence_parallel: - target_shape = [-1, self.seq_length, self.num_heads, 3 * self.head_dim] + target_shape = [ + -1, + self.seq_length, + self.num_key_value_heads, + (self.num_key_value_groups + 2) * self.head_dim, + ] else: - target_shape = [0, 0, self.num_heads, 3 * self.head_dim] + target_shape = [0, 0, self.num_key_value_heads, (self.num_key_value_groups + 2) * self.head_dim] mix_layer = paddle.reshape_(mix_layer, target_shape) - query_states, key_states, value_states = paddle.split(mix_layer, num_or_sections=3, axis=-1) + query_states, key_states, value_states = paddle.split( + mix_layer, + num_or_sections=[self.num_key_value_groups * self.head_dim, self.head_dim, self.head_dim], + axis=-1, + ) + if self.gqa_or_mqa: + query_states = paddle.reshape_(query_states, [0, 0, self.num_heads, self.head_dim]) else: query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) @@ -807,11 +825,19 @@ def forward( ) key_states = paddle.reshape( key_states, - [-1, self.seq_length // self.config.sep_parallel_degree, self.num_heads * self.head_dim], + [ + -1, + self.seq_length // self.config.sep_parallel_degree, + self.num_key_value_heads * self.head_dim, + ], ) value_states = paddle.reshape( value_states, - [-1, self.seq_length // self.config.sep_parallel_degree, self.num_heads * self.head_dim], + [ + -1, + self.seq_length // self.config.sep_parallel_degree, + self.num_key_value_heads * self.head_dim, + ], ) query_states = self.reshard_layer( query_states, From 5026243b3e7f332e42301d4a9ad9267e2e12480e Mon Sep 17 00:00:00 2001 From: liuyuang Date: Wed, 31 Jan 2024 13:49:47 +0800 Subject: [PATCH 2/2] add annotation for the fusion --- paddlenlp/transformers/llama/modeling.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 06c46fa5b85a..ebe8ff213d4b 100644 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -774,6 +774,17 @@ def forward( if self.fuse_attention_qkv: mix_layer = self.qkv_proj(hidden_states) + # NOTE for GQA attention fusion (compatible with MHA and MQA): + # The weight for qkv_proj is in shape like [hidden_size, hidden_size + 2 * num_kv_heads * head_dim]. + # After the projection, the mix_layer is in shape like [b, s, hidden_size + 2 * num_kv_heads * head_dim]. + # Reshape the mix_layer into a shape like [b, s, num_kv_heads, (num_groups + 2) * head_dim], + # where num_groups = num_q_heads // num_kv_heads. + # Split the mix_layer on the last axis into three sections [num_groups * head_dim, head_dim, head_dim] + # to represent the q, k and v respectively. + # The q is in the shape like [b, s, num_kv_heads, num_groups * head_dim]. + # The k and v are in the shape like [b, s, num_kv_heads, head_dim]. + # Under MHA, the q is ready for the following calculation since num_kv_heads == num_q_heads, + # But for the GQA or MQA, q should be reshaped into [b, s, num_q_heads, head_dim]. if self.reshard_layer is not None: if self.sequence_parallel: assert self.seq_length % self.config.sep_parallel_degree == 0