Skip to content

[LLM Inference] Support Qwen2_Moe Inference Model #8892

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions llm/predict/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,6 +1359,32 @@ def create_predictor(
dtype=predictor_args.dtype,
)
model.eval()
elif "qwen2moe" in config.architectures[0].lower():
if predictor_args.block_attn:
config.max_seq_len = predictor_args.total_max_length
config.block_size = predictor_args.block_size
from paddlenlp.experimental.transformers import (
Qwen2MoeForCausalLMBlockInferenceModel as Qwen2MoeInferenceModel,
)

model = Qwen2MoeInferenceModel.from_pretrained(
predictor_args.model_name_or_path,
config=config,
dtype=predictor_args.dtype,
tensor_parallel_degree=tensor_parallel_degree,
tensor_parallel_rank=tensor_parallel_rank,
)
else:
from paddlenlp.experimental.transformers import (
Qwen2MoeForCausalLMInferenceModel as Qwen2MoeInferenceModel,
)

model = Qwen2MoeInferenceModel.from_pretrained(
predictor_args.model_name_or_path,
config=config,
dtype=predictor_args.dtype,
)
model.eval()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里代码是不是可以梳理设计下,每新增一个模型都需要增加相关的模型初始化方式

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

恩恩,这个工作在计划中,预计九月份有结论

