|
| 1 | + |
| 2 | +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. |
| 3 | +// |
| 4 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +// you may not use this file except in compliance with the License. |
| 6 | +// You may obtain a copy of the License at |
| 7 | +// |
| 8 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +// |
| 10 | +// Unless required by applicable law or agreed to in writing, software |
| 11 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +// See the License for the specific language governing permissions and |
| 14 | +// limitations under the License. |
| 15 | + |
| 16 | +#pragma once |
| 17 | + |
| 18 | +#include <cuda.h> |
| 19 | +#include <cuda_fp16.h> |
| 20 | +#include <cuda_bf16.h> |
| 21 | +#include <stdio.h> |
| 22 | +#include <cstdint> |
| 23 | + |
| 24 | + |
| 25 | +struct WeightOnlyTraits { |
| 26 | + static constexpr int32_t kGroupSize = 64; |
| 27 | + static constexpr int32_t kPackNum = 4; |
| 28 | + static constexpr int16_t kWeightMask = 0x3F; |
| 29 | + static constexpr int32_t kBBZip = 32; |
| 30 | +}; |
| 31 | + |
| 32 | +template <typename T, int64_t TileRows, int64_t TileColumns> |
| 33 | +struct Wint2UnzipFunctor { |
| 34 | + using ScaleComputeT = float; |
| 35 | + |
| 36 | + static constexpr int64_t kTileRows = TileRows; |
| 37 | + static constexpr int64_t kTileColumns = TileColumns; |
| 38 | + |
| 39 | + struct Arguments { |
| 40 | + const uint8_t* w_ptr; |
| 41 | + const T* w_scale_ptr; |
| 42 | + const float* w_code_scale_ptr; |
| 43 | + const float* w_code_zp_ptr; |
| 44 | + const T* w_super_scale_ptr; |
| 45 | + T* out_ptr; |
| 46 | + const int in_stride; |
| 47 | + }; |
| 48 | + |
| 49 | + __device__ void operator()(const Arguments& args, const int tid, const int num_threads) { |
| 50 | + int16_t shift_bits[4] = {9, 6, 3, 0}; |
| 51 | + |
| 52 | + for (int col = tid; col < kTileColumns; col += num_threads) { |
| 53 | + for (int row = 0; row < kTileRows; ++row) { |
| 54 | + int w_row = row / WeightOnlyTraits::kPackNum; |
| 55 | + int w_offset = w_row * args.in_stride + col; |
| 56 | + ScaleComputeT w = static_cast<ScaleComputeT>(args.w_ptr[w_offset]); |
| 57 | + ScaleComputeT w_code_scale = static_cast<ScaleComputeT>(args.w_code_scale_ptr[col]); |
| 58 | + ScaleComputeT w_code_zp = static_cast<ScaleComputeT>(args.w_code_zp_ptr[col]); |
| 59 | + |
| 60 | + int16_t w_zipped_value = static_cast<int16_t>(floor(w * w_code_scale + w_code_zp + 0.5)); |
| 61 | + int16_t shift_bit = shift_bits[row % WeightOnlyTraits::kPackNum]; |
| 62 | + int16_t w_shifted_value = (w_zipped_value >> shift_bit) & WeightOnlyTraits::kWeightMask; |
| 63 | + |
| 64 | + int w_scale_row = row / WeightOnlyTraits::kGroupSize; |
| 65 | + int w_scale_offset = w_scale_row * args.in_stride + col; |
| 66 | + T w_scale = static_cast<T>(args.w_scale_ptr[w_scale_offset]); |
| 67 | + |
| 68 | + if (args.w_super_scale_ptr) { |
| 69 | + T w_super_scale = static_cast<T>(args.w_super_scale_ptr[col]); |
| 70 | + w_scale = w_scale * w_super_scale; |
| 71 | + } |
| 72 | + |
| 73 | + args.out_ptr[row * kTileColumns + col] = static_cast<T>(w_scale) * (static_cast<T>(w_shifted_value) - static_cast<T>(WeightOnlyTraits::kBBZip)); |
| 74 | + } |
| 75 | + } |
| 76 | + __syncthreads(); |
| 77 | + } |
| 78 | +}; |
| 79 | + |
| 80 | +template <typename T, int64_t TileRows, int64_t TileColumns> |
| 81 | +__global__ void Wint2UnzipKernel( |
| 82 | + const uint8_t* w_ptr, |
| 83 | + const T* w_scale_ptr, |
| 84 | + const float* w_code_scale_ptr, |
| 85 | + const float* w_code_zp_ptr, |
| 86 | + const T* w_super_scale_ptr, |
| 87 | + T* output_tensor_ptr, |
| 88 | + const int64_t batch, |
| 89 | + const int64_t num_rows, |
| 90 | + const int64_t num_columns) { |
| 91 | + __shared__ T smem[TileRows * TileColumns]; |
| 92 | + |
| 93 | + int64_t block_start_column = blockIdx.x * TileColumns; |
| 94 | + |
| 95 | + int64_t block_start_row = blockIdx.z * num_rows + blockIdx.y * TileRows; |
| 96 | + |
| 97 | + int64_t block_start_w_row = block_start_row / WeightOnlyTraits::kPackNum; |
| 98 | + int64_t block_w_offset = block_start_w_row * num_columns + block_start_column; |
| 99 | + const uint8_t* block_w_ptr = w_ptr + block_w_offset; |
| 100 | + |
| 101 | + int64_t block_start_w_scale_row = block_start_row / WeightOnlyTraits::kGroupSize; |
| 102 | + int64_t block_w_scale_offset = block_start_w_scale_row * num_columns + block_start_column; |
| 103 | + const T* block_w_scale_ptr = w_scale_ptr + block_w_scale_offset; |
| 104 | + |
| 105 | + const float* block_w_code_scale_ptr = w_code_scale_ptr + blockIdx.z * num_columns + block_start_column; |
| 106 | + const float* block_w_code_zp_ptr = w_code_zp_ptr + blockIdx.z * num_columns + block_start_column; |
| 107 | + const T* block_w_super_scale_ptr = w_super_scale_ptr ? w_super_scale_ptr + blockIdx.z * num_columns + block_start_column : nullptr; |
| 108 | + |
| 109 | + // unzip to shared memory |
| 110 | + typename Wint2UnzipFunctor<T, TileRows, TileColumns>::Arguments args{ |
| 111 | + block_w_ptr, block_w_scale_ptr, block_w_code_scale_ptr, block_w_code_zp_ptr, block_w_super_scale_ptr, smem, num_columns}; |
| 112 | + |
| 113 | + Wint2UnzipFunctor<T, TileRows, TileColumns> winx_unzipper; |
| 114 | + winx_unzipper(args, threadIdx.x, blockDim.x); |
| 115 | + |
| 116 | + // write back to global memory |
| 117 | + for (int row = 0; row < TileRows; ++row) { |
| 118 | + for (int col = 0; col < TileColumns; ++col) { |
| 119 | + int64_t global_row = block_start_row + row; |
| 120 | + int64_t global_col = block_start_column + col; |
| 121 | + output_tensor_ptr[global_row * num_columns + global_col] = smem[row * TileColumns + col]; |
| 122 | + } |
| 123 | + } |
| 124 | +} |
| 125 | + |
| 126 | +template <typename T> |
| 127 | +void Wint2UnzipKernelLauncher( |
| 128 | + const uint8_t* w_ptr, |
| 129 | + const T* w_scale_ptr, |
| 130 | + const float* w_code_scale_ptr, |
| 131 | + const float* w_code_zp_ptr, |
| 132 | + const T* w_super_scale_ptr, |
| 133 | + T* output_tensor_ptr, |
| 134 | + const int64_t batch, |
| 135 | + const int64_t num_rows, |
| 136 | + const int64_t num_columns) { |
| 137 | + constexpr int kTileRows = 64; |
| 138 | + constexpr int kTileColumns = 128; |
| 139 | + |
| 140 | + const int num_threads = 128; |
| 141 | + const int block_dim_x = (num_columns + kTileColumns - 1) / kTileColumns; |
| 142 | + const int block_dim_y = (num_rows + kTileRows - 1) / kTileRows; |
| 143 | + |
| 144 | + dim3 block_dim(num_threads, 1, 1); |
| 145 | + dim3 grid_dim(block_dim_x, block_dim_y, batch); |
| 146 | + // printf("Launch config: grid_dim={%d, %d, %d}, block_dim={%d, 1, 1}\n", block_dim_x, block_dim_y, batch, num_threads); |
| 147 | + |
| 148 | + Wint2UnzipKernel<T, kTileRows, kTileColumns><<<grid_dim, block_dim>>>( |
| 149 | + w_ptr, w_scale_ptr, w_code_scale_ptr, w_code_zp_ptr, w_super_scale_ptr, output_tensor_ptr, batch, num_rows, num_columns); |
| 150 | +} |
| 151 | + |
0 commit comments