Skip to content

Commit 7a72522

Browse files
authored
Add fused_transpose_wlch_split_quant op (#73520)
1 parent 67dc159 commit 7a72522

File tree

9 files changed

+504
-0
lines changed

9 files changed

+504
-0
lines changed

paddle/phi/infermeta/fusion.cc

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2504,6 +2504,88 @@ void FusedTransposeSplitQuantInferMeta(const MetaTensor& x,
25042504
"x.shape[1] (%d) must be <= 65535 * 128", N));
25052505
}
25062506

2507+
void FusedTransposeWLCHSplitQuantInferMeta(const MetaTensor& x,
2508+
const IntArray& tokens_per_expert,
2509+
bool pow_2_scales,
2510+
std::vector<MetaTensor*> outs,
2511+
std::vector<MetaTensor*> scales) {
2512+
PADDLE_ENFORCE_EQ(
2513+
x.dtype(),
2514+
DataType::BFLOAT16,
2515+
common::errors::InvalidArgument(
2516+
"The dtype of Input(x) must be BFLOAT16, but received %s",
2517+
x.dtype()));
2518+
2519+
auto x_dims = x.dims();
2520+
2521+
PADDLE_ENFORCE_EQ(
2522+
x_dims.size(),
2523+
4,
2524+
common::errors::InvalidArgument(
2525+
"Input(x) must have dimension of 4, but got %d.", x_dims.size()));
2526+
2527+
const int64_t M = x_dims[0] * x_dims[1] * x_dims[2];
2528+
const int64_t H = x_dims[3];
2529+
2530+
auto tokens_list = tokens_per_expert.GetData();
2531+
const size_t num_experts = tokens_list.size();
2532+
2533+
PADDLE_ENFORCE_EQ(
2534+
outs.size(),
2535+
num_experts,
2536+
common::errors::InvalidArgument(
2537+
"Size of outs (%d) must equal size of tokens_per_expert (%d)",
2538+
outs.size(),
2539+
num_experts));
2540+
2541+
PADDLE_ENFORCE_EQ(
2542+
scales.size(),
2543+
num_experts,
2544+
common::errors::InvalidArgument(
2545+
"Size of scales (%d) must equal size of tokens_per_expert (%d)",
2546+
scales.size(),
2547+
num_experts));
2548+
2549+
int64_t sum_tokens = 0;
2550+
for (size_t i = 0; i < num_experts; ++i) {
2551+
const int64_t tokens = tokens_list[i];
2552+
2553+
PADDLE_ENFORCE_EQ(
2554+
tokens % 128,
2555+
0,
2556+
common::errors::InvalidArgument(
2557+
"tokens_per_expert[%d] (%d) must be divisible by 128", i, tokens));
2558+
2559+
sum_tokens += tokens;
2560+
2561+
if (outs[i] != nullptr) {
2562+
outs[i]->set_dims(common::make_ddim({H, tokens}));
2563+
outs[i]->set_dtype(DataType::FLOAT8_E4M3FN);
2564+
outs[i]->set_layout(x.layout());
2565+
}
2566+
2567+
if (scales[i] != nullptr) {
2568+
scales[i]->set_dims(common::make_ddim({tokens / 128, H}));
2569+
scales[i]->set_dtype(DataType::FLOAT32);
2570+
scales[i]->set_layout(x.layout());
2571+
}
2572+
}
2573+
2574+
PADDLE_ENFORCE_EQ(
2575+
sum_tokens,
2576+
M,
2577+
common::errors::InvalidArgument("Sum of tokens_per_expert (%d) must "
2578+
"equal the upper dims of Input(x) (%d)",
2579+
sum_tokens,
2580+
M));
2581+
PADDLE_ENFORCE_LE(
2582+
H,
2583+
65535 * 128,
2584+
common::errors::InvalidArgument("Currently only supports the hidden size "
2585+
"of Input(x) <= 65535 * 128, but got %d.",
2586+
H));
2587+
}
2588+
25072589
void YoloBoxXPUInferMeta(const MetaTensor& x,
25082590
const MetaTensor& x_max,
25092591
const MetaTensor& grid,

paddle/phi/infermeta/fusion.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,12 @@ void FusedTransposeSplitQuantInferMeta(const MetaTensor& x,
674674
std::vector<MetaTensor*> outs,
675675
std::vector<MetaTensor*> scales);
676676

677+
void FusedTransposeWLCHSplitQuantInferMeta(const MetaTensor& x,
678+
const IntArray& tokens_per_expert,
679+
bool pow_2_scales,
680+
std::vector<MetaTensor*> outs,
681+
std::vector<MetaTensor*> scales);
682+
677683
void YoloBoxXPUInferMeta(const MetaTensor& x,
678684
const MetaTensor& x_max,
679685
const MetaTensor& grid,

paddle/phi/kernels/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ if(((WITH_GPU) AND (CUDA_VERSION VERSION_LESS 12.0))
7777
"fusion/gpu/fused_stack_transpose_quant_kernel.cu"
7878
"fusion/gpu/fused_stack_quant_kernel.cu"
7979
"fusion/gpu/fused_transpose_split_quant_kernel.cu"
80+
"fusion/gpu/fused_transpose_wlch_split_quant_kernel.cu"
8081
"fusion/gpu/fused_swiglu_weighted_bwd_kernel.cu"
8182
"fusion/gpu/fused_weighted_swiglu_act_quant_kernel.cu")
8283
endif()

0 commit comments

Comments
 (0)