Skip to content

Commit 88e23ee

Browse files
authored
Reduce padding overhead for sharedMoE (#3606)
1 parent b80f700 commit 88e23ee

File tree

3 files changed

+151
-10
lines changed

3 files changed

+151
-10
lines changed

csrc/cpu/aten/DSMoE.cpp

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,98 @@ namespace torch_ipex {
99
namespace cpu {
1010

1111
IPEX_DEFINE_DISPATCH(fused_experts_impl_stub);
12+
template <typename T>
13+
inline void copy_and_fill(
14+
T* __restrict__ out,
15+
const T* __restrict__ input,
16+
int size,
17+
int pad_size,
18+
T fill_value) {
19+
using Vec = at::vec::Vectorized<T>;
20+
int d = 0;
21+
if (size >= Vec::size()) {
22+
#pragma GCC unroll 4
23+
for (; d < size; d += Vec::size()) {
24+
Vec data = Vec::loadu(input + d);
25+
data.store(out + d);
26+
}
27+
}
28+
for (; d < size; ++d) {
29+
out[d] = input[d];
30+
}
31+
// using scalar padding as pad_size is less than vec size here
32+
for (; d < pad_size; ++d) {
33+
out[d] = fill_value;
34+
}
35+
}
36+
37+
at::Tensor fused_experts_with_shared(
38+
const at::Tensor& hidden_states,
39+
const at::Tensor& w1,
40+
const at::Tensor& w2,
41+
const at::Tensor& topk_weights,
42+
const at::Tensor& topk_ids,
43+
bool inplace,
44+
bool is_vnni,
45+
bool is_distributed,
46+
bool is_woq,
47+
int64_t woq_weight_dtype,
48+
int64_t woq_group_size,
49+
int64_t woq_lowp_mode,
50+
const std::optional<at::Tensor>& w1_scale,
51+
const std::optional<at::Tensor>& w1_zp,
52+
const std::optional<at::Tensor>& w1_compensation,
53+
const std::optional<at::Tensor>& w2_scale,
54+
const std::optional<at::Tensor>& w2_zp,
55+
const std::optional<at::Tensor>& w2_compensation) {
56+
RECORD_FUNCTION(
57+
"ipex::fused_experts_with_shared", c10::ArrayRef<c10::IValue>({}));
58+
int32_t num_tokens = topk_weights.size(0);
59+
int32_t num_topk_experts = topk_weights.size(1);
60+
int32_t num_topk_experts_pad = num_topk_experts + 1;
61+
int32_t num_experts = w1.size(0);
62+
auto pad_weight =
63+
at::empty({num_tokens, num_topk_experts_pad}, topk_weights.options());
64+
auto pad_ids =
65+
at::empty({num_tokens, num_topk_experts_pad}, topk_ids.options());
66+
// padding 1 shared expert to routed expert
67+
// topk_id is num_experts - 1, and topk weights is 1.0
68+
for (int id = 0; id < num_tokens; id++) {
69+
copy_and_fill<int32_t>(
70+
pad_ids.data_ptr<int32_t>() + id * num_topk_experts_pad,
71+
topk_ids.data_ptr<int32_t>() + id * num_topk_experts,
72+
num_topk_experts,
73+
num_topk_experts_pad,
74+
num_experts - 1);
75+
copy_and_fill<float>(
76+
pad_weight.data_ptr<float>() + id * num_topk_experts_pad,
77+
topk_weights.data_ptr<float>() + id * num_topk_experts,
78+
num_topk_experts,
79+
num_topk_experts_pad,
80+
1.0);
81+
}
82+
return fused_experts_impl_stub(
83+
kCPU,
84+
hidden_states,
85+
w1,
86+
w2,
87+
pad_weight,
88+
pad_ids,
89+
inplace,
90+
is_vnni,
91+
is_distributed,
92+
is_woq,
93+
woq_weight_dtype,
94+
woq_group_size,
95+
woq_lowp_mode,
96+
w1_scale,
97+
w1_zp,
98+
w1_compensation,
99+
w2_scale,
100+
w2_zp,
101+
w2_compensation);
102+
}
103+
12104
at::Tensor fused_experts(
13105
const at::Tensor& hidden_states,
14106
const at::Tensor& w1,
@@ -334,6 +426,15 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
334426
Tensor? w1_scale, Tensor? w1_zp, Tensor? w1_compensation, Tensor? w2_scale, Tensor? w2_zp, Tensor? w2_compensation) -> Tensor");
335427
m.impl(
336428
"fused_experts", c10::DispatchKey::CPU, torch_ipex::cpu::fused_experts);
429+
m.def(
430+
"fused_experts_with_shared(Tensor hidden_states, Tensor w1, Tensor w2, Tensor topk_weights, \
431+
Tensor topk_ids, bool inplace, bool is_vnni, \
432+
bool is_distributed, bool is_woq, int woq_weight_dtype, int woq_group_size, int woq_lowp_mode, \
433+
Tensor? w1_scale, Tensor? w1_zp, Tensor? w1_compensation, Tensor? w2_scale, Tensor? w2_zp, Tensor? w2_compensation) -> Tensor");
434+
m.impl(
435+
"fused_experts_with_shared",
436+
c10::DispatchKey::CPU,
437+
torch_ipex::cpu::fused_experts_with_shared);
337438
m.def(
338439
"grouped_topk(Tensor hidden_states, Tensor gating_output, \
339440
int topk, bool renormalize, int num_expert_group, int topk_group, Tensor e_score_correction_bias, Tensor routed_scaling_factor) -> (Tensor, Tensor)");

intel_extension_for_pytorch/transformers/models/reference/modules/decoder.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1956,15 +1956,7 @@ def JambaMambaDecoderLayer_forward(
19561956
def moe_infer(self, x, topk_ids, topk_weight):
19571957
if self.use_fused_moe or self.use_fused_moe_woq:
19581958
if self.unify_experts:
1959-
pad_weights = torch.ones(x.size(0), 1)
1960-
pad_ids = torch.full((x.size(0), 1), self.unify_shared_expert_id - 1).to(
1961-
torch.int
1962-
)
1963-
topk_weight = torch.cat((topk_weight.to(torch.float), pad_weights), -1).to(
1964-
torch.float
1965-
)
1966-
topk_ids = torch.cat((topk_ids.to(torch.int), pad_ids), -1).to(torch.int)
1967-
final_out = torch.ops.torch_ipex.fused_experts(
1959+
final_out = torch.ops.torch_ipex.fused_experts_with_shared(
19681960
x,
19691961
self.w13_weight,
19701962
self.w2_weight,

tests/cpu/test_deepseek_ops.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -815,6 +815,51 @@ def fuse_moe_with_sharedmoe(a, w1, w2, score, topk, renormalize):
815815
w2_comp,
816816
)
817817

818+
def fuse_moe_with_sharedmoe_v2(a, w1, w2, score, topk, renormalize):
819+
820+
G = 1
821+
topk_group = 1
822+
823+
B, D = a.shape
824+
topk_weights = torch.empty(B, topk, dtype=torch.float32)
825+
topk_ids = torch.empty(B, topk, dtype=torch.int32)
826+
topk_weights, topk_ids = grouped_topk_native(
827+
a, score, topk, renormalize, G, topk_group
828+
)
829+
830+
packed_w1 = torch.ops.torch_ipex.convert_weight_packed_bf16(w1)
831+
packed_w2 = torch.ops.torch_ipex.convert_weight_packed_bf16(w2)
832+
w13_scale = None
833+
w13_zp = None
834+
w13_comp = None
835+
w2_scale = None
836+
w2_zp = None
837+
w2_comp = None
838+
inplace = False
839+
group_size = -1
840+
weight_dtype = WoqWeightDtype.INT8
841+
lowp_mode = WoqLowpMode.BF16
842+
return torch.ops.torch_ipex.fused_experts_with_shared(
843+
a,
844+
packed_w1,
845+
packed_w2,
846+
topk_weights,
847+
topk_ids,
848+
inplace,
849+
True,
850+
False,
851+
False,
852+
weight_dtype,
853+
group_size,
854+
lowp_mode,
855+
w13_scale,
856+
w13_zp,
857+
w13_comp,
858+
w2_scale,
859+
w2_zp,
860+
w2_comp,
861+
)
862+
818863
def run_single_test(m, n, k, e, topk, dtype, renormalize=False):
819864

820865
a = torch.randn((m, k), device="cpu", dtype=dtype) / 10
@@ -841,8 +886,11 @@ def run_single_test(m, n, k, e, topk, dtype, renormalize=False):
841886
fused_output = fuse_moe_with_sharedmoe(
842887
a, w1_, w2_, score, topk, renormalize
843888
)
844-
889+
fused_output_v2 = fuse_moe_with_sharedmoe_v2(
890+
a, w1_, w2_, score, topk, renormalize
891+
)
845892
compare(torch_output, fused_output)
893+
compare(torch_output, fused_output_v2)
846894

847895
run_single_test(2, 2048, 2048, 4, 2, torch.bfloat16, renormalize=True)
848896
run_single_test(2, 128, 32, 4, 2, torch.bfloat16, renormalize=True)

0 commit comments

Comments
 (0)