Skip to content

Commit cb22b34

Browse files
committed
Support wint2 unzip
1 parent e5a6a9f commit cb22b34

File tree

4 files changed

+499
-0
lines changed

4 files changed

+499
-0
lines changed
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
2+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
#pragma once
17+
18+
#include <cuda.h>
19+
#include <cuda_fp16.h>
20+
#include <cuda_bf16.h>
21+
#include <stdio.h>
22+
#include <cstdint>
23+
24+
25+
struct WeightOnlyTraits {
26+
static constexpr int32_t kGroupSize = 64;
27+
static constexpr int32_t kPackNum = 4;
28+
static constexpr int16_t kWeightMask = 0x3F;
29+
static constexpr int32_t kBBZip = 32;
30+
};
31+
32+
template <typename T, int64_t TileRows, int64_t TileColumns>
33+
struct Wint2UnzipFunctor {
34+
using ScaleComputeT = float;
35+
36+
static constexpr int64_t kTileRows = TileRows;
37+
static constexpr int64_t kTileColumns = TileColumns;
38+
39+
struct Arguments {
40+
const uint8_t* w_ptr;
41+
const T* w_scale_ptr;
42+
const float* w_code_scale_ptr;
43+
const float* w_code_zp_ptr;
44+
const T* w_super_scale_ptr;
45+
T* out_ptr;
46+
const int in_stride;
47+
};
48+
49+
__device__ void operator()(const Arguments& args, const int tid, const int num_threads) {
50+
int16_t shift_bits[4] = {9, 6, 3, 0};
51+
52+
for (int col = tid; col < kTileColumns; col += num_threads) {
53+
for (int row = 0; row < kTileRows; ++row) {
54+
int w_row = row / WeightOnlyTraits::kPackNum;
55+
int w_offset = w_row * args.in_stride + col;
56+
ScaleComputeT w = static_cast<ScaleComputeT>(args.w_ptr[w_offset]);
57+
ScaleComputeT w_code_scale = static_cast<ScaleComputeT>(args.w_code_scale_ptr[col]);
58+
ScaleComputeT w_code_zp = static_cast<ScaleComputeT>(args.w_code_zp_ptr[col]);
59+
60+
int16_t w_zipped_value = static_cast<int16_t>(floor(w * w_code_scale + w_code_zp + 0.5));
61+
int16_t shift_bit = shift_bits[row % WeightOnlyTraits::kPackNum];
62+
int16_t w_shifted_value = (w_zipped_value >> shift_bit) & WeightOnlyTraits::kWeightMask;
63+
64+
int w_scale_row = row / WeightOnlyTraits::kGroupSize;
65+
int w_scale_offset = w_scale_row * args.in_stride + col;
66+
T w_scale = static_cast<T>(args.w_scale_ptr[w_scale_offset]);
67+
68+
if (args.w_super_scale_ptr) {
69+
T w_super_scale = static_cast<T>(args.w_super_scale_ptr[col]);
70+
w_scale = w_scale * w_super_scale;
71+
}
72+
73+
args.out_ptr[row * kTileColumns + col] = static_cast<T>(w_scale) * (static_cast<T>(w_shifted_value) - static_cast<T>(WeightOnlyTraits::kBBZip));
74+
}
75+
}
76+
__syncthreads();
77+
}
78+
};
79+
80+
template <typename T, int64_t TileRows, int64_t TileColumns>
81+
__global__ void Wint2UnzipKernel(
82+
const uint8_t* w_ptr,
83+
const T* w_scale_ptr,
84+
const float* w_code_scale_ptr,
85+
const float* w_code_zp_ptr,
86+
const T* w_super_scale_ptr,
87+
T* output_tensor_ptr,
88+
const int64_t batch,
89+
const int64_t num_rows,
90+
const int64_t num_columns) {
91+
__shared__ T smem[TileRows * TileColumns];
92+
93+
int64_t block_start_column = blockIdx.x * TileColumns;
94+
95+
int64_t block_start_row = blockIdx.z * num_rows + blockIdx.y * TileRows;
96+
97+
int64_t block_start_w_row = block_start_row / WeightOnlyTraits::kPackNum;
98+
int64_t block_w_offset = block_start_w_row * num_columns + block_start_column;
99+
const uint8_t* block_w_ptr = w_ptr + block_w_offset;
100+
101+
int64_t block_start_w_scale_row = block_start_row / WeightOnlyTraits::kGroupSize;
102+
int64_t block_w_scale_offset = block_start_w_scale_row * num_columns + block_start_column;
103+
const T* block_w_scale_ptr = w_scale_ptr + block_w_scale_offset;
104+
105+
const float* block_w_code_scale_ptr = w_code_scale_ptr + blockIdx.z * num_columns + block_start_column;
106+
const float* block_w_code_zp_ptr = w_code_zp_ptr + blockIdx.z * num_columns + block_start_column;
107+
const T* block_w_super_scale_ptr = w_super_scale_ptr ? w_super_scale_ptr + blockIdx.z * num_columns + block_start_column : nullptr;
108+
109+
// unzip to shared memory
110+
typename Wint2UnzipFunctor<T, TileRows, TileColumns>::Arguments args{
111+
block_w_ptr, block_w_scale_ptr, block_w_code_scale_ptr, block_w_code_zp_ptr, block_w_super_scale_ptr, smem, num_columns};
112+
113+
Wint2UnzipFunctor<T, TileRows, TileColumns> winx_unzipper;
114+
winx_unzipper(args, threadIdx.x, blockDim.x);
115+
116+
// write back to global memory
117+
for (int row = 0; row < TileRows; ++row) {
118+
for (int col = 0; col < TileColumns; ++col) {
119+
int64_t global_row = block_start_row + row;
120+
int64_t global_col = block_start_column + col;
121+
output_tensor_ptr[global_row * num_columns + global_col] = smem[row * TileColumns + col];
122+
}
123+
}
124+
}
125+
126+
template <typename T>
127+
void Wint2UnzipKernelLauncher(
128+
const uint8_t* w_ptr,
129+
const T* w_scale_ptr,
130+
const float* w_code_scale_ptr,
131+
const float* w_code_zp_ptr,
132+
const T* w_super_scale_ptr,
133+
T* output_tensor_ptr,
134+
const int64_t batch,
135+
const int64_t num_rows,
136+
const int64_t num_columns) {
137+
constexpr int kTileRows = 64;
138+
constexpr int kTileColumns = 128;
139+
140+
const int num_threads = 128;
141+
const int block_dim_x = (num_columns + kTileColumns - 1) / kTileColumns;
142+
const int block_dim_y = (num_rows + kTileRows - 1) / kTileRows;
143+
144+
dim3 block_dim(num_threads, 1, 1);
145+
dim3 grid_dim(block_dim_x, block_dim_y, batch);
146+
// printf("Launch config: grid_dim={%d, %d, %d}, block_dim={%d, 1, 1}\n", block_dim_x, block_dim_y, batch, num_threads);
147+
148+
Wint2UnzipKernel<T, kTileRows, kTileColumns><<<grid_dim, block_dim>>>(
149+
w_ptr, w_scale_ptr, w_code_scale_ptr, w_code_zp_ptr, w_super_scale_ptr, output_tensor_ptr, batch, num_rows, num_columns);
150+
}
151+

