diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index 034a56e55de7..9f213a08cd04 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -41,9 +41,10 @@ "The paddlenlp_ops package is not installed. you can read the docs and install it by hand, " "you can refer to: https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md" ) + if ( paddle.device.get_all_custom_device_type() is not None and len(paddle.device.get_all_custom_device_type()) > 0 -) or core.is_compiled_with_cuda(): +) or paddle.is_compiled_with_cuda(): from paddlenlp_ops import rebuild_padding_v2 @@ -147,6 +148,7 @@ def __init__( activation="gelu", norm_type="layernorm", use_neox_rotary_style=False, + rope_theta=10000.0, normalize_before=True, ln_scale_attrs=None, ln_bias_attrs=None, @@ -210,7 +212,7 @@ def __init__( self.dropout_rate = dropout_rate self.activation = activation self.norm_type = norm_type - + self.rope_theta = rope_theta self.use_neox_rotary_style = use_neox_rotary_style self.normalize_before = normalize_before self.ln_scale_attrs = ln_scale_attrs @@ -2234,6 +2236,7 @@ def compute_attn( quant_round_type=self.config.quant_round_type, quant_max_bound=self.config.quant_max_bound, quant_min_bound=self.config.quant_min_bound, + rope_theta=self.rope_theta, )[0] else: k_quant_scales = kwargs.get("k_quant_scales", None) @@ -2275,6 +2278,7 @@ def compute_attn( quant_round_type=self.config.quant_round_type, quant_max_bound=self.config.quant_max_bound, quant_min_bound=self.config.quant_min_bound, + rope_theta=self.rope_theta, )[0] out_linear_out = self.compute_out_linear(fmha_out, i) @@ -2420,6 +2424,7 @@ def compute_attn( quant_min_bound=self.quant_min_bound, out_scale=self.act_scales["out_linear_in_scale"][i], compute_dtype=self._fuse_kernel_compute_dtype, + rope_theta=self.rope_theta, )[0] out_linear_out = self.compute_out_linear(fmha_out, i) @@ -2932,6 +2937,7 @@ def compute_attn( quant_max_bound=self.config.quant_max_bound, quant_min_bound=self.config.quant_min_bound, out_scale=self.act_scales.scale["out_linear_in_scale"][i], + rope_theta=self.rope_theta, )[0] out_linear_out = self.compute_out_linear(fmha_out, i) diff --git a/paddlenlp/experimental/transformers/llama/modeling.py b/paddlenlp/experimental/transformers/llama/modeling.py index 2386f61a33ea..8f163ea3e707 100644 --- a/paddlenlp/experimental/transformers/llama/modeling.py +++ b/paddlenlp/experimental/transformers/llama/modeling.py @@ -214,6 +214,7 @@ def __init__(self, config: LlamaConfig): ffn2_bias_attrs=None, norm_type="rmsnorm", epsilon=self.epsilon, + rope_theta=self.rope_theta, nranks=config.tensor_parallel_degree, avx_config=avx_config, ) @@ -629,6 +630,7 @@ def __init__(self, config: LlamaConfig): ffn2_weight_attrs=ffn2_weight_attrs, ffn2_bias_attrs=ffn2_bias_attrs, epsilon=self.epsilon, + rope_theta=self.rope_theta, norm_type="rmsnorm", use_neox_rotary_style=self.use_neox, rank_id=config.tensor_parallel_rank, @@ -675,6 +677,7 @@ def __init__(self, config: LlamaConfig): cache_k_out_scale_attrs=cache_k_out_scale_attrs, cache_v_out_scale_attrs=cache_v_out_scale_attrs, epsilon=self.epsilon, + rope_theta=self.rope_theta, norm_type="rmsnorm", use_neox_rotary_style=self.use_neox, cachekv_int8_type=config.cachekv_int8_type, diff --git a/paddlenlp/experimental/transformers/mixtral/modeling.py b/paddlenlp/experimental/transformers/mixtral/modeling.py index ca3978bf3a9f..b7c40b761108 100644 --- a/paddlenlp/experimental/transformers/mixtral/modeling.py +++ b/paddlenlp/experimental/transformers/mixtral/modeling.py @@ -334,6 +334,7 @@ def __init__(self, config: MixtralConfig): cache_k_out_scale_attrs=cache_k_out_scale_attrs, cache_v_out_scale_attrs=cache_v_out_scale_attrs, epsilon=self.epsilon, + rope_theta=self.rope_theta, norm_type="rmsnorm", use_neox_rotary_style=self.use_neox, cachekv_int8_type=config.cachekv_int8_type, diff --git a/paddlenlp/experimental/transformers/qwen2/modeling.py b/paddlenlp/experimental/transformers/qwen2/modeling.py index eecbed055227..f945e244faf6 100644 --- a/paddlenlp/experimental/transformers/qwen2/modeling.py +++ b/paddlenlp/experimental/transformers/qwen2/modeling.py @@ -334,6 +334,7 @@ def __init__(self, config: Qwen2Config): ffn2_weight_attrs=ffn2_weight_attrs, ffn2_bias_attrs=ffn2_bias_attrs, epsilon=self.rms_norm_eps, + rope_theta=self.rope_theta, norm_type="rmsnorm", use_neox_rotary_style=self.use_neox, rank_id=config.tensor_parallel_rank, @@ -380,6 +381,7 @@ def __init__(self, config: Qwen2Config): cache_k_out_scale_attrs=cache_k_out_scale_attrs, cache_v_out_scale_attrs=cache_v_out_scale_attrs, epsilon=self.rms_norm_eps, + rope_theta=self.rope_theta, norm_type="rmsnorm", use_neox_rotary_style=self.use_neox, cachekv_int8_type=config.cachekv_int8_type, diff --git a/paddlenlp/experimental/transformers/qwen2_moe/modeling.py b/paddlenlp/experimental/transformers/qwen2_moe/modeling.py index 709b8d47fd13..baafc0d41b5c 100644 --- a/paddlenlp/experimental/transformers/qwen2_moe/modeling.py +++ b/paddlenlp/experimental/transformers/qwen2_moe/modeling.py @@ -252,6 +252,7 @@ def __init__(self, config: Qwen2MoeConfig): ffn2_weight_scale_attrs=ffn2_weight_scale_attrs, qkv_bias_attrs=qkv_bias_attrs, epsilon=self.rms_norm_eps, + rope_theta=self.rope_theta, norm_type="rmsnorm", use_neox_rotary_style=self.use_neox, rank_id=config.tensor_parallel_rank,