Skip to content

Commit faabf87

Browse files
authored
inference support llama3(wint8|4/a8w8) (#8630)
* inference support llama3 * fix
1 parent 69be4db commit faabf87

File tree

3 files changed

+43
-29
lines changed

3 files changed

+43
-29
lines changed

llm/predict/predictor.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
AutoTokenizer,
5252
ChatGLMTokenizer,
5353
ChatGLMv2Tokenizer,
54+
Llama3Tokenizer,
5455
LlamaTokenizer,
5556
PretrainedModel,
5657
PretrainedTokenizer,
@@ -739,13 +740,18 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer):
739740

740741
self.architectures = self.model_config.architectures[0].lower()
741742

742-
self.dtype = config.dtype or self.model_config
743+
self.dtype = config.dtype or self.model_config.dtype
743744

744745
self.total_max_length = config.src_length + config.max_length
745746
self.block_size = config.block_size
746747
self.pre_max_block_num = (self.total_max_length + config.block_size - 1) // config.block_size
747748
self.max_block_nums = config.batch_size * self.pre_max_block_num
748749

750+
try:
751+
self.rope_theta = self.model_config.rope_theta
752+
except:
753+
self.rope_theta = 10000.0
754+
749755
self.pre_cache_length = 0
750756

751757
if config.export_precache:
@@ -828,7 +834,7 @@ def init_inputs(self, config: PredictorArgument):
828834
)
829835
self.inputs["stop_nums"] = paddle.full(shape=[1], fill_value=config.batch_size, dtype="int64")
830836
self.inputs["rope_emb"] = self._get_rotary_position_embedding(
831-
paddle.arange(self.total_max_length).reshape((1, -1)), self.head_dim
837+
paddle.arange(self.total_max_length).reshape((1, -1)), self.head_dim, self.rope_theta
832838
)
833839
eos_token_id = get_eos_token_id(self.tokenizer, self.generation_config)
834840
if isinstance(eos_token_id, int):
@@ -895,7 +901,7 @@ def init_inputs(self, config: PredictorArgument):
895901
self.inputs["free_list"] = paddle.to_tensor(free_list, dtype="int32")
896902
self.inputs["free_list_len"] = paddle.full(shape=[1], fill_value=self.pre_max_block_num * 0.25, dtype="int32")
897903

898-
def _get_rotary_position_embedding(self, position_ids, head_dim):
904+
def _get_rotary_position_embedding(self, position_ids, head_dim, rope_theta=10000.0):
899905
"""
900906
Pre-calculate rotary position embedding for position_ids.
901907
@@ -908,7 +914,7 @@ def _get_rotary_position_embedding(self, position_ids, head_dim):
908914
"""
909915
bsz, max_seq_len = position_ids.shape[:2]
910916
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, head_dim), dtype="float32")
911-
inv_freq = 10000 ** (-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim)
917+
inv_freq = rope_theta ** (-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim)
912918

