Skip to content

Commit ebe397e

Browse files
authored
Support fused_attention_qkv for auto_parallel llama (#8432)
* add * add * add * add * add
1 parent 562229c commit ebe397e

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

llm/llama/auto_parallel/run_pretrain_auto.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ class PreTrainingArguments(TrainingArguments):
8686
"help": "Enable fused_linear_param_grad pass, which should replace add_n_op with add_op for gradients accumulation."
8787
},
8888
)
89+
fuse_allreduce_split_to_reducescatter: bool = field(
90+
default=False,
91+
metadata={"help": "Enable fuse_allreduce_split_to_reducescatter pass."},
92+
)
8993
eliminate_transpose: bool = field(
9094
default=False,
9195
metadata={
@@ -138,6 +142,11 @@ def __post_init__(self):
138142
fused_passes.enable = True
139143
fused_passes.fused_passes_list.append("fused_linear_param_grad_add_pass")
140144

145+
if self.fuse_allreduce_split_to_reducescatter:
146+
fused_passes = self.strategy.fused_passes
147+
fused_passes.enable = True
148+
fused_passes.fused_passes_list.append("fuse_allreduce_split_to_reducescatter_pass")
149+
141150
if self.eliminate_transpose:
142151
fused_passes = self.strategy.fused_passes
143152
fused_passes.enable = True

paddlenlp/transformers/llama/modeling_auto.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -272,16 +272,14 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
272272
self.head_dim = self.hidden_size // config.num_attention_heads
273273

274274
self.num_key_value_heads = config.num_key_value_heads
275+
assert config.num_attention_heads // config.num_key_value_heads
275276
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
277+
self.gqa_or_mqa = config.num_attention_heads != config.num_key_value_heads
276278

277279
self.max_position_embeddings = config.max_position_embeddings
278280
self.seq_length = config.seq_length
279281

280282
self.fuse_attention_qkv = config.fuse_attention_qkv
281-
if self.fuse_attention_qkv and config.num_attention_heads != config.num_key_value_heads:
282-
raise ValueError(
283-
f"fuse_attention_qkv can't be True when num_attention_heads {config.num_attention_heads}!= num_key_value_heads {config.num_key_value_heads}"
284-
)
285283

286284
self.kv_indices = None
287285
# Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True
@@ -303,7 +301,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
303301
if self.fuse_attention_qkv:
304302
self.qkv_proj = nn.Linear(
305303
self.hidden_size,
306-
3 * self.hidden_size,
304+
self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim,
307305
bias_attr=False,
308306
)
309307
self.qkv_proj.weight = dist.shard_tensor(
@@ -415,10 +413,16 @@ def forward(
415413
)
416414

417415
if self.fuse_attention_qkv:
418-
target_shape = [0, 0, self.num_heads, 3 * self.head_dim]
416+
target_shape = [0, 0, self.num_key_value_heads, (self.num_key_value_groups + 2) * self.head_dim]
419417
mix_layer = self.qkv_proj(hidden_states)
420418
mix_layer = paddle.reshape_(mix_layer, target_shape)
421-
query_states, key_states, value_states = paddle.split(mix_layer, num_or_sections=3, axis=-1)
419+
query_states, key_states, value_states = paddle.split(
420+
mix_layer,
421+
num_or_sections=[self.num_key_value_groups * self.head_dim, self.head_dim, self.head_dim],
422+
axis=-1,
423+
)
424+
if self.gqa_or_mqa:
425+
query_states = paddle.reshape(query_states, [0, 0, self.num_heads, self.head_dim])
422426
else:
423427
target_query_shape = [0, 0, self.num_heads, self.head_dim]
424428
target_key_value_shape = [0, 0, self.num_key_value_heads, self.head_dim]

0 commit comments

Comments
 (0)