Skip to content

Commit 245e097

Browse files
committed
add some mappings
1 parent 90dfcf6 commit 245e097

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

paddlenlp/transformers/qwen/modeling.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -525,15 +525,22 @@ def get_tensor_parallel_split_mappings(num_hidden_layers):
525525
base_actions = {
526526
# Column Linear
527527
"lm_head.weight": partial(fn, is_column=True),
528-
"qwen.h.0.mlp.w2.weight": partial(fn, is_column=True),
529-
"qwen.h.0.mlp.w1.weight": partial(fn, is_column=True),
530528
"qwen.h.0.attn.c_attn.weight": partial(fn, is_column=True, is_naive_3fuse=True),
531529
"qwen.h.0.attn.c_attn.bias": partial(fn, is_column=True, is_naive_3fuse=True),
532530
# Row Linear
533531
"qwen.wte.weight": partial(fn, is_column=False),
534532
"qwen.h.0.mlp.c_proj.weight": partial(fn, is_column=False),
535533
"qwen.h.0.attn.c_proj.weight": partial(fn, is_column=False),
536534
}
535+
536+
if config.fuse_attention_ffn:
537+
base_actions["layers.0.mlp.gate_up_fused_proj.weight"] = partial(
538+
fn, is_column=True, is_naive_2fuse=True
539+
)
540+
else:
541+
base_actions["qwen.h.0.mlp.w2.weight"] = partial(fn, is_column=True)
542+
base_actions["qwen.h.0.mlp.w1.weight"] = partial(fn, is_column=True)
543+
537544
for key, action in base_actions.items():
538545
if "h.0." in key:
539546
for i in range(num_hidden_layers):
@@ -569,8 +576,8 @@ def _get_name_mappings(cls, config: QWenConfig) -> List[StateDictNameMapping]:
569576
f"h.{layer_index}.attn.c_attn.bias",
570577
],
571578
[
572-
f"h.{layer_index}.attn.c_proj.weight",
573-
f"h.{layer_index}.attn.c_proj.weight",
579+
f"h.{layer_index}.attn.o_proj.weight",
580+
f"h.{layer_index}.attn.o_proj.weight",
574581
"transpose",
575582
],
576583
[

0 commit comments

Comments
 (0)