Skip to content

Commit a2bf616

Browse files
CJ77Qiyuanlehome
authored andcommitted
supprot qwen-moe (PaddlePaddle#8892)
Co-authored-by: yuanlehome <yuanlehome@163.com>
1 parent 63e059f commit a2bf616

File tree

5 files changed

+1197
-7
lines changed

5 files changed

+1197
-7
lines changed

llm/predict/predictor.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1359,6 +1359,32 @@ def create_predictor(
13591359
dtype=predictor_args.dtype,
13601360
)
13611361
model.eval()
1362+
elif "qwen2moe" in config.architectures[0].lower():
1363+
if predictor_args.block_attn:
1364+
config.max_seq_len = predictor_args.total_max_length
1365+
config.block_size = predictor_args.block_size
1366+
from paddlenlp.experimental.transformers import (
1367+
Qwen2MoeForCausalLMBlockInferenceModel as Qwen2MoeInferenceModel,
1368+
)
1369+
1370+
model = Qwen2MoeInferenceModel.from_pretrained(
1371+
predictor_args.model_name_or_path,
1372+
config=config,
1373+
dtype=predictor_args.dtype,
1374+
tensor_parallel_degree=tensor_parallel_degree,
1375+
tensor_parallel_rank=tensor_parallel_rank,
1376+
)
1377+
else:
1378+
from paddlenlp.experimental.transformers import (
1379+
Qwen2MoeForCausalLMInferenceModel as Qwen2MoeInferenceModel,
1380+
)
1381+
1382+
model = Qwen2MoeInferenceModel.from_pretrained(
1383+
predictor_args.model_name_or_path,
1384+
config=config,
1385+
dtype=predictor_args.dtype,
1386+
)
1387+
model.eval()
13621388
elif "qwen2" in config.architectures[0].lower():
13631389
if predictor_args.block_attn:
13641390
config.max_seq_len = predictor_args.total_max_length
@@ -1495,6 +1521,20 @@ def create_predictor(
14951521
cache_kvs_shape = GPTForCausalLMInferenceModel.get_cache_kvs_shape(
14961522
config, predictor_args.batch_size, predictor_args.total_max_length
14971523
)
1524+
elif "qwen2moe" in config.architectures[0].lower():
1525+
if predictor_args.block_attn:
1526+
config.block_size = predictor_args.block_size
1527+
config.max_seq_len = predictor_args.total_max_length
1528+
from paddlenlp.experimental.transformers import (
1529+
Qwen2MoeForCausalLMBlockInferenceModel as Qwen2MoeInferenceModel,
1530+
)
1531+
else:
1532+
from paddlenlp.experimental.transformers import (
1533+
Qwen2MoeForCausalLMInferenceModel as Qwen2MoeInferenceModel,
1534+
)
1535+
cache_kvs_shape = Qwen2MoeInferenceModel.get_cache_kvs_shape(
1536+
config, predictor_args.batch_size, predictor_args.total_max_length
1537+
)
14981538
elif "qwen2" in config.architectures[0].lower():
14991539
if predictor_args.block_attn:
15001540
config.block_size = predictor_args.block_size

paddlenlp/experimental/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@
2222
from .opt import *
2323
from .qwen import *
2424
from .qwen2 import *
25+
from .qwen2_moe import *

paddlenlp/experimental/transformers/fused_transformer_layers.py

Lines changed: 158 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import annotations
1515

1616
from dataclasses import dataclass
17+
from typing import List, Optional
1718

1819
import paddle
1920
import paddle.distributed as dist
@@ -157,12 +158,25 @@ class MoeConfig:
157158
norm_topk_prob: bool = True
158159
moe_every2: bool = False
159160

161+
shared_expert_intermediate_size: int = 0
162+
shared_expert_ffn1_weight_attrs: Optional[List[paddle.ParamAttr]] = None
163+
shared_expert_ffn1_weight_scale_attrs: Optional[List[paddle.ParamAttr]] = None
164+
shared_expert_ffn2_weight_attrs: Optional[List[paddle.ParamAttr]] = None
165+
shared_expert_ffn2_weight_scale_attrs: Optional[List[paddle.ParamAttr]] = None
166+
shared_expert_gate_weight_attrs: Optional[List[paddle.ParamAttr]] = None
167+
160168
def has_moe(self) -> bool:
161169
return self.num_experts > 1
162170

163171
def use_moe(self, i: int) -> bool:
164172
return self.has_moe() and (self.moe_every2 is False or (self.moe_every2 and i % 2 == 1))
165173

174+
def has_shared_expert(self) -> bool:
175+
return self.has_moe() and self.shared_expert_intermediate_size > 0
176+
177+
def use_shared_expert(self, i: int) -> bool:
178+
return self.use_moe(i) and self.shared_expert_intermediate_size > 0
179+
166180

167181
class FusedMultiTransformerConfig:
168182
def __init__(
@@ -342,9 +356,15 @@ def __init__(self, config: FusedMultiTransformerConfig):
342356
self.gate_weights = []
343357
self.ffn1_weights, self.ffn1_biases = [], []
344358
self.ffn2_weights, self.ffn2_biases = [], []
359+
if self.config.moe_config.has_shared_expert():
360+
self.shared_expert_gate_weights = []
361+
self.shared_expert_ffn1_weights = []
362+
self.shared_expert_ffn2_weights = []
345363
self.cache_k_scales, self.cache_v_scales = [], []
346364
self.cache_k_out_scales, self.cache_v_out_scales = [], []
347365

366+
self.init_weight_shape(config)
367+
348368
for i in range(self.num_layers):
349369
ln_scale_attr = self.get_attr(config.ln_scale_attrs, i)
350370
ln_bias_attr = self.get_attr(config.ln_bias_attrs, i)
@@ -362,6 +382,11 @@ def __init__(self, config: FusedMultiTransformerConfig):
362382
ffn2_weight_attr = self.get_attr(config.ffn2_weight_attrs, i)
363383
ffn2_bias_attr = self.get_attr(config.ffn2_bias_attrs, i)
364384

385+
if self.config.moe_config.use_shared_expert(i):
386+
shared_expert_gate_weight_attr = self.get_attr(config.moe_config.shared_expert_gate_weight_attrs, i)
387+
shared_expert_ffn1_weight_attr = self.get_attr(config.moe_config.shared_expert_ffn1_weight_attrs, i)
388+
shared_expert_ffn2_weight_attr = self.get_attr(config.moe_config.shared_expert_ffn2_weight_attrs, i)
389+
365390
cache_k_scale_attr = self.get_attr(config.cache_k_scale_attrs, i)
366391
cache_v_scale_attr = self.get_attr(config.cache_v_scale_attrs, i)
367392
cache_k_out_scale_attr = self.get_attr(config.cache_k_out_scale_attrs, i)
@@ -381,7 +406,6 @@ def __init__(self, config: FusedMultiTransformerConfig):
381406
is_bias=True,
382407
dtype=self._norm_weight_dtype,
383408
)
384-
self.init_weight_shape(config)
385409

386410
qkv_weight = self.create_parameter(
387411
shape=self.qkv_weight_shape,
@@ -433,7 +457,8 @@ def __init__(self, config: FusedMultiTransformerConfig):
433457
)
434458

435459
gate_weight = None
436-
if config.moe_config.use_moe(i):
460+
461+
if self.config.moe_config.use_moe(i):
437462
gate_weight = self.create_parameter(
438463
shape=[config.embed_dim, self.config.moe_config.num_experts],
439464
attr=gate_weight_attr,
@@ -442,7 +467,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
442467
default_initializer=paddle.nn.initializer.Constant(0),
443468
)
444469

445-
if config.moe_config.use_moe(i):
470+
if self.config.moe_config.use_moe(i):
446471
ffn1_weight = self.create_parameter(
447472
shape=self.moe_ffn1_weight_shape,
448473
attr=ffn1_weight_attr,
@@ -493,7 +518,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
493518

494519
ffn2_bias = None
495520
if ffn2_bias_attr:
496-
if config.moe_config.use_moe(i):
521+
if self.config.moe_config.use_moe(i):
497522
ffn2_bias = self.create_parameter(
498523
shape=[self.config.moe_config.num_experts, config.embed_dim],
499524
attr=ffn2_bias_attr,
@@ -508,6 +533,23 @@ def __init__(self, config: FusedMultiTransformerConfig):
508533
is_bias=True,
509534
)
510535

536+
if self.config.moe_config.use_shared_expert(i):
537+
shared_expert_ffn1_weight = self.create_parameter(
538+
shape=self.shared_expert_ffn1_weight_shape,
539+
attr=shared_expert_ffn1_weight_attr,
540+
dtype=self.create_params_type,
541+
)
542+
shared_expert_ffn2_weight = self.create_parameter(
543+
shape=self.shared_expert_ffn2_weight_shape,
544+
attr=shared_expert_ffn2_weight_attr,
545+
dtype=self.create_params_type,
546+
)
547+
shared_expert_gate_weight = self.create_parameter(
548+
shape=self.shared_expert_gate_weight_shape,
549+
attr=shared_expert_gate_weight_attr,
550+
dtype=self._helper.get_default_dtype(),
551+
)
552+
511553
cache_k_scale = None
512554
if cache_k_scale_attr:
513555
cache_k_scale = self.create_parameter(
@@ -571,6 +613,11 @@ def __init__(self, config: FusedMultiTransformerConfig):
571613
self.ffn2_weights.append(ffn2_weight)
572614
self.ffn2_biases.append(ffn2_bias)
573615

616+
if self.config.moe_config.use_shared_expert(i):
617+
self.shared_expert_ffn1_weights.append(shared_expert_ffn1_weight)
618+
self.shared_expert_ffn2_weights.append(shared_expert_ffn2_weight)
619+
self.shared_expert_gate_weights.append(shared_expert_gate_weight)
620+
574621
self.cache_k_scales.append(cache_k_scale)
575622
self.cache_v_scales.append(cache_v_scale)
576623
self.cache_k_out_scales.append(cache_k_out_scale)
@@ -592,6 +639,11 @@ def __init__(self, config: FusedMultiTransformerConfig):
592639
self._add_parameter(ffn2_weight)
593640
self._add_parameter(ffn2_bias)
594641

642+
if self.config.moe_config.use_shared_expert(i):
643+
self._add_parameter(shared_expert_ffn1_weight)
644+
self._add_parameter(shared_expert_ffn2_weight)
645+
self._add_parameter(shared_expert_gate_weight)
646+
595647
self._add_parameter(cache_k_scale)
596648
self._add_parameter(cache_v_scale)
597649
self._add_parameter(cache_k_out_scale)
@@ -624,6 +676,7 @@ def init_weight_shape(self, config):
624676
else [self.embed_dim, (self.num_heads + 2 * self.kv_num_heads) * self.head_dim]
625677
)
626678
self.linear_weight_shape = [self.num_heads * self.head_dim, self.embed_dim]
679+
627680
self.ffn1_weight_shape = (
628681
[self.embed_dim, self.dim_feedforward * 2]
629682
if self.activation.endswith("glu")
@@ -639,6 +692,20 @@ def init_weight_shape(self, config):
639692
)
640693
self.moe_ffn2_weight_shape = [self.config.moe_config.num_experts, self.dim_feedforward, self.embed_dim]
641694

695+
if self.config.moe_config.has_shared_expert():
696+
self.shared_expert_ffn1_weight_shape = [
697+
self.embed_dim,
698+
self.config.moe_config.shared_expert_intermediate_size * 2,
699+
]
700+
self.shared_expert_ffn2_weight_shape = [
701+
self.config.moe_config.shared_expert_intermediate_size,
702+
self.embed_dim,
703+
]
704+
self.shared_expert_gate_weight_shape = [
705+
self.embed_dim,
706+
1,
707+
]
708+
642709
def get_weight_create_dype(self):
643710
return self._dtype
644711

@@ -851,6 +918,15 @@ def compute_bias_residual_layernorm(self, ffn2_out, residual_input, i, num_layer
851918
)[0]
852919
return tmp_out, residual_input
853920

921+
def compute_shared_expert(self, tmp_out, i):
922+
ffn1_out = paddle.matmul(tmp_out, self.shared_expert_ffn1_weights[i])
923+
ffn1_out = fused_act_bias_wrapper(ffn1_out, None, act_method=self.activation)
924+
ffn2_out = paddle.matmul(ffn1_out, self.shared_expert_ffn2_weights[i])
925+
gate_out = paddle.matmul(tmp_out, self.shared_expert_gate_weights[i])
926+
gate_out = paddle.nn.functional.sigmoid(gate_out)
927+
shared_expert_output = gate_out * ffn2_out
928+
return shared_expert_output
929+
854930
def pre_process(self, **kwargs):
855931
pass
856932

@@ -962,6 +1038,10 @@ def forward(
9621038
# fused moe
9631039
ffn2_out = self.compute_fused_moe(tmp_out, i)
9641040

1041+
# shared_expert
1042+
if self.config.moe_config.use_shared_expert(i):
1043+
shared_expert_out = self.compute_shared_expert(tmp_out, i)
1044+
ffn2_out = ffn2_out + shared_expert_out
9651045
else:
9661046
# ffn1 matmul
9671047
ffn1_out = self.compute_ffn1(tmp_out, i)
@@ -1046,13 +1126,25 @@ def __init__(self, config: FusedMultiTransformerConfig):
10461126
self.ffn1_weights_scale = []
10471127
self.ffn2_weights_scale = []
10481128

1129+
if self.config.moe_config.has_shared_expert():
1130+
self.shared_expert_ffn1_weights_scale = []
1131+
self.shared_expert_ffn2_weights_scale = []
1132+
10491133
for i in range(self.num_layers):
10501134

10511135
qkv_weight_scale_attr = self.get_attr(config.qkv_weight_scale_attrs, i)
10521136
linear_weight_scale_attr = self.get_attr(config.linear_weight_scale_attrs, i)
10531137
ffn1_weight_scale_attr = self.get_attr(config.ffn1_weight_scale_attrs, i)
10541138
ffn2_weight_scale_attr = self.get_attr(config.ffn2_weight_scale_attrs, i)
10551139

1140+
if self.config.moe_config.use_shared_expert(i):
1141+
shared_expert_ffn1_weight_scale_attr = self.get_attr(
1142+
config.moe_config.shared_expert_ffn1_weight_scale_attrs, i
1143+
)
1144+
shared_expert_ffn2_weight_scale_attr = self.get_attr(
1145+
config.moe_config.shared_expert_ffn2_weight_scale_attrs, i
1146+
)
1147+
10561148
qkv_weight_scale = self.create_parameter(
10571149
shape=[(self.num_heads + 2 * self.kv_num_heads) * self.head_dim],
10581150
attr=qkv_weight_scale_attr,
@@ -1069,9 +1161,9 @@ def __init__(self, config: FusedMultiTransformerConfig):
10691161

10701162
if self.config.moe_config.use_moe(i):
10711163
ffn1_weight_scale = self.create_parameter(
1072-
shape=[config.moe_config.num_experts, self.dim_feedforward * 2]
1164+
shape=[self.config.moe_config.num_experts, self.dim_feedforward * 2]
10731165
if config.activation.endswith("glu")
1074-
else [config.moe_config.num_experts, self.dim_feedforward],
1166+
else [self.config.moe_config.num_experts, self.dim_feedforward],
10751167
attr=ffn1_weight_scale_attr,
10761168
dtype=self.weight_scale_dtype,
10771169
is_bias=False,
@@ -1086,7 +1178,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
10861178

10871179
if self.config.moe_config.use_moe(i):
10881180
ffn2_weight_scale = self.create_parameter(
1089-
shape=[config.moe_config.num_experts, self.embed_dim],
1181+
shape=[self.config.moe_config.num_experts, self.embed_dim],
10901182
attr=ffn2_weight_scale_attr,
10911183
dtype=self.weight_scale_dtype,
10921184
is_bias=False,
@@ -1099,16 +1191,38 @@ def __init__(self, config: FusedMultiTransformerConfig):
10991191
is_bias=False,
11001192
)
11011193

1194+
if self.config.moe_config.use_shared_expert(i):
1195+
shared_expert_ffn1_weight_scale = self.create_parameter(
1196+
shape=[self.config.moe_config.shared_expert_intermediate_size * 2],
1197+
attr=shared_expert_ffn1_weight_scale_attr,
1198+
dtype=self.weight_scale_dtype,
1199+
is_bias=False,
1200+
)
1201+
shared_expert_ffn2_weight_scale = self.create_parameter(
1202+
shape=[self.embed_dim],
1203+
attr=shared_expert_ffn2_weight_scale_attr,
1204+
dtype=self.weight_scale_dtype,
1205+
is_bias=False,
1206+
)
1207+
11021208
self.qkv_weights_scale.append(qkv_weight_scale)
11031209
self.linear_weights_scale.append(linear_weight_scale)
11041210
self.ffn1_weights_scale.append(ffn1_weight_scale)
11051211
self.ffn2_weights_scale.append(ffn2_weight_scale)
11061212

1213+
if self.config.moe_config.use_shared_expert(i):
1214+
self.shared_expert_ffn1_weights_scale.append(shared_expert_ffn1_weight_scale)
1215+
self.shared_expert_ffn2_weights_scale.append(shared_expert_ffn2_weight_scale)
1216+
11071217
self._add_parameter(qkv_weight_scale)
11081218
self._add_parameter(linear_weight_scale)
11091219
self._add_parameter(ffn1_weight_scale)
11101220
self._add_parameter(ffn2_weight_scale)
11111221

1222+
if self.config.moe_config.use_shared_expert(i):
1223+
self._add_parameter(shared_expert_ffn1_weight_scale)
1224+
self._add_parameter(shared_expert_ffn2_weight_scale)
1225+
11121226
def get_weight_create_dype(self):
11131227
return "int8" # If use weightonly int4, params dtype is int8, and one of the dimension will be half.
11141228

@@ -1141,6 +1255,20 @@ def init_weight_shape(self, config):
11411255
self.moe_ffn1_weight_shape[2] //= 2
11421256
self.moe_ffn2_weight_shape[2] //= 2
11431257

1258+
if self.config.moe_config.has_shared_expert():
1259+
self.shared_expert_ffn1_weight_shape = [
1260+
self.config.moe_config.shared_expert_intermediate_size * 2,
1261+
self.embed_dim,
1262+
]
1263+
self.shared_expert_ffn2_weight_shape = [
1264+
self.embed_dim,
1265+
self.config.moe_config.shared_expert_intermediate_size,
1266+
]
1267+
self.shared_expert_gate_weight_shape = [
1268+
self.embed_dim,
1269+
1,
1270+
]
1271+
11441272
def compute_qkv_linear(self, ln_out, i):
11451273
return weight_only_linear(
11461274
ln_out,
@@ -1197,6 +1325,29 @@ def compute_ffn2(self, ffn1_out, i):
11971325
weight_dtype=self.weight_dtype,
11981326
)
11991327

1328+
def compute_shared_expert(self, tmp_out, i):
1329+
ffn1_out = weight_only_linear(
1330+
tmp_out,
1331+
weight=self.shared_expert_ffn1_weights[i],
1332+
weight_scale=self.shared_expert_ffn1_weights_scale[i],
1333+
weight_dtype=self.weight_dtype,
1334+
)
1335+
1336+
ffn1_out = fused_act_bias_wrapper(ffn1_out, None, act_method=self.activation)
1337+
1338+
ffn2_out = weight_only_linear(
1339+
ffn1_out,
1340+
weight=self.shared_expert_ffn2_weights[i],
1341+
weight_scale=self.shared_expert_ffn2_weights_scale[i],
1342+
weight_dtype=self.weight_dtype,
1343+
)
1344+
1345+
gate_out = paddle.matmul(tmp_out, self.shared_expert_gate_weights[i])
1346+
gate_out = paddle.nn.functional.sigmoid(gate_out)
1347+
1348+
shared_expert_output = gate_out * ffn2_out
1349+
return shared_expert_output
1350+
12001351

12011352
class FusedMultiTransformerWeightOnlyPostLayernorm(
12021353
FusedMultiTransformerWeightOnly, FusedMultiTransformerPostLayernorm

0 commit comments

Comments
 (0)