Skip to content

inference support llama3(wint8|4/a8w8) #8630

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions llm/predict/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
AutoTokenizer,
ChatGLMTokenizer,
ChatGLMv2Tokenizer,
Llama3Tokenizer,
LlamaTokenizer,
PretrainedModel,
PretrainedTokenizer,
Expand Down Expand Up @@ -739,13 +740,18 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer):

self.architectures = self.model_config.architectures[0].lower()

self.dtype = config.dtype or self.model_config
self.dtype = config.dtype or self.model_config.dtype

self.total_max_length = config.src_length + config.max_length
self.block_size = config.block_size
self.pre_max_block_num = (self.total_max_length + config.block_size - 1) // config.block_size
self.max_block_nums = config.batch_size * self.pre_max_block_num

try:
self.rope_theta = self.model_config.rope_theta
except:
self.rope_theta = 10000.0

self.pre_cache_length = 0

if config.export_precache:
Expand Down Expand Up @@ -828,7 +834,7 @@ def init_inputs(self, config: PredictorArgument):
)
self.inputs["stop_nums"] = paddle.full(shape=[1], fill_value=config.batch_size, dtype="int64")
self.inputs["rope_emb"] = self._get_rotary_position_embedding(
paddle.arange(self.total_max_length).reshape((1, -1)), self.head_dim
paddle.arange(self.total_max_length).reshape((1, -1)), self.head_dim, self.rope_theta
)
eos_token_id = get_eos_token_id(self.tokenizer, self.generation_config)
if isinstance(eos_token_id, int):
Expand Down Expand Up @@ -895,7 +901,7 @@ def init_inputs(self, config: PredictorArgument):
self.inputs["free_list"] = paddle.to_tensor(free_list, dtype="int32")
self.inputs["free_list_len"] = paddle.full(shape=[1], fill_value=self.pre_max_block_num * 0.25, dtype="int32")

def _get_rotary_position_embedding(self, position_ids, head_dim):
def _get_rotary_position_embedding(self, position_ids, head_dim, rope_theta=10000.0):
"""
Pre-calculate rotary position embedding for position_ids.

Expand All @@ -908,7 +914,7 @@ def _get_rotary_position_embedding(self, position_ids, head_dim):
"""
bsz, max_seq_len = position_ids.shape[:2]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, head_dim), dtype="float32")
inv_freq = 10000 ** (-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim)
inv_freq = rope_theta ** (-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim)

