diff --git a/llm/predict/predictor.py b/llm/predict/predictor.py index 262a21fa6a0b..f53344f388b6 100644 --- a/llm/predict/predictor.py +++ b/llm/predict/predictor.py @@ -51,6 +51,7 @@ AutoTokenizer, ChatGLMTokenizer, ChatGLMv2Tokenizer, + Llama3Tokenizer, LlamaTokenizer, PretrainedModel, PretrainedTokenizer, @@ -739,13 +740,18 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer): self.architectures = self.model_config.architectures[0].lower() - self.dtype = config.dtype or self.model_config + self.dtype = config.dtype or self.model_config.dtype self.total_max_length = config.src_length + config.max_length self.block_size = config.block_size self.pre_max_block_num = (self.total_max_length + config.block_size - 1) // config.block_size self.max_block_nums = config.batch_size * self.pre_max_block_num + try: + self.rope_theta = self.model_config.rope_theta + except: + self.rope_theta = 10000.0 + self.pre_cache_length = 0 if config.export_precache: @@ -828,7 +834,7 @@ def init_inputs(self, config: PredictorArgument): ) self.inputs["stop_nums"] = paddle.full(shape=[1], fill_value=config.batch_size, dtype="int64") self.inputs["rope_emb"] = self._get_rotary_position_embedding( - paddle.arange(self.total_max_length).reshape((1, -1)), self.head_dim + paddle.arange(self.total_max_length).reshape((1, -1)), self.head_dim, self.rope_theta ) eos_token_id = get_eos_token_id(self.tokenizer, self.generation_config) if isinstance(eos_token_id, int): @@ -895,7 +901,7 @@ def init_inputs(self, config: PredictorArgument): self.inputs["free_list"] = paddle.to_tensor(free_list, dtype="int32") self.inputs["free_list_len"] = paddle.full(shape=[1], fill_value=self.pre_max_block_num * 0.25, dtype="int32") - def _get_rotary_position_embedding(self, position_ids, head_dim): + def _get_rotary_position_embedding(self, position_ids, head_dim, rope_theta=10000.0): """ Pre-calculate rotary position embedding for position_ids. @@ -908,7 +914,7 @@ def _get_rotary_position_embedding(self, position_ids, head_dim): """ bsz, max_seq_len = position_ids.shape[:2] rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, head_dim), dtype="float32") - inv_freq = 10000 ** (-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim) + inv_freq = rope_theta ** (-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim) # shape: [B, S, D/2] freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq) @@ -1213,8 +1219,8 @@ def create_predictor( init_chat_template(tokenizer, predictor_args.model_name_or_path, predictor_args.chat_template) # TODO(wj-Mcat): fix llama tokenzier pad_token bug - if isinstance(tokenizer, LlamaTokenizer) and not tokenizer.pad_token: - tokenizer.pad_token = tokenizer.unk_token + if (isinstance(tokenizer, (LlamaTokenizer, Llama3Tokenizer))) and not tokenizer.pad_token: + tokenizer.pad_token = tokenizer.bos_token config = AutoConfig.from_pretrained(predictor_args.model_name_or_path) @@ -1310,10 +1316,13 @@ def create_predictor( config.use_cachekv_int8 = predictor_args.use_cachekv_int8 config.single_card_ptq = True - if predictor_args.quant_type is not None and predictor_args.quant_type.startswith("weight_only_int"): - weight_only_quant_bits = int(predictor_args.quant_type[-1]) - config.weight_only_quant_bits = weight_only_quant_bits - config.quant_type = predictor_args.quant_type + if predictor_args.quant_type is not None: + if predictor_args.quant_type.startswith("weight_only_int"): + weight_only_quant_bits = int(predictor_args.quant_type[-1]) + config.weight_only_quant_bits = weight_only_quant_bits + config.quant_type = predictor_args.quant_type + elif predictor_args.quant_type == "a8w8": + config.quant_type = predictor_args.quant_type if config.quantization_config.quant_type is not None and "a8w8" in config.quantization_config.quant_type: config.model_name_or_path = predictor_args.model_name_or_path diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index f6ba2d0373c8..72a71d21a49c 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -443,7 +443,7 @@ def __init__(self, config: FusedMultiTransformerConfig): cache_k_scale = None if cache_k_scale_attr: cache_k_scale = self.create_parameter( - shape=[self.num_heads], + shape=[self.kv_num_heads], attr=cache_k_scale_attr, dtype="float32", is_bias=False, @@ -452,7 +452,7 @@ def __init__(self, config: FusedMultiTransformerConfig): cache_v_scale = None if cache_v_scale_attr: cache_v_scale = self.create_parameter( - shape=[self.num_heads], + shape=[self.kv_num_heads], attr=cache_v_scale_attr, dtype="float32", is_bias=False, @@ -461,7 +461,7 @@ def __init__(self, config: FusedMultiTransformerConfig): cache_k_out_scale = None if cache_k_out_scale_attr: cache_k_out_scale = self.create_parameter( - shape=[self.num_heads], + shape=[self.kv_num_heads], attr=cache_k_out_scale_attr, dtype="float32", is_bias=False, @@ -470,7 +470,7 @@ def __init__(self, config: FusedMultiTransformerConfig): cache_v_out_scale = None if cache_v_out_scale_attr: cache_v_out_scale = self.create_parameter( - shape=[self.num_heads], + shape=[self.kv_num_heads], attr=cache_v_out_scale_attr, dtype="float32", is_bias=False, @@ -549,7 +549,7 @@ def init_weight_shape(self, config): self.qkv_weight_shape = ( [(self.num_heads + 2 * self.kv_num_heads) * self.head_dim, self.embed_dim] if config.trans_qkvw - else [(self.num_heads + 2 * self.kv_num_heads) * self.head_dim, self.embed_dim] + else [self.embed_dim, (self.num_heads + 2 * self.kv_num_heads) * self.head_dim] ) self.linear_weight_shape = [self.num_heads * self.head_dim, self.embed_dim] self.ffn1_weight_shape = ( @@ -1075,7 +1075,7 @@ def __init__(self, config: FusedMultiTransformerConfig): ffn2_smooth_attr = self.get_attr(config.ffn2_smooth_attrs, i) qkv_out_scale = self.create_parameter( - shape=[self.head_dim * 3 * self.num_heads], + shape=[self.head_dim * (2 * self.kv_num_heads + self.num_heads)], attr=qkv_out_scale_attr, dtype="float32", is_bias=False, diff --git a/paddlenlp/experimental/transformers/llama/modeling.py b/paddlenlp/experimental/transformers/llama/modeling.py index 9ec5661cb6ab..622371336215 100644 --- a/paddlenlp/experimental/transformers/llama/modeling.py +++ b/paddlenlp/experimental/transformers/llama/modeling.py @@ -96,12 +96,15 @@ def __init__(self, config: LlamaConfig): self.vocab_size = config.vocab_size self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads self.intermediate_size = config.intermediate_size self.num_layers = config.num_hidden_layers self.epsilon = config.rms_norm_eps self.max_position_embeddings = config.max_position_embeddings self.quant_type = config.quant_type + self.rope_theta = config.rope_theta + self.use_weight_only = False self.weight_only_quant_bits = config.weight_only_quant_bits @@ -188,8 +191,6 @@ def __init__(self, config: LlamaConfig): ffn2_bias_attrs = None if self.quant_type == "a8w8": - self.quant_round_type = config.quantization_config.quant_round_type - qkv_out_scale_attrs = [ paddle.ParamAttr(name="fusellama.{}.qkv_out_scale".format(i)) for i in range(self.num_layers) ] @@ -277,9 +278,10 @@ def __init__(self, config: LlamaConfig): ] transformer_config = FusedMultiTransformerConfig( - self.hidden_size, - self.num_attention_heads, - self.intermediate_size, + embed_dim=self.hidden_size, + num_heads=self.num_attention_heads, + kv_num_heads=self.num_key_value_heads, + dim_feedforward=self.intermediate_size, weight_only_quant_bits=self.weight_only_quant_bits, activation="swiglu", num_layers=config.num_hidden_layers, @@ -430,13 +432,12 @@ def forward( seq_lens = seq_len_decoder if is_decoder else seq_len_encoder position_offset = 0 - theta = 10000.0 if not is_decoder and pre_caches is not None: position_offset = 128 from paddlenlp_ops import fused_get_rotary_embedding new_rope = fused_get_rotary_embedding( - input_ids, position_ids, self.head_dim_shape_tensor, position_offset, theta, True + input_ids, position_ids, self.head_dim_shape_tensor, position_offset, self.rope_theta, True ) with dy2st_nocheck_guard_context(): @@ -491,7 +492,7 @@ def set_state_dict(self, state_dict): state_dict["llama.layers.{}.self_attn.qkv_proj.weight".format(idx)], is_qkv=True, num_heads=self.num_attention_heads // self.config.tensor_parallel_degree, - num_key_value_heads=self.num_attention_heads // self.config.tensor_parallel_degree, + num_key_value_heads=self.num_key_value_heads // self.config.tensor_parallel_degree, ), axis=-1, ).transpose(1, 0) @@ -517,10 +518,14 @@ def set_state_dict(self, state_dict): ) .transpose(1, 0) .reshape( - 3 * (self.num_attention_heads // self.config.tensor_parallel_degree) * (head_size), + ( + self.num_attention_heads // self.config.tensor_parallel_degree + + 2 * self.num_key_value_heads // self.config.tensor_parallel_degree + ) + * (head_size), self.hidden_size, ) - ) # reshape(3, self.num_attention_heself.hidden_sizeads // self.config.tensor_parallel_degree, head_size, ) + ) if "llama.layers.{}.mlp.gate_up_fused_proj.weight".format(idx) in state_dict.keys(): concated_ffn1_weight = np.concatenate( split_fn(state_dict["llama.layers.{}.mlp.gate_up_fused_proj.weight".format(idx)]), axis=-1 @@ -744,7 +749,7 @@ def set_state_dict(self, state_dict): cache_scale_json_path, cache_scale_map_dict, num_of_layers=self.config.num_hidden_layers, - num_heads=self.num_attention_heads // self.config.tensor_parallel_degree, + num_heads=self.num_key_value_heads // self.config.tensor_parallel_degree, ) for k, v in cache_scales_loader.scale.items(): for i_layer, weight_scale in enumerate(v): @@ -919,7 +924,7 @@ def get_cache_kvs_shape( [ 2, max_batch_size, - config.num_attention_heads // max(config.tensor_parallel_degree, 1), + config.num_key_value_heads // max(config.tensor_parallel_degree, 1), max_length, config.hidden_size // config.num_attention_heads, ] @@ -1205,7 +1210,7 @@ def get_cache_kvs_shape( for _ in range(config.num_hidden_layers): cache_kv_shape = [ max_block_nums, - config.num_attention_heads // max(config.tensor_parallel_degree, 1), + config.num_key_value_heads // max(config.tensor_parallel_degree, 1), config.block_size, config.hidden_size // config.num_attention_heads, ]