Skip to content

[LLM INFER] add rope_theta for block_multihead_attention #9334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions paddlenlp/experimental/transformers/fused_transformer_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@
"The paddlenlp_ops package is not installed. you can read the docs and install it by hand, "
"you can refer to: https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md"
)

if (
paddle.device.get_all_custom_device_type() is not None and len(paddle.device.get_all_custom_device_type()) > 0
) or core.is_compiled_with_cuda():
) or paddle.is_compiled_with_cuda():
from paddlenlp_ops import rebuild_padding_v2


Expand Down Expand Up @@ -147,6 +148,7 @@
activation="gelu",
norm_type="layernorm",
use_neox_rotary_style=False,
rope_theta=10000.0,
normalize_before=True,
ln_scale_attrs=None,
ln_bias_attrs=None,
Expand Down Expand Up @@ -210,7 +212,7 @@
self.dropout_rate = dropout_rate
self.activation = activation
self.norm_type = norm_type

self.rope_theta = rope_theta

Check warning on line 215 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L215

Added line #L215 was not covered by tests
self.use_neox_rotary_style = use_neox_rotary_style
self.normalize_before = normalize_before
self.ln_scale_attrs = ln_scale_attrs
Expand Down Expand Up @@ -2234,6 +2236,7 @@
quant_round_type=self.config.quant_round_type,
quant_max_bound=self.config.quant_max_bound,
quant_min_bound=self.config.quant_min_bound,
rope_theta=self.rope_theta,
)[0]
else:
k_quant_scales = kwargs.get("k_quant_scales", None)
Expand Down Expand Up @@ -2275,6 +2278,7 @@
quant_round_type=self.config.quant_round_type,
quant_max_bound=self.config.quant_max_bound,
quant_min_bound=self.config.quant_min_bound,
rope_theta=self.rope_theta,
)[0]

out_linear_out = self.compute_out_linear(fmha_out, i)
Expand Down Expand Up @@ -2420,6 +2424,7 @@
quant_min_bound=self.quant_min_bound,
out_scale=self.act_scales["out_linear_in_scale"][i],
compute_dtype=self._fuse_kernel_compute_dtype,
rope_theta=self.rope_theta,
)[0]

out_linear_out = self.compute_out_linear(fmha_out, i)
Expand Down Expand Up @@ -2932,6 +2937,7 @@
quant_max_bound=self.config.quant_max_bound,
quant_min_bound=self.config.quant_min_bound,
out_scale=self.act_scales.scale["out_linear_in_scale"][i],
rope_theta=self.rope_theta,
)[0]
out_linear_out = self.compute_out_linear(fmha_out, i)

Expand Down
3 changes: 3 additions & 0 deletions paddlenlp/experimental/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def __init__(self, config: LlamaConfig):
ffn2_bias_attrs=None,
norm_type="rmsnorm",
epsilon=self.epsilon,
rope_theta=self.rope_theta,
nranks=config.tensor_parallel_degree,
avx_config=avx_config,
)
Expand Down Expand Up @@ -629,6 +630,7 @@ def __init__(self, config: LlamaConfig):
ffn2_weight_attrs=ffn2_weight_attrs,
ffn2_bias_attrs=ffn2_bias_attrs,
epsilon=self.epsilon,
rope_theta=self.rope_theta,
norm_type="rmsnorm",
use_neox_rotary_style=self.use_neox,
rank_id=config.tensor_parallel_rank,
Expand Down Expand Up @@ -675,6 +677,7 @@ def __init__(self, config: LlamaConfig):
cache_k_out_scale_attrs=cache_k_out_scale_attrs,
cache_v_out_scale_attrs=cache_v_out_scale_attrs,
epsilon=self.epsilon,
rope_theta=self.rope_theta,
norm_type="rmsnorm",
use_neox_rotary_style=self.use_neox,
cachekv_int8_type=config.cachekv_int8_type,
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/experimental/transformers/mixtral/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def __init__(self, config: MixtralConfig):
cache_k_out_scale_attrs=cache_k_out_scale_attrs,
cache_v_out_scale_attrs=cache_v_out_scale_attrs,
epsilon=self.epsilon,
rope_theta=self.rope_theta,
norm_type="rmsnorm",
use_neox_rotary_style=self.use_neox,
cachekv_int8_type=config.cachekv_int8_type,
Expand Down
2 changes: 2 additions & 0 deletions paddlenlp/experimental/transformers/qwen2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def __init__(self, config: Qwen2Config):
ffn2_weight_attrs=ffn2_weight_attrs,
ffn2_bias_attrs=ffn2_bias_attrs,
epsilon=self.rms_norm_eps,
rope_theta=self.rope_theta,
norm_type="rmsnorm",
use_neox_rotary_style=self.use_neox,
rank_id=config.tensor_parallel_rank,
Expand Down Expand Up @@ -380,6 +381,7 @@ def __init__(self, config: Qwen2Config):
cache_k_out_scale_attrs=cache_k_out_scale_attrs,
cache_v_out_scale_attrs=cache_v_out_scale_attrs,
epsilon=self.rms_norm_eps,
rope_theta=self.rope_theta,
norm_type="rmsnorm",
use_neox_rotary_style=self.use_neox,
cachekv_int8_type=config.cachekv_int8_type,
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/experimental/transformers/qwen2_moe/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def __init__(self, config: Qwen2MoeConfig):
ffn2_weight_scale_attrs=ffn2_weight_scale_attrs,
qkv_bias_attrs=qkv_bias_attrs,
epsilon=self.rms_norm_eps,
rope_theta=self.rope_theta,
norm_type="rmsnorm",
use_neox_rotary_style=self.use_neox,
rank_id=config.tensor_parallel_rank,
Expand Down
Loading