elif "qwen2" in config.architectures[0].lower():
if predictor_args.block_attn:
config.max_seq_len = predictor_args.total_max_length
Expand Down Expand Up @@ -1495,6 +1521,20 @@ def create_predictor(
cache_kvs_shape = GPTForCausalLMInferenceModel.get_cache_kvs_shape(
config, predictor_args.batch_size, predictor_args.total_max_length
)
elif "qwen2moe" in config.architectures[0].lower():
if predictor_args.block_attn:
config.block_size = predictor_args.block_size
config.max_seq_len = predictor_args.total_max_length
from paddlenlp.experimental.transformers import (
Qwen2MoeForCausalLMBlockInferenceModel as Qwen2MoeInferenceModel,
)
else:
from paddlenlp.experimental.transformers import (
Qwen2MoeForCausalLMInferenceModel as Qwen2MoeInferenceModel,
)
cache_kvs_shape = Qwen2MoeInferenceModel.get_cache_kvs_shape(
config, predictor_args.batch_size, predictor_args.total_max_length
)
elif "qwen2" in config.architectures[0].lower():
if predictor_args.block_attn:
config.block_size = predictor_args.block_size
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/experimental/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@
from .opt import *
from .qwen import *
from .qwen2 import *
from .qwen2_moe import *

Check warning on line 25 in paddlenlp/experimental/transformers/__init__.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/__init__.py#L25

Added line #L25 was not covered by tests
165 changes: 158 additions & 7 deletions paddlenlp/experimental/transformers/fused_transformer_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import List, Optional

Check warning on line 17 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L17

Added line #L17 was not covered by tests

import paddle
import paddle.distributed as dist
Expand Down Expand Up @@ -157,12 +158,25 @@
norm_topk_prob: bool = True
moe_every2: bool = False

shared_expert_intermediate_size: int = 0
shared_expert_ffn1_weight_attrs: Optional[List[paddle.ParamAttr]] = None
shared_expert_ffn1_weight_scale_attrs: Optional[List[paddle.ParamAttr]] = None
shared_expert_ffn2_weight_attrs: Optional[List[paddle.ParamAttr]] = None
shared_expert_ffn2_weight_scale_attrs: Optional[List[paddle.ParamAttr]] = None
shared_expert_gate_weight_attrs: Optional[List[paddle.ParamAttr]] = None

Check warning on line 166 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L161-L166

Added lines #L161 - L166 were not covered by tests

def has_moe(self) -> bool:
return self.num_experts > 1

def use_moe(self, i: int) -> bool:
return self.has_moe() and (self.moe_every2 is False or (self.moe_every2 and i % 2 == 1))

def has_shared_expert(self) -> bool:
return self.has_moe() and self.shared_expert_intermediate_size > 0

Check warning on line 175 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L174-L175

Added lines #L174 - L175 were not covered by tests

def use_shared_expert(self, i: int) -> bool:
return self.use_moe(i) and self.shared_expert_intermediate_size > 0

Check warning on line 178 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L177-L178

Added lines #L177 - L178 were not covered by tests


class FusedMultiTransformerConfig:
def __init__(
Expand Down Expand Up @@ -342,9 +356,15 @@
self.gate_weights = []
self.ffn1_weights, self.ffn1_biases = [], []
self.ffn2_weights, self.ffn2_biases = [], []
if self.config.moe_config.has_shared_expert():
self.shared_expert_gate_weights = []
self.shared_expert_ffn1_weights = []
self.shared_expert_ffn2_weights = []

Check warning on line 362 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L359-L362

Added lines #L359 - L362 were not covered by tests
self.cache_k_scales, self.cache_v_scales = [], []
self.cache_k_out_scales, self.cache_v_out_scales = [], []

self.init_weight_shape(config)

Check warning on line 366 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L366

Added line #L366 was not covered by tests

for i in range(self.num_layers):
ln_scale_attr = self.get_attr(config.ln_scale_attrs, i)
ln_bias_attr = self.get_attr(config.ln_bias_attrs, i)
Expand All @@ -362,6 +382,11 @@
ffn2_weight_attr = self.get_attr(config.ffn2_weight_attrs, i)
ffn2_bias_attr = self.get_attr(config.ffn2_bias_attrs, i)

if self.config.moe_config.use_shared_expert(i):
shared_expert_gate_weight_attr = self.get_attr(config.moe_config.shared_expert_gate_weight_attrs, i)
shared_expert_ffn1_weight_attr = self.get_attr(config.moe_config.shared_expert_ffn1_weight_attrs, i)
shared_expert_ffn2_weight_attr = self.get_attr(config.moe_config.shared_expert_ffn2_weight_attrs, i)

Check warning on line 388 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L385-L388

Added lines #L385 - L388 were not covered by tests

cache_k_scale_attr = self.get_attr(config.cache_k_scale_attrs, i)
cache_v_scale_attr = self.get_attr(config.cache_v_scale_attrs, i)
cache_k_out_scale_attr = self.get_attr(config.cache_k_out_scale_attrs, i)
Expand All @@ -381,7 +406,6 @@
is_bias=True,
dtype=self._norm_weight_dtype,
)
self.init_weight_shape(config)

qkv_weight = self.create_parameter(
shape=self.qkv_weight_shape,
Expand Down Expand Up @@ -433,7 +457,8 @@
)

gate_weight = None
if config.moe_config.use_moe(i):

if self.config.moe_config.use_moe(i):

Check warning on line 461 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L461

Added line #L461 was not covered by tests
gate_weight = self.create_parameter(
shape=[config.embed_dim, self.config.moe_config.num_experts],
attr=gate_weight_attr,
Expand All @@ -442,7 +467,7 @@
default_initializer=paddle.nn.initializer.Constant(0),
)

if config.moe_config.use_moe(i):
if self.config.moe_config.use_moe(i):

Check warning on line 470 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L470

Added line #L470 was not covered by tests
ffn1_weight = self.create_parameter(
shape=self.moe_ffn1_weight_shape,
attr=ffn1_weight_attr,
Expand Down Expand Up @@ -493,7 +518,7 @@

ffn2_bias = None
if ffn2_bias_attr:
if config.moe_config.use_moe(i):
if self.config.moe_config.use_moe(i):

Check warning on line 521 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L521

Added line #L521 was not covered by tests
ffn2_bias = self.create_parameter(
shape=[self.config.moe_config.num_experts, config.embed_dim],
attr=ffn2_bias_attr,
Expand All @@ -508,6 +533,23 @@
is_bias=True,
)

if self.config.moe_config.use_shared_expert(i):
shared_expert_ffn1_weight = self.create_parameter(

Check warning on line 537 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L536-L537

Added lines #L536 - L537 were not covered by tests
shape=self.shared_expert_ffn1_weight_shape,
attr=shared_expert_ffn1_weight_attr,
dtype=self.create_params_type,
)
shared_expert_ffn2_weight = self.create_parameter(

Check warning on line 542 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L542

Added line #L542 was not covered by tests
shape=self.shared_expert_ffn2_weight_shape,
attr=shared_expert_ffn2_weight_attr,
dtype=self.create_params_type,
)
shared_expert_gate_weight = self.create_parameter(

Check warning on line 547 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L547

Added line #L547 was not covered by tests
shape=self.shared_expert_gate_weight_shape,
attr=shared_expert_gate_weight_attr,
dtype=self._helper.get_default_dtype(),
)

cache_k_scale = None
if cache_k_scale_attr:
cache_k_scale = self.create_parameter(
Expand Down Expand Up @@ -571,6 +613,11 @@
self.ffn2_weights.append(ffn2_weight)
self.ffn2_biases.append(ffn2_bias)

if self.config.moe_config.use_shared_expert(i):
self.shared_expert_ffn1_weights.append(shared_expert_ffn1_weight)
self.shared_expert_ffn2_weights.append(shared_expert_ffn2_weight)
self.shared_expert_gate_weights.append(shared_expert_gate_weight)

Check warning on line 619 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L616-L619

Added lines #L616 - L619 were not covered by tests

self.cache_k_scales.append(cache_k_scale)
self.cache_v_scales.append(cache_v_scale)
self.cache_k_out_scales.append(cache_k_out_scale)
Expand All @@ -592,6 +639,11 @@
self._add_parameter(ffn2_weight)
self._add_parameter(ffn2_bias)

if self.config.moe_config.use_shared_expert(i):
self._add_parameter(shared_expert_ffn1_weight)
self._add_parameter(shared_expert_ffn2_weight)
self._add_parameter(shared_expert_gate_weight)

Check warning on line 645 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L642-L645

Added lines #L642 - L645 were not covered by tests

self._add_parameter(cache_k_scale)
self._add_parameter(cache_v_scale)
self._add_parameter(cache_k_out_scale)
Expand Down Expand Up @@ -624,6 +676,7 @@
else [self.embed_dim, (self.num_heads + 2 * self.kv_num_heads) * self.head_dim]
)
self.linear_weight_shape = [self.num_heads * self.head_dim, self.embed_dim]

self.ffn1_weight_shape = (
[self.embed_dim, self.dim_feedforward * 2]
if self.activation.endswith("glu")
Expand All @@ -639,6 +692,20 @@
)
self.moe_ffn2_weight_shape = [self.config.moe_config.num_experts, self.dim_feedforward, self.embed_dim]

if self.config.moe_config.has_shared_expert():
self.shared_expert_ffn1_weight_shape = [

Check warning on line 696 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L695-L696

Added lines #L695 - L696 were not covered by tests
self.embed_dim,
self.config.moe_config.shared_expert_intermediate_size * 2,
]
self.shared_expert_ffn2_weight_shape = [

Check warning on line 700 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L700

Added line #L700 was not covered by tests
self.config.moe_config.shared_expert_intermediate_size,
self.embed_dim,
]
self.shared_expert_gate_weight_shape = [

Check warning on line 704 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L704

Added line #L704 was not covered by tests
self.embed_dim,
1,
]

def get_weight_create_dype(self):
return self._dtype

Expand Down Expand Up @@ -851,6 +918,15 @@
)[0]
return tmp_out, residual_input

def compute_shared_expert(self, tmp_out, i):
ffn1_out = paddle.matmul(tmp_out, self.shared_expert_ffn1_weights[i])
ffn1_out = fused_act_bias_wrapper(ffn1_out, None, act_method=self.activation)
ffn2_out = paddle.matmul(ffn1_out, self.shared_expert_ffn2_weights[i])
gate_out = paddle.matmul(tmp_out, self.shared_expert_gate_weights[i])
gate_out = paddle.nn.functional.sigmoid(gate_out)
shared_expert_output = gate_out * ffn2_out
return shared_expert_output

Check warning on line 928 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L921-L928

Added lines #L921 - L928 were not covered by tests

def pre_process(self, **kwargs):
pass

Expand Down Expand Up @@ -962,6 +1038,10 @@
# fused moe
ffn2_out = self.compute_fused_moe(tmp_out, i)

# shared_expert
if self.config.moe_config.use_shared_expert(i):
shared_expert_out = self.compute_shared_expert(tmp_out, i)
ffn2_out = ffn2_out + shared_expert_out

Check warning on line 1044 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1042-L1044

Added lines #L1042 - L1044 were not covered by tests
else:
# ffn1 matmul
ffn1_out = self.compute_ffn1(tmp_out, i)
Expand Down Expand Up @@ -1046,13 +1126,25 @@
self.ffn1_weights_scale = []
self.ffn2_weights_scale = []

if self.config.moe_config.has_shared_expert():
self.shared_expert_ffn1_weights_scale = []
self.shared_expert_ffn2_weights_scale = []

Check warning on line 1131 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1129-L1131

Added lines #L1129 - L1131 were not covered by tests

for i in range(self.num_layers):

qkv_weight_scale_attr = self.get_attr(config.qkv_weight_scale_attrs, i)
linear_weight_scale_attr = self.get_attr(config.linear_weight_scale_attrs, i)
ffn1_weight_scale_attr = self.get_attr(config.ffn1_weight_scale_attrs, i)
ffn2_weight_scale_attr = self.get_attr(config.ffn2_weight_scale_attrs, i)

if self.config.moe_config.use_shared_expert(i):
shared_expert_ffn1_weight_scale_attr = self.get_attr(

Check warning on line 1141 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1140-L1141

Added lines #L1140 - L1141 were not covered by tests
config.moe_config.shared_expert_ffn1_weight_scale_attrs, i
)
shared_expert_ffn2_weight_scale_attr = self.get_attr(

Check warning on line 1144 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1144

Added line #L1144 was not covered by tests
config.moe_config.shared_expert_ffn2_weight_scale_attrs, i
)

qkv_weight_scale = self.create_parameter(
shape=[(self.num_heads + 2 * self.kv_num_heads) * self.head_dim],
attr=qkv_weight_scale_attr,
Expand All @@ -1069,9 +1161,9 @@

if self.config.moe_config.use_moe(i):
ffn1_weight_scale = self.create_parameter(
shape=[config.moe_config.num_experts, self.dim_feedforward * 2]
shape=[self.config.moe_config.num_experts, self.dim_feedforward * 2]
if config.activation.endswith("glu")
else [config.moe_config.num_experts, self.dim_feedforward],
else [self.config.moe_config.num_experts, self.dim_feedforward],
attr=ffn1_weight_scale_attr,
dtype=self.weight_scale_dtype,
is_bias=False,
Expand All @@ -1086,7 +1178,7 @@

if self.config.moe_config.use_moe(i):
ffn2_weight_scale = self.create_parameter(
shape=[config.moe_config.num_experts, self.embed_dim],
shape=[self.config.moe_config.num_experts, self.embed_dim],
attr=ffn2_weight_scale_attr,
dtype=self.weight_scale_dtype,
is_bias=False,
Expand All @@ -1099,16 +1191,38 @@
is_bias=False,
)

if self.config.moe_config.use_shared_expert(i):
shared_expert_ffn1_weight_scale = self.create_parameter(

Check warning on line 1195 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1194-L1195

Added lines #L1194 - L1195 were not covered by tests
shape=[self.config.moe_config.shared_expert_intermediate_size * 2],
attr=shared_expert_ffn1_weight_scale_attr,
dtype=self.weight_scale_dtype,
is_bias=False,
)
shared_expert_ffn2_weight_scale = self.create_parameter(

Check warning on line 1201 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1201

Added line #L1201 was not covered by tests
shape=[self.embed_dim],
attr=shared_expert_ffn2_weight_scale_attr,
dtype=self.weight_scale_dtype,
is_bias=False,
)

self.qkv_weights_scale.append(qkv_weight_scale)
self.linear_weights_scale.append(linear_weight_scale)
self.ffn1_weights_scale.append(ffn1_weight_scale)
self.ffn2_weights_scale.append(ffn2_weight_scale)

if self.config.moe_config.use_shared_expert(i):
self.shared_expert_ffn1_weights_scale.append(shared_expert_ffn1_weight_scale)
self.shared_expert_ffn2_weights_scale.append(shared_expert_ffn2_weight_scale)

Check warning on line 1215 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1213-L1215

Added lines #L1213 - L1215 were not covered by tests

self._add_parameter(qkv_weight_scale)
self._add_parameter(linear_weight_scale)
self._add_parameter(ffn1_weight_scale)
self._add_parameter(ffn2_weight_scale)

if self.config.moe_config.use_shared_expert(i):
self._add_parameter(shared_expert_ffn1_weight_scale)
self._add_parameter(shared_expert_ffn2_weight_scale)

Check warning on line 1224 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1222-L1224

Added lines #L1222 - L1224 were not covered by tests

def get_weight_create_dype(self):
return "int8" # If use weightonly int4, params dtype is int8, and one of the dimension will be half.

Expand Down Expand Up @@ -1141,6 +1255,20 @@
self.moe_ffn1_weight_shape[2] //= 2
self.moe_ffn2_weight_shape[2] //= 2

if self.config.moe_config.has_shared_expert():
self.shared_expert_ffn1_weight_shape = [

Check warning on line 1259 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1258-L1259

Added lines #L1258 - L1259 were not covered by tests
self.config.moe_config.shared_expert_intermediate_size * 2,
self.embed_dim,
]
self.shared_expert_ffn2_weight_shape = [

Check warning on line 1263 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1263

Added line #L1263 was not covered by tests
self.embed_dim,
self.config.moe_config.shared_expert_intermediate_size,
]
self.shared_expert_gate_weight_shape = [

Check warning on line 1267 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1267

Added line #L1267 was not covered by tests
self.embed_dim,
1,
]

def compute_qkv_linear(self, ln_out, i):
return weight_only_linear(
ln_out,
Expand Down Expand Up @@ -1197,6 +1325,29 @@
weight_dtype=self.weight_dtype,
)

def compute_shared_expert(self, tmp_out, i):
ffn1_out = weight_only_linear(

Check warning on line 1329 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1328-L1329

Added lines #L1328 - L1329 were not covered by tests
tmp_out,
weight=self.shared_expert_ffn1_weights[i],
weight_scale=self.shared_expert_ffn1_weights_scale[i],
weight_dtype=self.weight_dtype,
)

ffn1_out = fused_act_bias_wrapper(ffn1_out, None, act_method=self.activation)

Check warning on line 1336 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1336

Added line #L1336 was not covered by tests

ffn2_out = weight_only_linear(

Check warning on line 1338 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1338

Added line #L1338 was not covered by tests
ffn1_out,
weight=self.shared_expert_ffn2_weights[i],
weight_scale=self.shared_expert_ffn2_weights_scale[i],
weight_dtype=self.weight_dtype,
)

gate_out = paddle.matmul(tmp_out, self.shared_expert_gate_weights[i])
gate_out = paddle.nn.functional.sigmoid(gate_out)

Check warning on line 1346 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1345-L1346

Added lines #L1345 - L1346 were not covered by tests

shared_expert_output = gate_out * ffn2_out
return shared_expert_output

Check warning on line 1349 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1348-L1349

Added lines #L1348 - L1349 were not covered by tests


class FusedMultiTransformerWeightOnlyPostLayernorm(
FusedMultiTransformerWeightOnly, FusedMultiTransformerPostLayernorm
Expand Down
Loading
Loading