Skip to content

Commit ff61d4a

Browse files
committed
support fused weights for export_model
1 parent f36ed75 commit ff61d4a

File tree

1 file changed

+70
-43
lines changed

1 file changed

+70
-43
lines changed

paddlenlp/experimental/transformers/llama/modeling.py

Lines changed: 70 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
GenerationInferenceModel,
4848
)
4949
from paddlenlp.transformers import LlamaConfig, LlamaPretrainedModel
50+
from paddlenlp.transformers.conversion_utils import split_param_func
5051
from paddlenlp.transformers.llama.modeling import LlamaLMHead
5152
from paddlenlp.transformers.model_outputs import (
5253
BaseModelOutputWithPastAndCrossAttentions,
@@ -473,48 +474,66 @@ def forward(
473474
def set_state_dict(self, state_dict):
474475
unfused_state_dict = {}
475476
head_size = self.hidden_size // self.num_attention_heads
477+
split_fn = split_param_func()
476478

477-
self.embed_tokens.weight.set_value(paddle.to_tensor(state_dict["llama.embed_tokens.weight"]))
478-
self.norm.weight.set_value(paddle.to_tensor(state_dict["llama.norm.weight"], dtype=self.norm.weight.dtype))
479+
self.embed_tokens.weight.set_value(
480+
paddle.to_tensor(state_dict["llama.embed_tokens.weight"]).cast(self.embed_tokens.weight.dtype)
481+
)
482+
self.norm.weight.set_value(paddle.to_tensor(state_dict["llama.norm.weight"]).cast(self.norm.weight.dtype))
479483

480484
for idx in range(self.config.num_hidden_layers):
481485
logger.info(f"set state for layer {idx}")
482486

483487
if self.use_weight_only:
484488
logger.info("weight only is enabled")
485-
unfused_state_dict = {}
486-
unfused_state_dict["self_attn.q_proj.weight"] = state_dict[
487-
"llama.layers.{}.self_attn.q_proj.weight".format(idx)
488-
]
489-
unfused_state_dict["self_attn.k_proj.weight"] = state_dict[
490-
"llama.layers.{}.self_attn.k_proj.weight".format(idx)
491-
]
492-
unfused_state_dict["self_attn.v_proj.weight"] = state_dict[
493-
"llama.layers.{}.self_attn.v_proj.weight".format(idx)
494-
]
495-
496-
concated_qkv_weight = (
497-
np.concatenate(
498-
[
499-
unfused_state_dict["self_attn.q_proj.weight"],
500-
unfused_state_dict["self_attn.k_proj.weight"],
501-
unfused_state_dict["self_attn.v_proj.weight"],
502-
],
489+
if "llama.layers.{}.self_attn.qkv_proj.weight".format(idx) in state_dict.keys():
490+
concated_qkv_weight = np.concatenate(
491+
split_fn(
492+
state_dict["llama.layers.{}.self_attn.qkv_proj.weight".format(idx)],
493+
is_qkv=True,
494+
num_heads=self.num_attention_heads // self.config.tensor_parallel_degree,
495+
num_key_value_heads=self.num_attention_heads // self.config.tensor_parallel_degree,
496+
),
503497
axis=-1,
504498
)
505-
.transpose(1, 0)
506-
.reshape(
507-
3 * (self.num_attention_heads // self.config.tensor_parallel_degree) * (head_size),
508-
self.hidden_size,
499+
else:
500+
unfused_state_dict = {}
501+
unfused_state_dict["self_attn.q_proj.weight"] = state_dict[
502+
"llama.layers.{}.self_attn.q_proj.weight".format(idx)
503+
]
504+
unfused_state_dict["self_attn.k_proj.weight"] = state_dict[
505+
"llama.layers.{}.self_attn.k_proj.weight".format(idx)
506+
]
507+
unfused_state_dict["self_attn.v_proj.weight"] = state_dict[
508+
"llama.layers.{}.self_attn.v_proj.weight".format(idx)
509+
]
510+
concated_qkv_weight = (
511+
np.concatenate(
512+
[
513+
unfused_state_dict["self_attn.q_proj.weight"],
514+
unfused_state_dict["self_attn.k_proj.weight"],
515+
unfused_state_dict["self_attn.v_proj.weight"],
516+
],
517+
axis=-1,
518+
)
519+
.transpose(1, 0)
520+
.reshape(
521+
3 * (self.num_attention_heads // self.config.tensor_parallel_degree) * (head_size),
522+
self.hidden_size,
523+
)
524+
) # reshape(3, self.num_attention_heself.hidden_sizeads // self.config.tensor_parallel_degree, head_size, )
525+
if "llama.layers.{}.mlp.gate_up_fused_proj.weight".format(idx) in state_dict.keys():
526+
ffn1_weight_tensor = np.concatenate(
527+
split_fn(state_dict["llama.layers.{}.mlp.gate_up_fused_proj.weight".format(idx)]), axis=-1
528+
)
529+
else:
530+
unfused_state_dict["mlp.gate_proj.weight"] = state_dict[
531+
"llama.layers.{}.mlp.gate_proj.weight".format(idx)
532+
]
533+
unfused_state_dict["mlp.up_proj.weight"] = state_dict["llama.layers.{}.mlp.up_proj.weight".format(idx)]
534+
concated_ffn1_weight = np.concatenate(
535+
[unfused_state_dict["mlp.gate_proj.weight"], unfused_state_dict["mlp.up_proj.weight"]], axis=-1
509536
)
510-
) # reshape(3, self.num_attention_heself.hidden_sizeads // self.config.tensor_parallel_degree, head_size, )
511-
512-
unfused_state_dict["mlp.gate_proj.weight"] = state_dict["llama.layers.{}.mlp.gate_proj.weight".format(idx)]
513-
unfused_state_dict["mlp.up_proj.weight"] = state_dict["llama.layers.{}.mlp.up_proj.weight".format(idx)]
514-
515-
concated_ffn1_weight = np.concatenate(
516-
[unfused_state_dict["mlp.gate_proj.weight"], unfused_state_dict["mlp.up_proj.weight"]], axis=-1
517-
)
518537
ffn1_weight_tensor = paddle.to_tensor(concated_ffn1_weight)
519538

520539
qkv_weight_tensor = paddle.to_tensor(concated_qkv_weight)
@@ -534,7 +553,9 @@ def set_state_dict(self, state_dict):
534553
paddle.cast(paddle.to_tensor(concated_qkv_weight), "int8")
535554
)
536555
else:
537-
self.transformer_block.qkv_weights[idx].set_value(qkv_weight_tensor)
556+
self.transformer_block.qkv_weights[idx].set_value(
557+
qkv_weight_tensor.cast(self.transformer_block.qkv_weights[idx].dtype)
558+
)
538559

539560
linear_weight_tensor = paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.weight".format(idx)])
540561
if self.use_weight_only:
@@ -556,7 +577,9 @@ def set_state_dict(self, state_dict):
556577
)
557578
)
558579
else:
559-
self.transformer_block.linear_weights[idx].set_value(linear_weight_tensor)
580+
self.transformer_block.linear_weights[idx].set_value(
581+
linear_weight_tensor.cast(self.transformer_block.linear_weights[idx].dtype)
582+
)
560583

561584
if self.use_weight_only:
562585
ffn1_quanted_weight_tensor, ffn1_weight_scale_tensor = weight_quantize(
@@ -572,7 +595,9 @@ def set_state_dict(self, state_dict):
572595
paddle.cast(paddle.to_tensor(concated_ffn1_weight).transpose((1, 0)), "int8")
573596
)
574597
else:
575-
self.transformer_block.ffn1_weights[idx].set_value(ffn1_weight_tensor)
598+
self.transformer_block.ffn1_weights[idx].set_value(
599+
ffn1_weight_tensor.cast(self.transformer_block.ffn1_weights[idx].dtype)
600+
)
576601

577602
ffn2_weight_tensor = paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.weight".format(idx)])
578603
if self.use_weight_only:
@@ -594,7 +619,9 @@ def set_state_dict(self, state_dict):
594619
)
595620
)
596621
else:
597-
self.transformer_block.ffn2_weights[idx].set_value(ffn2_weight_tensor)
622+
self.transformer_block.ffn2_weights[idx].set_value(
623+
ffn2_weight_tensor.cast(self.transformer_block.ffn2_weights[idx].dtype)
624+
)
598625

599626
if self.quant_type == "a8w8":
600627
if self.shift_smooth_all_linears:
@@ -660,16 +687,14 @@ def set_state_dict(self, state_dict):
660687
)
661688