# shape: [B, S, D/2]
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq)
Expand Down Expand Up @@ -1213,8 +1219,8 @@ def create_predictor(
init_chat_template(tokenizer, predictor_args.model_name_or_path, predictor_args.chat_template)

# TODO(wj-Mcat): fix llama tokenzier pad_token bug
if isinstance(tokenizer, LlamaTokenizer) and not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.unk_token
if (isinstance(tokenizer, (LlamaTokenizer, Llama3Tokenizer))) and not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.bos_token

config = AutoConfig.from_pretrained(predictor_args.model_name_or_path)

Expand Down Expand Up @@ -1310,10 +1316,13 @@ def create_predictor(
config.use_cachekv_int8 = predictor_args.use_cachekv_int8
config.single_card_ptq = True

if predictor_args.quant_type is not None and predictor_args.quant_type.startswith("weight_only_int"):
weight_only_quant_bits = int(predictor_args.quant_type[-1])
config.weight_only_quant_bits = weight_only_quant_bits
config.quant_type = predictor_args.quant_type
if predictor_args.quant_type is not None:
if predictor_args.quant_type.startswith("weight_only_int"):
weight_only_quant_bits = int(predictor_args.quant_type[-1])
config.weight_only_quant_bits = weight_only_quant_bits
config.quant_type = predictor_args.quant_type
elif predictor_args.quant_type == "a8w8":
config.quant_type = predictor_args.quant_type

if config.quantization_config.quant_type is not None and "a8w8" in config.quantization_config.quant_type:
config.model_name_or_path = predictor_args.model_name_or_path
Expand Down
12 changes: 6 additions & 6 deletions paddlenlp/experimental/transformers/fused_transformer_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
cache_k_scale = None
if cache_k_scale_attr:
cache_k_scale = self.create_parameter(
shape=[self.num_heads],
shape=[self.kv_num_heads],
attr=cache_k_scale_attr,
dtype="float32",
is_bias=False,
Expand All @@ -452,7 +452,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
cache_v_scale = None
if cache_v_scale_attr:
cache_v_scale = self.create_parameter(
shape=[self.num_heads],
shape=[self.kv_num_heads],
attr=cache_v_scale_attr,
dtype="float32",
is_bias=False,
Expand All @@ -461,7 +461,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
cache_k_out_scale = None
if cache_k_out_scale_attr:
cache_k_out_scale = self.create_parameter(
shape=[self.num_heads],
shape=[self.kv_num_heads],
attr=cache_k_out_scale_attr,
dtype="float32",
is_bias=False,
Expand All @@ -470,7 +470,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
cache_v_out_scale = None
if cache_v_out_scale_attr:
cache_v_out_scale = self.create_parameter(
shape=[self.num_heads],
shape=[self.kv_num_heads],
attr=cache_v_out_scale_attr,
dtype="float32",
is_bias=False,
Expand Down Expand Up @@ -549,7 +549,7 @@ def init_weight_shape(self, config):
self.qkv_weight_shape = (
[(self.num_heads + 2 * self.kv_num_heads) * self.head_dim, self.embed_dim]
if config.trans_qkvw
else [(self.num_heads + 2 * self.kv_num_heads) * self.head_dim, self.embed_dim]
else [self.embed_dim, (self.num_heads + 2 * self.kv_num_heads) * self.head_dim]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块shape为啥前后修改了?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为之前是错误的

)
self.linear_weight_shape = [self.num_heads * self.head_dim, self.embed_dim]
self.ffn1_weight_shape = (
Expand Down Expand Up @@ -1075,7 +1075,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
ffn2_smooth_attr = self.get_attr(config.ffn2_smooth_attrs, i)

qkv_out_scale = self.create_parameter(
shape=[self.head_dim * 3 * self.num_heads],
shape=[self.head_dim * (2 * self.kv_num_heads + self.num_heads)],
attr=qkv_out_scale_attr,
dtype="float32",
is_bias=False,
Expand Down
31 changes: 18 additions & 13 deletions paddlenlp/experimental/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,15 @@
self.vocab_size = config.vocab_size
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads

Check warning on line 99 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L99

Added line #L99 was not covered by tests
self.intermediate_size = config.intermediate_size
self.num_layers = config.num_hidden_layers
self.epsilon = config.rms_norm_eps
self.max_position_embeddings = config.max_position_embeddings
self.quant_type = config.quant_type

self.rope_theta = config.rope_theta

Check warning on line 106 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L106

Added line #L106 was not covered by tests

self.use_weight_only = False
self.weight_only_quant_bits = config.weight_only_quant_bits

Expand Down Expand Up @@ -188,8 +191,6 @@
ffn2_bias_attrs = None

if self.quant_type == "a8w8":
self.quant_round_type = config.quantization_config.quant_round_type

qkv_out_scale_attrs = [
paddle.ParamAttr(name="fusellama.{}.qkv_out_scale".format(i)) for i in range(self.num_layers)
]
Expand Down Expand Up @@ -277,9 +278,10 @@
]

transformer_config = FusedMultiTransformerConfig(
self.hidden_size,
self.num_attention_heads,
self.intermediate_size,
embed_dim=self.hidden_size,
num_heads=self.num_attention_heads,
kv_num_heads=self.num_key_value_heads,
dim_feedforward=self.intermediate_size,
weight_only_quant_bits=self.weight_only_quant_bits,
activation="swiglu",
num_layers=config.num_hidden_layers,
Expand Down Expand Up @@ -430,13 +432,12 @@
seq_lens = seq_len_decoder if is_decoder else seq_len_encoder

position_offset = 0
theta = 10000.0
if not is_decoder and pre_caches is not None:
position_offset = 128
from paddlenlp_ops import fused_get_rotary_embedding

new_rope = fused_get_rotary_embedding(
input_ids, position_ids, self.head_dim_shape_tensor, position_offset, theta, True
input_ids, position_ids, self.head_dim_shape_tensor, position_offset, self.rope_theta, True
)

with dy2st_nocheck_guard_context():
Expand Down Expand Up @@ -491,7 +492,7 @@
state_dict["llama.layers.{}.self_attn.qkv_proj.weight".format(idx)],
is_qkv=True,
num_heads=self.num_attention_heads // self.config.tensor_parallel_degree,
num_key_value_heads=self.num_attention_heads // self.config.tensor_parallel_degree,
num_key_value_heads=self.num_key_value_heads // self.config.tensor_parallel_degree,
),
axis=-1,
).transpose(1, 0)
Expand All @@ -517,10 +518,14 @@
)
.transpose(1, 0)
.reshape(
3 * (self.num_attention_heads // self.config.tensor_parallel_degree) * (head_size),
(
self.num_attention_heads // self.config.tensor_parallel_degree
+ 2 * self.num_key_value_heads // self.config.tensor_parallel_degree
)
* (head_size),
self.hidden_size,
)
) # reshape(3, self.num_attention_heself.hidden_sizeads // self.config.tensor_parallel_degree, head_size, )
)
if "llama.layers.{}.mlp.gate_up_fused_proj.weight".format(idx) in state_dict.keys():
concated_ffn1_weight = np.concatenate(
split_fn(state_dict["llama.layers.{}.mlp.gate_up_fused_proj.weight".format(idx)]), axis=-1
Expand Down Expand Up @@ -744,7 +749,7 @@
cache_scale_json_path,
cache_scale_map_dict,
num_of_layers=self.config.num_hidden_layers,
num_heads=self.num_attention_heads // self.config.tensor_parallel_degree,
num_heads=self.num_key_value_heads // self.config.tensor_parallel_degree,
)
for k, v in cache_scales_loader.scale.items():
for i_layer, weight_scale in enumerate(v):
Expand Down Expand Up @@ -919,7 +924,7 @@
[
2,
max_batch_size,
config.num_attention_heads // max(config.tensor_parallel_degree, 1),
config.num_key_value_heads // max(config.tensor_parallel_degree, 1),
max_length,
config.hidden_size // config.num_attention_heads,
]
Expand Down Expand Up @@ -1205,7 +1210,7 @@
for _ in range(config.num_hidden_layers):
cache_kv_shape = [
max_block_nums,
config.num_attention_heads // max(config.tensor_parallel_degree, 1),
config.num_key_value_heads // max(config.tensor_parallel_degree, 1),
config.block_size,
config.hidden_size // config.num_attention_heads,
]
Expand Down
Loading