14
14
from __future__ import annotations
15
15
16
16
from dataclasses import dataclass
17
+ from typing import List , Optional
17
18
18
19
import paddle
19
20
import paddle .distributed as dist
@@ -157,12 +158,25 @@ class MoeConfig:
157
158
norm_topk_prob : bool = True
158
159
moe_every2 : bool = False
159
160
161
+ shared_expert_intermediate_size : int = 0
162
+ shared_expert_ffn1_weight_attrs : Optional [List [paddle .ParamAttr ]] = None
163
+ shared_expert_ffn1_weight_scale_attrs : Optional [List [paddle .ParamAttr ]] = None
164
+ shared_expert_ffn2_weight_attrs : Optional [List [paddle .ParamAttr ]] = None
165
+ shared_expert_ffn2_weight_scale_attrs : Optional [List [paddle .ParamAttr ]] = None
166
+ shared_expert_gate_weight_attrs : Optional [List [paddle .ParamAttr ]] = None
167
+
160
168
def has_moe (self ) -> bool :
161
169
return self .num_experts > 1
162
170
163
171
def use_moe (self , i : int ) -> bool :
164
172
return self .has_moe () and (self .moe_every2 is False or (self .moe_every2 and i % 2 == 1 ))
165
173
174
+ def has_shared_expert (self ) -> bool :
175
+ return self .has_moe () and self .shared_expert_intermediate_size > 0
176
+
177
+ def use_shared_expert (self , i : int ) -> bool :
178
+ return self .use_moe (i ) and self .shared_expert_intermediate_size > 0
179
+
166
180
167
181
class FusedMultiTransformerConfig :
168
182
def __init__ (
@@ -342,9 +356,15 @@ def __init__(self, config: FusedMultiTransformerConfig):
342
356
self .gate_weights = []
343
357
self .ffn1_weights , self .ffn1_biases = [], []
344
358
self .ffn2_weights , self .ffn2_biases = [], []
359
+ if self .config .moe_config .has_shared_expert ():
360
+ self .shared_expert_gate_weights = []
361
+ self .shared_expert_ffn1_weights = []
362
+ self .shared_expert_ffn2_weights = []
345
363
self .cache_k_scales , self .cache_v_scales = [], []
346
364
self .cache_k_out_scales , self .cache_v_out_scales = [], []
347
365
366
+ self .init_weight_shape (config )
367
+
348
368
for i in range (self .num_layers ):
349
369
ln_scale_attr = self .get_attr (config .ln_scale_attrs , i )
350
370
ln_bias_attr = self .get_attr (config .ln_bias_attrs , i )
@@ -362,6 +382,11 @@ def __init__(self, config: FusedMultiTransformerConfig):
362
382
ffn2_weight_attr = self .get_attr (config .ffn2_weight_attrs , i )
363
383
ffn2_bias_attr = self .get_attr (config .ffn2_bias_attrs , i )
364
384
385
+ if self .config .moe_config .use_shared_expert (i ):
386
+ shared_expert_gate_weight_attr = self .get_attr (config .moe_config .shared_expert_gate_weight_attrs , i )
387
+ shared_expert_ffn1_weight_attr = self .get_attr (config .moe_config .shared_expert_ffn1_weight_attrs , i )
388
+ shared_expert_ffn2_weight_attr = self .get_attr (config .moe_config .shared_expert_ffn2_weight_attrs , i )
389
+
365
390
cache_k_scale_attr = self .get_attr (config .cache_k_scale_attrs , i )
366
391
cache_v_scale_attr = self .get_attr (config .cache_v_scale_attrs , i )
367
392
cache_k_out_scale_attr = self .get_attr (config .cache_k_out_scale_attrs , i )
@@ -381,7 +406,6 @@ def __init__(self, config: FusedMultiTransformerConfig):
381
406
is_bias = True ,
382
407
dtype = self ._norm_weight_dtype ,
383
408
)
384
- self .init_weight_shape (config )
385
409
386
410
qkv_weight = self .create_parameter (
387
411
shape = self .qkv_weight_shape ,
@@ -433,7 +457,8 @@ def __init__(self, config: FusedMultiTransformerConfig):
433
457
)
434
458
435
459
gate_weight = None
436
- if config .moe_config .use_moe (i ):
460
+
461
+ if self .config .moe_config .use_moe (i ):
437
462
gate_weight = self .create_parameter (
438
463
shape = [config .embed_dim , self .config .moe_config .num_experts ],
439
464
attr = gate_weight_attr ,
@@ -442,7 +467,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
442
467
default_initializer = paddle .nn .initializer .Constant (0 ),
443
468
)
444
469
445
- if config .moe_config .use_moe (i ):
470
+ if self . config .moe_config .use_moe (i ):
446
471
ffn1_weight = self .create_parameter (
447
472
shape = self .moe_ffn1_weight_shape ,
448
473
attr = ffn1_weight_attr ,
@@ -493,7 +518,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
493
518
494
519
ffn2_bias = None
495
520
if ffn2_bias_attr :
496
- if config .moe_config .use_moe (i ):
521
+ if self . config .moe_config .use_moe (i ):
497
522
ffn2_bias = self .create_parameter (
498
523
shape = [self .config .moe_config .num_experts , config .embed_dim ],
499
524
attr = ffn2_bias_attr ,
@@ -508,6 +533,23 @@ def __init__(self, config: FusedMultiTransformerConfig):
508
533
is_bias = True ,
509
534
)
510
535
536
+ if self .config .moe_config .use_shared_expert (i ):
537
+ shared_expert_ffn1_weight = self .create_parameter (
538
+ shape = self .shared_expert_ffn1_weight_shape ,
539
+ attr = shared_expert_ffn1_weight_attr ,
540
+ dtype = self .create_params_type ,
541
+ )
542
+ shared_expert_ffn2_weight = self .create_parameter (
543
+ shape = self .shared_expert_ffn2_weight_shape ,
544
+ attr = shared_expert_ffn2_weight_attr ,
545
+ dtype = self .create_params_type ,
546
+ )
547
+ shared_expert_gate_weight = self .create_parameter (
548
+ shape = self .shared_expert_gate_weight_shape ,
549
+ attr = shared_expert_gate_weight_attr ,
550
+ dtype = self ._helper .get_default_dtype (),
551
+ )
552
+
511
553
cache_k_scale = None
512
554
if cache_k_scale_attr :
513
555
cache_k_scale = self .create_parameter (
@@ -571,6 +613,11 @@ def __init__(self, config: FusedMultiTransformerConfig):
571
613
self .ffn2_weights .append (ffn2_weight )
572
614
self .ffn2_biases .append (ffn2_bias )
573
615
616
+ if self .config .moe_config .use_shared_expert (i ):
617
+ self .shared_expert_ffn1_weights .append (shared_expert_ffn1_weight )
618
+ self .shared_expert_ffn2_weights .append (shared_expert_ffn2_weight )
619
+ self .shared_expert_gate_weights .append (shared_expert_gate_weight )
620
+
574
621
self .cache_k_scales .append (cache_k_scale )
575
622
self .cache_v_scales .append (cache_v_scale )
576
623
self .cache_k_out_scales .append (cache_k_out_scale )
@@ -592,6 +639,11 @@ def __init__(self, config: FusedMultiTransformerConfig):
592
639
self ._add_parameter (ffn2_weight )
593
640
self ._add_parameter (ffn2_bias )
594
641
642
+ if self .config .moe_config .use_shared_expert (i ):
643
+ self ._add_parameter (shared_expert_ffn1_weight )
644
+ self ._add_parameter (shared_expert_ffn2_weight )
645
+ self ._add_parameter (shared_expert_gate_weight )
646
+
595
647
self ._add_parameter (cache_k_scale )
596
648
self ._add_parameter (cache_v_scale )
597
649
self ._add_parameter (cache_k_out_scale )
@@ -624,6 +676,7 @@ def init_weight_shape(self, config):
624
676
else [self .embed_dim , (self .num_heads + 2 * self .kv_num_heads ) * self .head_dim ]
625
677
)
626
678
self .linear_weight_shape = [self .num_heads * self .head_dim , self .embed_dim ]
679
+
627
680
self .ffn1_weight_shape = (
628
681
[self .embed_dim , self .dim_feedforward * 2 ]
629
682
if self .activation .endswith ("glu" )
@@ -639,6 +692,20 @@ def init_weight_shape(self, config):
639
692
)
640
693
self .moe_ffn2_weight_shape = [self .config .moe_config .num_experts , self .dim_feedforward , self .embed_dim ]
641
694
695
+ if self .config .moe_config .has_shared_expert ():
696
+ self .shared_expert_ffn1_weight_shape = [
697
+ self .embed_dim ,
698
+ self .config .moe_config .shared_expert_intermediate_size * 2 ,
699
+ ]
700
+ self .shared_expert_ffn2_weight_shape = [
701
+ self .config .moe_config .shared_expert_intermediate_size ,
702
+ self .embed_dim ,
703
+ ]
704
+ self .shared_expert_gate_weight_shape = [
705
+ self .embed_dim ,
706
+ 1 ,
707
+ ]
708
+
642
709
def get_weight_create_dype (self ):
643
710
return self ._dtype
644
711
@@ -851,6 +918,15 @@ def compute_bias_residual_layernorm(self, ffn2_out, residual_input, i, num_layer
851
918
)[0 ]
852
919
return tmp_out , residual_input
853
920
921
+ def compute_shared_expert (self , tmp_out , i ):
922
+ ffn1_out = paddle .matmul (tmp_out , self .shared_expert_ffn1_weights [i ])
923
+ ffn1_out = fused_act_bias_wrapper (ffn1_out , None , act_method = self .activation )
924
+ ffn2_out = paddle .matmul (ffn1_out , self .shared_expert_ffn2_weights [i ])
925
+ gate_out = paddle .matmul (tmp_out , self .shared_expert_gate_weights [i ])
926
+ gate_out = paddle .nn .functional .sigmoid (gate_out )
927
+ shared_expert_output = gate_out * ffn2_out
928
+ return shared_expert_output
929
+
854
930
def pre_process (self , ** kwargs ):
855
931
pass
856
932
@@ -962,6 +1038,10 @@ def forward(
962
1038
# fused moe
963
1039
ffn2_out = self .compute_fused_moe (tmp_out , i )
964
1040
1041
+ # shared_expert
1042
+ if self .config .moe_config .use_shared_expert (i ):
1043
+ shared_expert_out = self .compute_shared_expert (tmp_out , i )
1044
+ ffn2_out = ffn2_out + shared_expert_out
965
1045
else :
966
1046
# ffn1 matmul
967
1047
ffn1_out = self .compute_ffn1 (tmp_out , i )
@@ -1046,13 +1126,25 @@ def __init__(self, config: FusedMultiTransformerConfig):
1046
1126
self .ffn1_weights_scale = []
1047
1127
self .ffn2_weights_scale = []
1048
1128
1129
+ if self .config .moe_config .has_shared_expert ():
1130
+ self .shared_expert_ffn1_weights_scale = []
1131
+ self .shared_expert_ffn2_weights_scale = []
1132
+
1049
1133
for i in range (self .num_layers ):
1050
1134
1051
1135
qkv_weight_scale_attr = self .get_attr (config .qkv_weight_scale_attrs , i )
1052
1136
linear_weight_scale_attr = self .get_attr (config .linear_weight_scale_attrs , i )
1053
1137
ffn1_weight_scale_attr = self .get_attr (config .ffn1_weight_scale_attrs , i )
1054
1138
ffn2_weight_scale_attr = self .get_attr (config .ffn2_weight_scale_attrs , i )
1055
1139
1140
+ if self .config .moe_config .use_shared_expert (i ):
1141
+ shared_expert_ffn1_weight_scale_attr = self .get_attr (
1142
+ config .moe_config .shared_expert_ffn1_weight_scale_attrs , i
1143
+ )
1144
+ shared_expert_ffn2_weight_scale_attr = self .get_attr (
1145
+ config .moe_config .shared_expert_ffn2_weight_scale_attrs , i
1146
+ )
1147
+
1056
1148
qkv_weight_scale = self .create_parameter (
1057
1149
shape = [(self .num_heads + 2 * self .kv_num_heads ) * self .head_dim ],
1058
1150
attr = qkv_weight_scale_attr ,
@@ -1069,9 +1161,9 @@ def __init__(self, config: FusedMultiTransformerConfig):
1069
1161
1070
1162
if self .config .moe_config .use_moe (i ):
1071
1163
ffn1_weight_scale = self .create_parameter (
1072
- shape = [config .moe_config .num_experts , self .dim_feedforward * 2 ]
1164
+ shape = [self . config .moe_config .num_experts , self .dim_feedforward * 2 ]
1073
1165
if config .activation .endswith ("glu" )
1074
- else [config .moe_config .num_experts , self .dim_feedforward ],
1166
+ else [self . config .moe_config .num_experts , self .dim_feedforward ],
1075
1167
attr = ffn1_weight_scale_attr ,
1076
1168
dtype = self .weight_scale_dtype ,
1077
1169
is_bias = False ,
@@ -1086,7 +1178,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
1086
1178
1087
1179
if self .config .moe_config .use_moe (i ):
1088
1180
ffn2_weight_scale = self .create_parameter (
1089
- shape = [config .moe_config .num_experts , self .embed_dim ],
1181
+ shape = [self . config .moe_config .num_experts , self .embed_dim ],
1090
1182
attr = ffn2_weight_scale_attr ,
1091
1183
dtype = self .weight_scale_dtype ,
1092
1184
is_bias = False ,
@@ -1099,16 +1191,38 @@ def __init__(self, config: FusedMultiTransformerConfig):
1099
1191
is_bias = False ,
1100
1192
)
1101
1193
1194
+ if self .config .moe_config .use_shared_expert (i ):
1195
+ shared_expert_ffn1_weight_scale = self .create_parameter (
1196
+ shape = [self .config .moe_config .shared_expert_intermediate_size * 2 ],
1197
+ attr = shared_expert_ffn1_weight_scale_attr ,
1198
+ dtype = self .weight_scale_dtype ,
1199
+ is_bias = False ,
1200
+ )
1201
+ shared_expert_ffn2_weight_scale = self .create_parameter (
1202
+ shape = [self .embed_dim ],
1203
+ attr = shared_expert_ffn2_weight_scale_attr ,
1204
+ dtype = self .weight_scale_dtype ,
1205
+ is_bias = False ,
1206
+ )
1207
+
1102
1208
self .qkv_weights_scale .append (qkv_weight_scale )
1103
1209
self .linear_weights_scale .append (linear_weight_scale )
1104
1210
self .ffn1_weights_scale .append (ffn1_weight_scale )
1105
1211
self .ffn2_weights_scale .append (ffn2_weight_scale )
1106
1212
1213
+ if self .config .moe_config .use_shared_expert (i ):
1214
+ self .shared_expert_ffn1_weights_scale .append (shared_expert_ffn1_weight_scale )
1215
+ self .shared_expert_ffn2_weights_scale .append (shared_expert_ffn2_weight_scale )
1216
+
1107
1217
self ._add_parameter (qkv_weight_scale )
1108
1218
self ._add_parameter (linear_weight_scale )
1109
1219
self ._add_parameter (ffn1_weight_scale )
1110
1220
self ._add_parameter (ffn2_weight_scale )
1111
1221
1222
+ if self .config .moe_config .use_shared_expert (i ):
1223
+ self ._add_parameter (shared_expert_ffn1_weight_scale )
1224
+ self ._add_parameter (shared_expert_ffn2_weight_scale )
1225
+
1112
1226
def get_weight_create_dype (self ):
1113
1227
return "int8" # If use weightonly int4, params dtype is int8, and one of the dimension will be half.
1114
1228
@@ -1141,6 +1255,20 @@ def init_weight_shape(self, config):
1141
1255
self .moe_ffn1_weight_shape [2 ] //= 2
1142
1256
self .moe_ffn2_weight_shape [2 ] //= 2
1143
1257
1258
+ if self .config .moe_config .has_shared_expert ():
1259
+ self .shared_expert_ffn1_weight_shape = [
1260
+ self .config .moe_config .shared_expert_intermediate_size * 2 ,
1261
+ self .embed_dim ,
1262
+ ]
1263
+ self .shared_expert_ffn2_weight_shape = [
1264
+ self .embed_dim ,
1265
+ self .config .moe_config .shared_expert_intermediate_size ,
1266
+ ]
1267
+ self .shared_expert_gate_weight_shape = [
1268
+ self .embed_dim ,
1269
+ 1 ,
1270
+ ]
1271
+
1144
1272
def compute_qkv_linear (self , ln_out , i ):
1145
1273
return weight_only_linear (
1146
1274
ln_out ,
@@ -1197,6 +1325,29 @@ def compute_ffn2(self, ffn1_out, i):
1197
1325
weight_dtype = self .weight_dtype ,
1198
1326
)
1199
1327
1328
+ def compute_shared_expert (self , tmp_out , i ):
1329
+ ffn1_out = weight_only_linear (
1330
+ tmp_out ,
1331
+ weight = self .shared_expert_ffn1_weights [i ],
1332
+ weight_scale = self .shared_expert_ffn1_weights_scale [i ],
1333
+ weight_dtype = self .weight_dtype ,
1334
+ )
1335
+
1336
+ ffn1_out = fused_act_bias_wrapper (ffn1_out , None , act_method = self .activation )
1337
+
1338
+ ffn2_out = weight_only_linear (
1339
+ ffn1_out ,
1340
+ weight = self .shared_expert_ffn2_weights [i ],
1341
+ weight_scale = self .shared_expert_ffn2_weights_scale [i ],
1342
+ weight_dtype = self .weight_dtype ,
1343
+ )
1344
+
1345
+ gate_out = paddle .matmul (tmp_out , self .shared_expert_gate_weights [i ])
1346
+ gate_out = paddle .nn .functional .sigmoid (gate_out )
1347
+
1348
+ shared_expert_output = gate_out * ffn2_out
1349
+ return shared_expert_output
1350
+
1200
1351
1201
1352
class FusedMultiTransformerWeightOnlyPostLayernorm (
1202
1353
FusedMultiTransformerWeightOnly , FusedMultiTransformerPostLayernorm
0 commit comments