662689
self.transformer_block.ln_scales[idx].set_value(
663-
paddle.to_tensor(
664-
state_dict["llama.layers.{}.input_layernorm.weight".format(idx)],
665-
dtype=self.transformer_block.ln_scales[idx].dtype,
690+
paddle.to_tensor(state_dict["llama.layers.{}.input_layernorm.weight".format(idx)]).cast(
691+
self.transformer_block.ln_scales[idx].dtype
666692
)
667693
)
668694

669695
self.transformer_block.ffn_ln_scales[idx].set_value(
670-
paddle.to_tensor(
671-
state_dict["llama.layers.{}.post_attention_layernorm.weight".format(idx)],
672-
dtype=self.transformer_block.ffn_ln_scales[idx].dtype,
696+
paddle.to_tensor(state_dict["llama.layers.{}.post_attention_layernorm.weight".format(idx)]).cast(
697+
self.transformer_block.ffn_ln_scales[idx].dtype
673698
)
674699
)
675700

@@ -1264,7 +1289,9 @@ def forward(
12641289
@paddle.no_grad()
12651290
def set_state_dict(self, state_dict):
12661291
if "lm_head.weight" in state_dict:
1267-
self.lm_head.weight.set_value(state_dict["lm_head.weight"])
1292+
self.lm_head.weight.set_value(
1293+
paddle.to_tensor(state_dict["lm_head.weight"]).cast(self.lm_head.weight.dtype)
1294+
)
12681295
self.llama.set_state_dict({k: state_dict[k] for k in state_dict.keys()})
12691296

12701297

0 commit comments

Comments
 (0)