47
47
GenerationInferenceModel ,
48
48
)
49
49
from paddlenlp .transformers import LlamaConfig , LlamaPretrainedModel
50
+ from paddlenlp .transformers .conversion_utils import split_param_func
50
51
from paddlenlp .transformers .llama .modeling import LlamaLMHead
51
52
from paddlenlp .transformers .model_outputs import (
52
53
BaseModelOutputWithPastAndCrossAttentions ,
@@ -473,48 +474,66 @@ def forward(
473
474
def set_state_dict (self , state_dict ):
474
475
unfused_state_dict = {}
475
476
head_size = self .hidden_size // self .num_attention_heads
477
+ split_fn = split_param_func ()
476
478
477
- self .embed_tokens .weight .set_value (paddle .to_tensor (state_dict ["llama.embed_tokens.weight" ]))
478
- self .norm .weight .set_value (paddle .to_tensor (state_dict ["llama.norm.weight" ], dtype = self .norm .weight .dtype ))
479
+ self .embed_tokens .weight .set_value (
480
+ paddle .to_tensor (state_dict ["llama.embed_tokens.weight" ]).cast (self .embed_tokens .weight .dtype )
481
+ )
482
+ self .norm .weight .set_value (paddle .to_tensor (state_dict ["llama.norm.weight" ]).cast (self .norm .weight .dtype ))
479
483
480
484
for idx in range (self .config .num_hidden_layers ):
481
485
logger .info (f"set state for layer { idx } " )
482
486
483
487
if self .use_weight_only :
484
488
logger .info ("weight only is enabled" )
485
- unfused_state_dict = {}
486
- unfused_state_dict ["self_attn.q_proj.weight" ] = state_dict [
487
- "llama.layers.{}.self_attn.q_proj.weight" .format (idx )
488
- ]
489
- unfused_state_dict ["self_attn.k_proj.weight" ] = state_dict [
490
- "llama.layers.{}.self_attn.k_proj.weight" .format (idx )
491
- ]
492
- unfused_state_dict ["self_attn.v_proj.weight" ] = state_dict [
493
- "llama.layers.{}.self_attn.v_proj.weight" .format (idx )
494
- ]
495
-
496
- concated_qkv_weight = (
497
- np .concatenate (
498
- [
499
- unfused_state_dict ["self_attn.q_proj.weight" ],
500
- unfused_state_dict ["self_attn.k_proj.weight" ],
501
- unfused_state_dict ["self_attn.v_proj.weight" ],
502
- ],
489
+ if "llama.layers.{}.self_attn.qkv_proj.weight" .format (idx ) in state_dict .keys ():
490
+ concated_qkv_weight = np .concatenate (
491
+ split_fn (
492
+ state_dict ["llama.layers.{}.self_attn.qkv_proj.weight" .format (idx )],
493
+ is_qkv = True ,
494
+ num_heads = self .num_attention_heads // self .config .tensor_parallel_degree ,
495
+ num_key_value_heads = self .num_attention_heads // self .config .tensor_parallel_degree ,
496
+ ),
503
497
axis = - 1 ,
504
498
)
505
- .transpose (1 , 0 )
506
- .reshape (
507
- 3 * (self .num_attention_heads // self .config .tensor_parallel_degree ) * (head_size ),
508
- self .hidden_size ,
499
+ else :
500
+ unfused_state_dict = {}
501
+ unfused_state_dict ["self_attn.q_proj.weight" ] = state_dict [
502
+ "llama.layers.{}.self_attn.q_proj.weight" .format (idx )
503
+ ]
504
+ unfused_state_dict ["self_attn.k_proj.weight" ] = state_dict [
505
+ "llama.layers.{}.self_attn.k_proj.weight" .format (idx )
506
+ ]
507
+ unfused_state_dict ["self_attn.v_proj.weight" ] = state_dict [
508
+ "llama.layers.{}.self_attn.v_proj.weight" .format (idx )
509
+ ]
510
+ concated_qkv_weight = (
511
+ np .concatenate (
512
+ [
513
+ unfused_state_dict ["self_attn.q_proj.weight" ],
514
+ unfused_state_dict ["self_attn.k_proj.weight" ],
515
+ unfused_state_dict ["self_attn.v_proj.weight" ],
516
+ ],
517
+ axis = - 1 ,
518
+ )
519
+ .transpose (1 , 0 )
520
+ .reshape (
521
+ 3 * (self .num_attention_heads // self .config .tensor_parallel_degree ) * (head_size ),
522
+ self .hidden_size ,
523
+ )
524
+ ) # reshape(3, self.num_attention_heself.hidden_sizeads // self.config.tensor_parallel_degree, head_size, )
525
+ if "llama.layers.{}.mlp.gate_up_fused_proj.weight" .format (idx ) in state_dict .keys ():
526
+ ffn1_weight_tensor = np .concatenate (
527
+ split_fn (state_dict ["llama.layers.{}.mlp.gate_up_fused_proj.weight" .format (idx )]), axis = - 1
528
+ )
529
+ else :
530
+ unfused_state_dict ["mlp.gate_proj.weight" ] = state_dict [
531
+ "llama.layers.{}.mlp.gate_proj.weight" .format (idx )
532
+ ]
533
+ unfused_state_dict ["mlp.up_proj.weight" ] = state_dict ["llama.layers.{}.mlp.up_proj.weight" .format (idx )]
534
+ concated_ffn1_weight = np .concatenate (
535
+ [unfused_state_dict ["mlp.gate_proj.weight" ], unfused_state_dict ["mlp.up_proj.weight" ]], axis = - 1
509
536
)
510
- ) # reshape(3, self.num_attention_heself.hidden_sizeads // self.config.tensor_parallel_degree, head_size, )
511
-
512
- unfused_state_dict ["mlp.gate_proj.weight" ] = state_dict ["llama.layers.{}.mlp.gate_proj.weight" .format (idx )]
513
- unfused_state_dict ["mlp.up_proj.weight" ] = state_dict ["llama.layers.{}.mlp.up_proj.weight" .format (idx )]
514
-
515
- concated_ffn1_weight = np .concatenate (
516
- [unfused_state_dict ["mlp.gate_proj.weight" ], unfused_state_dict ["mlp.up_proj.weight" ]], axis = - 1
517
- )
518
537
ffn1_weight_tensor = paddle .to_tensor (concated_ffn1_weight )
519
538
520
539
qkv_weight_tensor = paddle .to_tensor (concated_qkv_weight )
@@ -534,7 +553,9 @@ def set_state_dict(self, state_dict):
534
553
paddle .cast (paddle .to_tensor (concated_qkv_weight ), "int8" )
535
554
)
536
555
else :
537
- self .transformer_block .qkv_weights [idx ].set_value (qkv_weight_tensor )
556
+ self .transformer_block .qkv_weights [idx ].set_value (
557
+ qkv_weight_tensor .cast (self .transformer_block .qkv_weights [idx ].dtype )
558
+ )
538
559
539
560
linear_weight_tensor = paddle .to_tensor (state_dict ["llama.layers.{}.self_attn.o_proj.weight" .format (idx )])
540
561
if self .use_weight_only :
@@ -556,7 +577,9 @@ def set_state_dict(self, state_dict):
556
577
)
557
578
)
558
579
else :
559
- self .transformer_block .linear_weights [idx ].set_value (linear_weight_tensor )
580
+ self .transformer_block .linear_weights [idx ].set_value (
581
+ linear_weight_tensor .cast (self .transformer_block .linear_weights [idx ].dtype )
582
+ )
560
583
561
584
if self .use_weight_only :
562
585
ffn1_quanted_weight_tensor , ffn1_weight_scale_tensor = weight_quantize (
@@ -572,7 +595,9 @@ def set_state_dict(self, state_dict):
572
595
paddle .cast (paddle .to_tensor (concated_ffn1_weight ).transpose ((1 , 0 )), "int8" )
573
596
)
574
597
else :
575
- self .transformer_block .ffn1_weights [idx ].set_value (ffn1_weight_tensor )
598
+ self .transformer_block .ffn1_weights [idx ].set_value (
599
+ ffn1_weight_tensor .cast (self .transformer_block .ffn1_weights [idx ].dtype )
600
+ )
576
601
577
602
ffn2_weight_tensor = paddle .to_tensor (state_dict ["llama.layers.{}.mlp.down_proj.weight" .format (idx )])
578
603
if self .use_weight_only :
@@ -594,7 +619,9 @@ def set_state_dict(self, state_dict):
594
619
)
595
620
)
596
621
else :
597
- self .transformer_block .ffn2_weights [idx ].set_value (ffn2_weight_tensor )
622
+ self .transformer_block .ffn2_weights [idx ].set_value (
623
+ ffn2_weight_tensor .cast (self .transformer_block .ffn2_weights [idx ].dtype )
624
+ )
598
625
599
626
if self .quant_type == "a8w8" :
600
627
if self .shift_smooth_all_linears :
@@ -660,16 +687,14 @@ def set_state_dict(self, state_dict):
660
687
)
661
688
662
689
self .transformer_block .ln_scales [idx ].set_value (
663
- paddle .to_tensor (
664
- state_dict ["llama.layers.{}.input_layernorm.weight" .format (idx )],
665
- dtype = self .transformer_block .ln_scales [idx ].dtype ,
690
+ paddle .to_tensor (state_dict ["llama.layers.{}.input_layernorm.weight" .format (idx )]).cast (
691
+ self .transformer_block .ln_scales [idx ].dtype
666
692
)
667
693
)
668
694
669
695
self .transformer_block .ffn_ln_scales [idx ].set_value (
670
- paddle .to_tensor (
671
- state_dict ["llama.layers.{}.post_attention_layernorm.weight" .format (idx )],
672
- dtype = self .transformer_block .ffn_ln_scales [idx ].dtype ,
696
+ paddle .to_tensor (state_dict ["llama.layers.{}.post_attention_layernorm.weight" .format (idx )]).cast (
697
+ self .transformer_block .ffn_ln_scales [idx ].dtype
673
698
)
674
699
)
675
700
@@ -1264,7 +1289,9 @@ def forward(
1264
1289
@paddle .no_grad ()
1265
1290
def set_state_dict (self , state_dict ):
1266
1291
if "lm_head.weight" in state_dict :
1267
- self .lm_head .weight .set_value (state_dict ["lm_head.weight" ])
1292
+ self .lm_head .weight .set_value (
1293
+ paddle .to_tensor (state_dict ["lm_head.weight" ]).cast (self .lm_head .weight .dtype )
1294
+ )
1268
1295
self .llama .set_state_dict ({k : state_dict [k ] for k in state_dict .keys ()})
1269
1296
1270
1297
0 commit comments