@@ -9,6 +9,98 @@ namespace torch_ipex {
9
9
namespace cpu {
10
10
11
11
IPEX_DEFINE_DISPATCH (fused_experts_impl_stub);
12
+ template <typename T>
13
+ inline void copy_and_fill (
14
+ T* __restrict__ out,
15
+ const T* __restrict__ input,
16
+ int size,
17
+ int pad_size,
18
+ T fill_value) {
19
+ using Vec = at::vec::Vectorized<T>;
20
+ int d = 0 ;
21
+ if (size >= Vec::size ()) {
22
+ #pragma GCC unroll 4
23
+ for (; d < size; d += Vec::size ()) {
24
+ Vec data = Vec::loadu (input + d);
25
+ data.store (out + d);
26
+ }
27
+ }
28
+ for (; d < size; ++d) {
29
+ out[d] = input[d];
30
+ }
31
+ // using scalar padding as pad_size is less than vec size here
32
+ for (; d < pad_size; ++d) {
33
+ out[d] = fill_value;
34
+ }
35
+ }
36
+
37
+ at::Tensor fused_experts_with_shared (
38
+ const at::Tensor& hidden_states,
39
+ const at::Tensor& w1,
40
+ const at::Tensor& w2,
41
+ const at::Tensor& topk_weights,
42
+ const at::Tensor& topk_ids,
43
+ bool inplace,
44
+ bool is_vnni,
45
+ bool is_distributed,
46
+ bool is_woq,
47
+ int64_t woq_weight_dtype,
48
+ int64_t woq_group_size,
49
+ int64_t woq_lowp_mode,
50
+ const std::optional<at::Tensor>& w1_scale,
51
+ const std::optional<at::Tensor>& w1_zp,
52
+ const std::optional<at::Tensor>& w1_compensation,
53
+ const std::optional<at::Tensor>& w2_scale,
54
+ const std::optional<at::Tensor>& w2_zp,
55
+ const std::optional<at::Tensor>& w2_compensation) {
56
+ RECORD_FUNCTION (
57
+ " ipex::fused_experts_with_shared" , c10::ArrayRef<c10::IValue>({}));
58
+ int32_t num_tokens = topk_weights.size (0 );
59
+ int32_t num_topk_experts = topk_weights.size (1 );
60
+ int32_t num_topk_experts_pad = num_topk_experts + 1 ;
61
+ int32_t num_experts = w1.size (0 );
62
+ auto pad_weight =
63
+ at::empty ({num_tokens, num_topk_experts_pad}, topk_weights.options ());
64
+ auto pad_ids =
65
+ at::empty ({num_tokens, num_topk_experts_pad}, topk_ids.options ());
66
+ // padding 1 shared expert to routed expert
67
+ // topk_id is num_experts - 1, and topk weights is 1.0
68
+ for (int id = 0 ; id < num_tokens; id++) {
69
+ copy_and_fill<int32_t >(
70
+ pad_ids.data_ptr <int32_t >() + id * num_topk_experts_pad,
71
+ topk_ids.data_ptr <int32_t >() + id * num_topk_experts,
72
+ num_topk_experts,
73
+ num_topk_experts_pad,
74
+ num_experts - 1 );
75
+ copy_and_fill<float >(
76
+ pad_weight.data_ptr <float >() + id * num_topk_experts_pad,
77
+ topk_weights.data_ptr <float >() + id * num_topk_experts,
78
+ num_topk_experts,
79
+ num_topk_experts_pad,
80
+ 1.0 );
81
+ }
82
+ return fused_experts_impl_stub (
83
+ kCPU ,
84
+ hidden_states,
85
+ w1,
86
+ w2,
87
+ pad_weight,
88
+ pad_ids,
89
+ inplace,
90
+ is_vnni,
91
+ is_distributed,
92
+ is_woq,
93
+ woq_weight_dtype,
94
+ woq_group_size,
95
+ woq_lowp_mode,
96
+ w1_scale,
97
+ w1_zp,
98
+ w1_compensation,
99
+ w2_scale,
100
+ w2_zp,
101
+ w2_compensation);
102
+ }
103
+
12
104
at::Tensor fused_experts (
13
105
const at::Tensor& hidden_states,
14
106
const at::Tensor& w1,
@@ -334,6 +426,15 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
334
426
Tensor? w1_scale, Tensor? w1_zp, Tensor? w1_compensation, Tensor? w2_scale, Tensor? w2_zp, Tensor? w2_compensation) -> Tensor" );
335
427
m.impl (
336
428
" fused_experts" , c10::DispatchKey::CPU, torch_ipex::cpu::fused_experts);
429
+ m.def (
430
+ " fused_experts_with_shared(Tensor hidden_states, Tensor w1, Tensor w2, Tensor topk_weights, \
431
+ Tensor topk_ids, bool inplace, bool is_vnni, \
432
+ bool is_distributed, bool is_woq, int woq_weight_dtype, int woq_group_size, int woq_lowp_mode, \
433
+ Tensor? w1_scale, Tensor? w1_zp, Tensor? w1_compensation, Tensor? w2_scale, Tensor? w2_zp, Tensor? w2_compensation) -> Tensor" );
434
+ m.impl (
435
+ " fused_experts_with_shared" ,
436
+ c10::DispatchKey::CPU,
437
+ torch_ipex::cpu::fused_experts_with_shared);
337
438
m.def (
338
439
" grouped_topk(Tensor hidden_states, Tensor gating_output, \
339
440
int topk, bool renormalize, int num_expert_group, int topk_group, Tensor e_score_correction_bias, Tensor routed_scaling_factor) -> (Tensor, Tensor)" );
0 commit comments