@@ -525,15 +525,22 @@ def get_tensor_parallel_split_mappings(num_hidden_layers):
525
525
base_actions = {
526
526
# Column Linear
527
527
"lm_head.weight" : partial (fn , is_column = True ),
528
- "qwen.h.0.mlp.w2.weight" : partial (fn , is_column = True ),
529
- "qwen.h.0.mlp.w1.weight" : partial (fn , is_column = True ),
530
528
"qwen.h.0.attn.c_attn.weight" : partial (fn , is_column = True , is_naive_3fuse = True ),
531
529
"qwen.h.0.attn.c_attn.bias" : partial (fn , is_column = True , is_naive_3fuse = True ),
532
530
# Row Linear
533
531
"qwen.wte.weight" : partial (fn , is_column = False ),
534
532
"qwen.h.0.mlp.c_proj.weight" : partial (fn , is_column = False ),
535
533
"qwen.h.0.attn.c_proj.weight" : partial (fn , is_column = False ),
536
534
}
535
+
536
+ if config .fuse_attention_ffn :
537
+ base_actions ["layers.0.mlp.gate_up_fused_proj.weight" ] = partial (
538
+ fn , is_column = True , is_naive_2fuse = True
539
+ )
540
+ else :
541
+ base_actions ["qwen.h.0.mlp.w2.weight" ] = partial (fn , is_column = True )
542
+ base_actions ["qwen.h.0.mlp.w1.weight" ] = partial (fn , is_column = True )
543
+
537
544
for key , action in base_actions .items ():
538
545
if "h.0." in key :
539
546
for i in range (num_hidden_layers ):
@@ -569,8 +576,8 @@ def _get_name_mappings(cls, config: QWenConfig) -> List[StateDictNameMapping]:
569
576
f"h.{ layer_index } .attn.c_attn.bias" ,
570
577
],
571
578
[
572
- f"h.{ layer_index } .attn.c_proj .weight" ,
573
- f"h.{ layer_index } .attn.c_proj .weight" ,
579
+ f"h.{ layer_index } .attn.o_proj .weight" ,
580
+ f"h.{ layer_index } .attn.o_proj .weight" ,
574
581
"transpose" ,
575
582
],
576
583
[
0 commit comments