Skip to content

Commit 8dc2cf7

Browse files
committed
support fused weights for export_model
1 parent f36ed75 commit 8dc2cf7

File tree

1 file changed

+41
-38
lines changed

1 file changed

+41
-38
lines changed

paddlenlp/experimental/transformers/llama/modeling.py

Lines changed: 41 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -474,47 +474,50 @@ def set_state_dict(self, state_dict):
474474
unfused_state_dict = {}
475475
head_size = self.hidden_size // self.num_attention_heads
476476

477-
self.embed_tokens.weight.set_value(paddle.to_tensor(state_dict["llama.embed_tokens.weight"]))
477+
self.embed_tokens.weight.set_value(paddle.to_tensor(state_dict["llama.embed_tokens.weight"], dtype=self.embed_tokens.weight.dtype))
478478
self.norm.weight.set_value(paddle.to_tensor(state_dict["llama.norm.weight"], dtype=self.norm.weight.dtype))
479479

480480
for idx in range(self.config.num_hidden_layers):
481481
logger.info(f"set state for layer {idx}")
482482

483483
if self.use_weight_only:
484484
logger.info("weight only is enabled")
485-
unfused_state_dict = {}
486-
unfused_state_dict["self_attn.q_proj.weight"] = state_dict[
487-
"llama.layers.{}.self_attn.q_proj.weight".format(idx)
488-
]
489-
unfused_state_dict["self_attn.k_proj.weight"] = state_dict[
490-
"llama.layers.{}.self_attn.k_proj.weight".format(idx)
491-
]
492-
unfused_state_dict["self_attn.v_proj.weight"] = state_dict[
493-
"llama.layers.{}.self_attn.v_proj.weight".format(idx)
494-
]
495-
496-
concated_qkv_weight = (
497-
np.concatenate(
498-
[
499-
unfused_state_dict["self_attn.q_proj.weight"],
500-
unfused_state_dict["self_attn.k_proj.weight"],
501-
unfused_state_dict["self_attn.v_proj.weight"],
502-
],
503-
axis=-1,
504-
)
505-
.transpose(1, 0)
506-
.reshape(
507-
3 * (self.num_attention_heads // self.config.tensor_parallel_degree) * (head_size),
508-
self.hidden_size,
485+
if "llama.layers.{}.self_attn.qkv_proj.weight".format(idx) in state_dict.keys():
486+
concated_qkv_weight = state_dict["llama.layers.{}.self_attn.qkv_proj.weight".format(idx)].transpose([1, 0])
487+
else:
488+
unfused_state_dict = {}
489+
unfused_state_dict["self_attn.q_proj.weight"] = state_dict[
490+
"llama.layers.{}.self_attn.q_proj.weight".format(idx)
491+
]
492+
unfused_state_dict["self_attn.k_proj.weight"] = state_dict[
493+
"llama.layers.{}.self_attn.k_proj.weight".format(idx)
494+
]
495+
unfused_state_dict["self_attn.v_proj.weight"] = state_dict[
496+
"llama.layers.{}.self_attn.v_proj.weight".format(idx)
497+
]
498+
concated_qkv_weight = (
499+
np.concatenate(
500+
[
501+
unfused_state_dict["self_attn.q_proj.weight"],
502+
unfused_state_dict["self_attn.k_proj.weight"],
503+
unfused_state_dict["self_attn.v_proj.weight"],
504+
],
505+
axis=-1,
506+
)
507+
.transpose(1, 0)
508+
.reshape(
509+
3 * (self.num_attention_heads // self.config.tensor_parallel_degree) * (head_size),
510+
self.hidden_size,
511+
)
512+
) # reshape(3, self.num_attention_heself.hidden_sizeads // self.config.tensor_parallel_degree, head_size, )
513+
if "llama.layers.{}.mlp.gate_up_fused_proj.weight".format(idx) in state_dict.keys():
514+
concated_ffn1_weight = state_dict["llama.layers.{}.mlp.gate_up_fused_proj.weight".format(idx)]
515+
else:
516+
unfused_state_dict["mlp.gate_proj.weight"] = state_dict["llama.layers.{}.mlp.gate_proj.weight".format(idx)]
517+
unfused_state_dict["mlp.up_proj.weight"] = state_dict["llama.layers.{}.mlp.up_proj.weight".format(idx)]
518+
concated_ffn1_weight = np.concatenate(
519+
[unfused_state_dict["mlp.gate_proj.weight"], unfused_state_dict["mlp.up_proj.weight"]], axis=-1
509520
)
510-
) # reshape(3, self.num_attention_heself.hidden_sizeads // self.config.tensor_parallel_degree, head_size, )
511-
512-
unfused_state_dict["mlp.gate_proj.weight"] = state_dict["llama.layers.{}.mlp.gate_proj.weight".format(idx)]
513-
unfused_state_dict["mlp.up_proj.weight"] = state_dict["llama.layers.{}.mlp.up_proj.weight".format(idx)]
514-
515-
concated_ffn1_weight = np.concatenate(
516-
[unfused_state_dict["mlp.gate_proj.weight"], unfused_state_dict["mlp.up_proj.weight"]], axis=-1
517-
)
518521
ffn1_weight_tensor = paddle.to_tensor(concated_ffn1_weight)
519522

520523
qkv_weight_tensor = paddle.to_tensor(concated_qkv_weight)
@@ -534,7 +537,7 @@ def set_state_dict(self, state_dict):
534537
paddle.cast(paddle.to_tensor(concated_qkv_weight), "int8")
535538
)
536539
else:
537-
self.transformer_block.qkv_weights[idx].set_value(qkv_weight_tensor)
540+
self.transformer_block.qkv_weights[idx].set_value(qkv_weight_tensor.cast(self.transformer_block.qkv_weights[idx].dtype))
538541

539542
linear_weight_tensor = paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.weight".format(idx)])
540543
if self.use_weight_only:
@@ -556,7 +559,7 @@ def set_state_dict(self, state_dict):
556559
)
557560
)
558561
else:
559-
self.transformer_block.linear_weights[idx].set_value(linear_weight_tensor)
562+
self.transformer_block.linear_weights[idx].set_value(linear_weight_tensor.cast(self.transformer_block.linear_weights[idx].dtype))
560563

