@@ -474,47 +474,50 @@ def set_state_dict(self, state_dict):
474
474
unfused_state_dict = {}
475
475
head_size = self .hidden_size // self .num_attention_heads
476
476
477
- self .embed_tokens .weight .set_value (paddle .to_tensor (state_dict ["llama.embed_tokens.weight" ]))
477
+ self .embed_tokens .weight .set_value (paddle .to_tensor (state_dict ["llama.embed_tokens.weight" ], dtype = self . embed_tokens . weight . dtype ))
478
478
self .norm .weight .set_value (paddle .to_tensor (state_dict ["llama.norm.weight" ], dtype = self .norm .weight .dtype ))
479
479
480
480
for idx in range (self .config .num_hidden_layers ):
481
481
logger .info (f"set state for layer { idx } " )
482
482
483
483
if self .use_weight_only :
484
484
logger .info ("weight only is enabled" )
485
- unfused_state_dict = {}
486
- unfused_state_dict ["self_attn.q_proj.weight" ] = state_dict [
487
- "llama.layers.{}.self_attn.q_proj.weight" .format (idx )
488
- ]
489
- unfused_state_dict ["self_attn.k_proj.weight" ] = state_dict [
490
- "llama.layers.{}.self_attn.k_proj.weight" .format (idx )
491
- ]
492
- unfused_state_dict ["self_attn.v_proj.weight" ] = state_dict [
493
- "llama.layers.{}.self_attn.v_proj.weight" .format (idx )
494
- ]
495
-
496
- concated_qkv_weight = (
497
- np .concatenate (
498
- [
499
- unfused_state_dict ["self_attn.q_proj.weight" ],
500
- unfused_state_dict ["self_attn.k_proj.weight" ],
501
- unfused_state_dict ["self_attn.v_proj.weight" ],
502
- ],
503
- axis = - 1 ,
504
- )
505
- .transpose (1 , 0 )
506
- .reshape (
507
- 3 * (self .num_attention_heads // self .config .tensor_parallel_degree ) * (head_size ),
508
- self .hidden_size ,
485
+ if "llama.layers.{}.self_attn.qkv_proj.weight" .format (idx ) in state_dict .keys ():
486
+ concated_qkv_weight = state_dict ["llama.layers.{}.self_attn.qkv_proj.weight" .format (idx )].transpose ([1 , 0 ])
487
+ else :
488
+ unfused_state_dict = {}
489
+ unfused_state_dict ["self_attn.q_proj.weight" ] = state_dict [
490
+ "llama.layers.{}.self_attn.q_proj.weight" .format (idx )
491
+ ]
492
+ unfused_state_dict ["self_attn.k_proj.weight" ] = state_dict [
493
+ "llama.layers.{}.self_attn.k_proj.weight" .format (idx )
494
+ ]
495
+ unfused_state_dict ["self_attn.v_proj.weight" ] = state_dict [
496
+ "llama.layers.{}.self_attn.v_proj.weight" .format (idx )
497
+ ]
498
+ concated_qkv_weight = (
499
+ np .concatenate (
500
+ [
501
+ unfused_state_dict ["self_attn.q_proj.weight" ],
502
+ unfused_state_dict ["self_attn.k_proj.weight" ],
503
+ unfused_state_dict ["self_attn.v_proj.weight" ],
504
+ ],
505
+ axis = - 1 ,
506
+ )
507
+ .transpose (1 , 0 )
508
+ .reshape (
509
+ 3 * (self .num_attention_heads // self .config .tensor_parallel_degree ) * (head_size ),
510
+ self .hidden_size ,
511
+ )
512
+ ) # reshape(3, self.num_attention_heself.hidden_sizeads // self.config.tensor_parallel_degree, head_size, )
513
+ if "llama.layers.{}.mlp.gate_up_fused_proj.weight" .format (idx ) in state_dict .keys ():
514
+ concated_ffn1_weight = state_dict ["llama.layers.{}.mlp.gate_up_fused_proj.weight" .format (idx )]
515
+ else :
516
+ unfused_state_dict ["mlp.gate_proj.weight" ] = state_dict ["llama.layers.{}.mlp.gate_proj.weight" .format (idx )]
517
+ unfused_state_dict ["mlp.up_proj.weight" ] = state_dict ["llama.layers.{}.mlp.up_proj.weight" .format (idx )]
518
+ concated_ffn1_weight = np .concatenate (
519
+ [unfused_state_dict ["mlp.gate_proj.weight" ], unfused_state_dict ["mlp.up_proj.weight" ]], axis = - 1
509
520
)
510
- ) # reshape(3, self.num_attention_heself.hidden_sizeads // self.config.tensor_parallel_degree, head_size, )
511
-
512
- unfused_state_dict ["mlp.gate_proj.weight" ] = state_dict ["llama.layers.{}.mlp.gate_proj.weight" .format (idx )]
513
- unfused_state_dict ["mlp.up_proj.weight" ] = state_dict ["llama.layers.{}.mlp.up_proj.weight" .format (idx )]
514
-
515
- concated_ffn1_weight = np .concatenate (
516
- [unfused_state_dict ["mlp.gate_proj.weight" ], unfused_state_dict ["mlp.up_proj.weight" ]], axis = - 1
517
- )
518
521
ffn1_weight_tensor = paddle .to_tensor (concated_ffn1_weight )
519
522
520
523
qkv_weight_tensor = paddle .to_tensor (concated_qkv_weight )
@@ -534,7 +537,7 @@ def set_state_dict(self, state_dict):
534
537
paddle .cast (paddle .to_tensor (concated_qkv_weight ), "int8" )
535
538
)
536
539
else :
537
- self .transformer_block .qkv_weights [idx ].set_value (qkv_weight_tensor )
540
+ self .transformer_block .qkv_weights [idx ].set_value (qkv_weight_tensor . cast ( self . transformer_block . qkv_weights [ idx ]. dtype ) )
538
541
539
542
linear_weight_tensor = paddle .to_tensor (state_dict ["llama.layers.{}.self_attn.o_proj.weight" .format (idx )])
540
543
if self .use_weight_only :
@@ -556,7 +559,7 @@ def set_state_dict(self, state_dict):
556
559
)
557
560
)
558
561
else :
559
- self .transformer_block .linear_weights [idx ].set_value (linear_weight_tensor )
562
+ self .transformer_block .linear_weights [idx ].set_value (linear_weight_tensor . cast ( self . transformer_block . linear_weights [ idx ]. dtype ) )
560
563
561
564
if self .use_weight_only :
562
565
ffn1_quanted_weight_tensor , ffn1_weight_scale_tensor = weight_quantize (
@@ -572,7 +575,7 @@ def set_state_dict(self, state_dict):
572
575
paddle .cast (paddle .to_tensor (concated_ffn1_weight ).transpose ((1 , 0 )), "int8" )
573
576
)
574
577
else :
575
- self .transformer_block .ffn1_weights [idx ].set_value (ffn1_weight_tensor )
578
+ self .transformer_block .ffn1_weights [idx ].set_value (ffn1_weight_tensor . cast ( self . transformer_block . ffn1_weights [ idx ]. dtype ) )
576
579
577
580
ffn2_weight_tensor = paddle .to_tensor (state_dict ["llama.layers.{}.mlp.down_proj.weight" .format (idx )])
578
581
if self .use_weight_only :
@@ -594,7 +597,7 @@ def set_state_dict(self, state_dict):
594
597
)
595
598
)
596
599
else :
597
- self .transformer_block .ffn2_weights [idx ].set_value (ffn2_weight_tensor )
600
+ self .transformer_block .ffn2_weights [idx ].set_value (ffn2_weight_tensor . cast ( self . transformer_block . ffn2_weights [ idx ]. dtype ) )
598
601
599
602
if self .quant_type == "a8w8" :
600
603
if self .shift_smooth_all_linears :
@@ -1264,7 +1267,7 @@ def forward(
1264
1267
@paddle .no_grad ()
1265
1268
def set_state_dict (self , state_dict ):
1266
1269
if "lm_head.weight" in state_dict :
1267
- self .lm_head .weight .set_value (state_dict ["lm_head.weight" ])
1270
+ self .lm_head .weight .set_value (paddle . to_tensor ( state_dict ["lm_head.weight" ], dtype = self . lm_head . weight . dtype ) )
1268
1271
self .llama .set_state_dict ({k : state_dict [k ] for k in state_dict .keys ()})
1269
1272
1270
1273
0 commit comments