913919
# shape: [B, S, D/2]
914920
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq)
@@ -1213,8 +1219,8 @@ def create_predictor(
12131219
init_chat_template(tokenizer, predictor_args.model_name_or_path, predictor_args.chat_template)
12141220

12151221
# TODO(wj-Mcat): fix llama tokenzier pad_token bug
1216-
if isinstance(tokenizer, LlamaTokenizer) and not tokenizer.pad_token:
1217-
tokenizer.pad_token = tokenizer.unk_token
1222+
if (isinstance(tokenizer, (LlamaTokenizer, Llama3Tokenizer))) and not tokenizer.pad_token:
1223+
tokenizer.pad_token = tokenizer.bos_token
12181224

12191225
config = AutoConfig.from_pretrained(predictor_args.model_name_or_path)
12201226

@@ -1310,10 +1316,13 @@ def create_predictor(
13101316
config.use_cachekv_int8 = predictor_args.use_cachekv_int8
13111317
config.single_card_ptq = True
13121318

1313-
if predictor_args.quant_type is not None and predictor_args.quant_type.startswith("weight_only_int"):
1314-
weight_only_quant_bits = int(predictor_args.quant_type[-1])
1315-
config.weight_only_quant_bits = weight_only_quant_bits
1316-
config.quant_type = predictor_args.quant_type
1319+
if predictor_args.quant_type is not None:
1320+
if predictor_args.quant_type.startswith("weight_only_int"):
1321+
weight_only_quant_bits = int(predictor_args.quant_type[-1])
1322+
config.weight_only_quant_bits = weight_only_quant_bits
1323+
config.quant_type = predictor_args.quant_type
1324+
elif predictor_args.quant_type == "a8w8":
1325+
config.quant_type = predictor_args.quant_type
13171326

13181327
if config.quantization_config.quant_type is not None and "a8w8" in config.quantization_config.quant_type:
13191328
config.model_name_or_path = predictor_args.model_name_or_path

paddlenlp/experimental/transformers/fused_transformer_layers.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
443443
cache_k_scale = None
444444
if cache_k_scale_attr:
445445
cache_k_scale = self.create_parameter(
446-
shape=[self.num_heads],
446+
shape=[self.kv_num_heads],
447447
attr=cache_k_scale_attr,
448448
dtype="float32",
449449
is_bias=False,
@@ -452,7 +452,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
452452
cache_v_scale = None
453453
if cache_v_scale_attr:
454454
cache_v_scale = self.create_parameter(
455-
shape=[self.num_heads],
455+
shape=[self.kv_num_heads],
456456
attr=cache_v_scale_attr,
457457
dtype="float32",
458458
is_bias=False,
@@ -461,7 +461,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
461461
cache_k_out_scale = None
462462
if cache_k_out_scale_attr:
463463
cache_k_out_scale = self.create_parameter(
464-
shape=[self.num_heads],
464+
shape=[self.kv_num_heads],
465465
attr=cache_k_out_scale_attr,
466466
dtype="float32",
467467
is_bias=False,
@@ -470,7 +470,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
470470
cache_v_out_scale = None
471471
if cache_v_out_scale_attr:
472472
cache_v_out_scale = self.create_parameter(
473-
shape=[self.num_heads],
473+
shape=[self.kv_num_heads],
474474
attr=cache_v_out_scale_attr,
475475
dtype="float32",
476476
is_bias=False,
@@ -549,7 +549,7 @@ def init_weight_shape(self, config):
549549
self.qkv_weight_shape = (
550550
[(self.num_heads + 2 * self.kv_num_heads) * self.head_dim, self.embed_dim]
551551
if config.trans_qkvw
552-
else [(self.num_heads + 2 * self.kv_num_heads) * self.head_dim, self.embed_dim]
552+
else [self.embed_dim, (self.num_heads + 2 * self.kv_num_heads) * self.head_dim]
553553
)
554554
self.linear_weight_shape = [self.num_heads * self.head_dim, self.embed_dim]
555555
self.ffn1_weight_shape = (
@@ -1075,7 +1075,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
10751075
ffn2_smooth_attr = self.get_attr(config.ffn2_smooth_attrs, i)
10761076

10771077
qkv_out_scale = self.create_parameter(
1078-
shape=[self.head_dim * 3 * self.num_heads],
1078+
shape=[self.head_dim * (2 * self.kv_num_heads + self.num_heads)],
10791079
attr=qkv_out_scale_attr,
10801080
dtype="float32",
10811081
is_bias=False,

paddlenlp/experimental/transformers/llama/modeling.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,15 @@ def __init__(self, config: LlamaConfig):
9696
self.vocab_size = config.vocab_size
9797
self.hidden_size = config.hidden_size
9898
self.num_attention_heads = config.num_attention_heads
99+
self.num_key_value_heads = config.num_key_value_heads
99100
self.intermediate_size = config.intermediate_size
100101
self.num_layers = config.num_hidden_layers
101102
self.epsilon = config.rms_norm_eps
102103
self.max_position_embeddings = config.max_position_embeddings
103104
self.quant_type = config.quant_type
104105

106+
self.rope_theta = config.rope_theta
107+
105108
self.use_weight_only = False
106109
self.weight_only_quant_bits = config.weight_only_quant_bits
107110

@@ -188,8 +191,6 @@ def __init__(self, config: LlamaConfig):
188191
ffn2_bias_attrs = None
189192

190193
if self.quant_type == "a8w8":
191-
self.quant_round_type = config.quantization_config.quant_round_type
192-
193194
qkv_out_scale_attrs = [
194195
paddle.ParamAttr(name="fusellama.{}.qkv_out_scale".format(i)) for i in range(self.num_layers)
195196
]
@@ -277,9 +278,10 @@ def __init__(self, config: LlamaConfig):
277278
]
278279

279280
transformer_config = FusedMultiTransformerConfig(
280-
self.hidden_size,
281-
self.num_attention_heads,
282-
self.intermediate_size,
281+
embed_dim=self.hidden_size,
282+
num_heads=self.num_attention_heads,
283+
kv_num_heads=self.num_key_value_heads,
284+
dim_feedforward=self.intermediate_size,
283285
weight_only_quant_bits=self.weight_only_quant_bits,
284286
activation="swiglu",
285287
num_layers=config.num_hidden_layers,
@@ -430,13 +432,12 @@ def forward(
430432
seq_lens = seq_len_decoder if is_decoder else seq_len_encoder
431433

432434
position_offset = 0
433-
theta = 10000.0
434435
if not is_decoder and pre_caches is not None:
435436
position_offset = 128
436437
from paddlenlp_ops import fused_get_rotary_embedding
437438

438439
new_rope = fused_get_rotary_embedding(
439-
input_ids, position_ids, self.head_dim_shape_tensor, position_offset, theta, True
440+
input_ids, position_ids, self.head_dim_shape_tensor, position_offset, self.rope_theta, True
440441
)
441442

442443
with dy2st_nocheck_guard_context():
@@ -491,7 +492,7 @@ def set_state_dict(self, state_dict):
491492
state_dict["llama.layers.{}.self_attn.qkv_proj.weight".format(idx)],
492493
is_qkv=True,
493494
num_heads=self.num_attention_heads // self.config.tensor_parallel_degree,
494-
num_key_value_heads=self.num_attention_heads // self.config.tensor_parallel_degree,
495+
num_key_value_heads=self.num_key_value_heads // self.config.tensor_parallel_degree,
495496
),
496497
axis=-1,
497498
).transpose(1, 0)
@@ -517,10 +518,14 @@ def set_state_dict(self, state_dict):
517518
)
518519
.transpose(1, 0)
519520
.reshape(
520-
3 * (self.num_attention_heads // self.config.tensor_parallel_degree) * (head_size),
521+
(
522+
self.num_attention_heads // self.config.tensor_parallel_degree
523+
+ 2 * self.num_key_value_heads // self.config.tensor_parallel_degree
524+
)
525+
* (head_size),
521526
self.hidden_size,
522527
)
523-
) # reshape(3, self.num_attention_heself.hidden_sizeads // self.config.tensor_parallel_degree, head_size, )
528+
)
524529
if "llama.layers.{}.mlp.gate_up_fused_proj.weight".format(idx) in state_dict.keys():
525530
concated_ffn1_weight = np.concatenate(
526531
split_fn(state_dict["llama.layers.{}.mlp.gate_up_fused_proj.weight".format(idx)]), axis=-1
@@ -744,7 +749,7 @@ def set_state_dict(self, state_dict):
744749
cache_scale_json_path,
745750
cache_scale_map_dict,
746751
num_of_layers=self.config.num_hidden_layers,
747-
num_heads=self.num_attention_heads // self.config.tensor_parallel_degree,
752+
num_heads=self.num_key_value_heads // self.config.tensor_parallel_degree,
748753
)
749754
for k, v in cache_scales_loader.scale.items():
750755
for i_layer, weight_scale in enumerate(v):
@@ -919,7 +924,7 @@ def get_cache_kvs_shape(
919924
[
920925
2,
921926
max_batch_size,
922-
config.num_attention_heads // max(config.tensor_parallel_degree, 1),
927+
config.num_key_value_heads // max(config.tensor_parallel_degree, 1),
923928
max_length,
924929
config.hidden_size // config.num_attention_heads,
925930
]
@@ -1205,7 +1210,7 @@ def get_cache_kvs_shape(
12051210
for _ in range(config.num_hidden_layers):
12061211
cache_kv_shape = [
12071212
max_block_nums,
1208-
config.num_attention_heads // max(config.tensor_parallel_degree, 1),
1213+
config.num_key_value_heads // max(config.tensor_parallel_degree, 1),
12091214
config.block_size,
12101215
config.hidden_size // config.num_attention_heads,
12111216
]

0 commit comments

Comments
 (0)