csrc/gpu/moe/fused_moe/wint2_unzip.cu

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
2+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
#include "moe/wint2_unzip_impl_op.h"
17+
#include "helper.h"
18+
19+
template <paddle::DataType T>
20+
void Wint2UnzipKernel(const paddle::Tensor& w,
21+
const paddle::Tensor& w_scale,
22+
const paddle::Tensor& w_code_scale,
23+
const paddle::Tensor& w_code_zp,
24+
const paddle::Tensor& w_super_scale,
25+
paddle::Tensor& output_tensor,
26+
const std::string& quant_method) {
27+
using data_t = typename PDTraits<T>::data_t;
28+
using NvType = typename PDTraits<T>::DataType;
29+
30+
if (quant_method == "weight_only_int2") {
31+
const uint8_t* w_ptr = w.data<uint8_t>();
32+
const NvType* w_scale_ptr = reinterpret_cast<const NvType*>(w_scale.data<data_t>());
33+
const float* w_code_scale_ptr = w_code_scale.data<float>();
34+
const float* w_code_zp_ptr = w_code_zp.data<float>();
35+
const NvType* w_super_scale_kernel_ptr = w_super_scale.initialized() ? reinterpret_cast<const NvType*>(w_super_scale.data<data_t>()) : nullptr;
36+
37+
NvType* output_tensor_ptr = reinterpret_cast<NvType*>(output_tensor.data<data_t>());
38+
39+
const int64_t batch = output_tensor.shape()[0];
40+
const int64_t num_rows = output_tensor.shape()[1];
41+
const int64_t num_columns = output_tensor.shape()[2];
42+
Wint2UnzipKernelLauncher<NvType>(
43+
w_ptr,
44+
w_scale_ptr,
45+
w_code_scale_ptr,
46+
w_code_zp_ptr,
47+
w_super_scale_kernel_ptr,
48+
output_tensor_ptr,
49+
batch, num_rows, num_columns);
50+
} else {
51+
PD_THROW("Unsupported quant_method for Wint2Unzip.");
52+
}
53+
}
54+
55+
std::vector<paddle::Tensor> Wint2Unzip(const paddle::Tensor& w,
56+
const paddle::Tensor& w_scale,
57+
const paddle::Tensor& w_code_scale,
58+
const paddle::Tensor& w_code_zp,
59+
const paddle::Tensor& w_super_scale,
60+
const std::string& quant_method) {
61+
auto place = w.place();
62+
auto dtype = w_scale.dtype();
63+
64+
auto output_dims = w.dims();
65+
const int unzip_axis = 1;
66+
67+
if (quant_method == "weight_only_int2") {
68+
output_dims[unzip_axis] = output_dims[unzip_axis] * WeightOnlyTraits::kPackNum;
69+
// PD_CHECK(output_shape[unzip_axis] % WeightOnlyTraits::kGroupSize == 0, "unzip_size must be divisible by 64 in wint2!");
70+
} else {
71+
PD_THROW("Unsupported data type for Wint2Unzip");
72+
}
73+
auto output_tensor = GetEmptyTensor(output_dims, dtype, place);
74+
75+
switch (w_scale.dtype()) {
76+
case paddle::DataType::BFLOAT16:
77+
Wint2UnzipKernel<paddle::DataType::BFLOAT16>(w,
78+
w_scale,
79+
w_code_scale,
80+
w_code_zp,
81+
w_super_scale,
82+
output_tensor,
83+
quant_method);
84+
break;
85+
case paddle::DataType::FLOAT16:
86+
Wint2UnzipKernel<paddle::DataType::FLOAT16>(w,
87+
w_scale,
88+
w_code_scale,
89+
w_code_zp,
90+
w_super_scale,
91+
output_tensor,
92+
quant_method);
93+
break;
94+
default:
95+
PD_THROW("Unsupported data type for Wint2Unzip");
96+
}
97+
return {output_tensor};
98+
}
99+
100+
std::vector<std::vector<int64_t>> Wint2UnzipInferShape(
101+
const std::vector<int64_t>& w_shape,
102+
const std::vector<int64_t>& w_scale_shape,
103+
const std::vector<int64_t>& w_code_scale_shape,
104+
const std::vector<int64_t>& w_code_zp_shape,
105+
const std::vector<int64_t>& w_super_scale_shape,
106+
const std::string& quant_method) {
107+
std::vector<int64_t> output_shape(w_shape);
108+
const int unzip_axis = 1;
109+
if(quant_method == "weight_only_int2") {
110+
output_shape[unzip_axis] = w_shape[unzip_axis] * WeightOnlyTraits::kPackNum;
111+
PD_CHECK(output_shape[unzip_axis] % WeightOnlyTraits::kGroupSize == 0, "unzip_size must be divisible by 64 in wint2!");
112+
} else {
113+
PD_THROW("Unsupported data type for Wint2Unzip");
114+
}
115+
return {output_shape};
116+
}
117+
118+
std::vector<paddle::DataType> Wint2UnzipInferDtype(
119+
const paddle::DataType& w_dtype,
120+
const paddle::DataType& w_scale_dtype,
121+
const paddle::DataType& w_code_scale_dtype,
122+
const paddle::DataType& w_code_zp_dtype,
123+
const paddle::DataType& w_super_scale_dtype) {
124+
return {w_scale_dtype};
125+
}
126+
127+
PD_BUILD_OP(win2_unzip)
128+
.Inputs({"w", "w_scale", "w_code_scale", "w_code_zp", "w_super_scale"})
129+
.Outputs({"output_tensor"})
130+
.Attrs({"quant_method:std::string"})
131+
.SetKernelFn(PD_KERNEL(Wint2Unzip))
132+
.SetInferShapeFn(PD_INFER_SHAPE(Wint2UnzipInferShape))
133+
.SetInferDtypeFn(PD_INFER_DTYPE(Wint2UnzipInferDtype));

0 commit comments

Comments
 (0)