561564
if self.use_weight_only:
562565
ffn1_quanted_weight_tensor, ffn1_weight_scale_tensor = weight_quantize(
@@ -572,7 +575,7 @@ def set_state_dict(self, state_dict):
572575
paddle.cast(paddle.to_tensor(concated_ffn1_weight).transpose((1, 0)), "int8")
573576
)
574577
else:
575-
self.transformer_block.ffn1_weights[idx].set_value(ffn1_weight_tensor)
578+
self.transformer_block.ffn1_weights[idx].set_value(ffn1_weight_tensor.cast(self.transformer_block.ffn1_weights[idx].dtype))
576579

577580
ffn2_weight_tensor = paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.weight".format(idx)])
578581
if self.use_weight_only:
@@ -594,7 +597,7 @@ def set_state_dict(self, state_dict):
594597
)
595598
)
596599
else:
597-
self.transformer_block.ffn2_weights[idx].set_value(ffn2_weight_tensor)
600+
self.transformer_block.ffn2_weights[idx].set_value(ffn2_weight_tensor.cast(self.transformer_block.ffn2_weights[idx].dtype))
598601

599602
if self.quant_type == "a8w8":
600603
if self.shift_smooth_all_linears:
@@ -1264,7 +1267,7 @@ def forward(
12641267
@paddle.no_grad()
12651268
def set_state_dict(self, state_dict):
12661269
if "lm_head.weight" in state_dict:
1267-
self.lm_head.weight.set_value(state_dict["lm_head.weight"])
1270+
self.lm_head.weight.set_value(paddle.to_tensor(state_dict["lm_head.weight"], dtype=self.lm_head.weight.dtype))
12681271
self.llama.set_state_dict({k: state_dict[k] for k in state_dict.keys()})
12691272

12701273

0 commit comments

Comments
 (0)