Skip to content

Commit 5f40ee8

Browse files
committed
use fp8 cuda core gemm kernel when M<=4
1 parent 85333aa commit 5f40ee8

File tree

5 files changed

+283
-20
lines changed

5 files changed

+283
-20
lines changed
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "fp8_fp8_half_cuda_core_gemm.h"
16+
#include "cutlass/numeric_conversion.h"
17+
18+
template <typename InputType,
19+
typename OutputType,
20+
int32_t TILE_M,
21+
int32_t TILE_N,
22+
int32_t BLOCK_SIZE,
23+
bool UseBias>
24+
__global__ void cudaCoreGemm(InputType const* __restrict__ act,
25+
InputType const* __restrict__ weight,
26+
OutputType const* __restrict__ bias,
27+
OutputType* __restrict__ output,
28+
int32_t m,
29+
int32_t n,
30+
int32_t k,
31+
float alpha) {
32+
using VecType = int4;
33+
static constexpr int32_t kStepK =
34+
static_cast<int32_t>(128 / (8 * sizeof(InputType)));
35+
static constexpr int32_t kTileK = kStepK * BLOCK_SIZE;
36+
auto tileIdM = static_cast<int32_t>(blockIdx.x * TILE_M);
37+
auto tileIdN = static_cast<int32_t>(blockIdx.y * TILE_N);
38+
auto tid = static_cast<int32_t>(threadIdx.x);
39+
float tile_a[kStepK], tile_w[TILE_N * kStepK];
40+
float acc[TILE_M * TILE_N];
41+
42+
static_assert(kStepK % 4 == 0);
43+
using Converter = cutlass::NumericArrayConverter<float, InputType, 4>;
44+
using CvtSrcType = typename Converter::source_type;
45+
using CvtResType = typename Converter::result_type;
46+
47+
static constexpr int32_t kCvtCount =
48+
static_cast<int32_t>(sizeof(VecType) / sizeof(CvtSrcType));
49+
50+
#pragma unroll
51+
for (int32_t i = 0; i < TILE_M * TILE_N; ++i) {
52+
acc[i] = 0;
53+
}
54+
act += tileIdM * k;
55+
weight += tileIdN * k;
56+
output += tileIdM * n + tileIdN;
57+
if constexpr (UseBias) {
58+
bias += tileIdN;
59+
}
60+
for (int32_t idxK = tid * kStepK; idxK < k; idxK += kTileK) {
61+
for (int32_t i = 0; i < TILE_N; ++i) {
62+
auto tile_w_quantized =
63+
reinterpret_cast<VecType const*>(weight + i * k + idxK)[0];
64+
#pragma unroll
65+
for (int32_t cvtIdx = 0; cvtIdx < kCvtCount; ++cvtIdx) {
66+
reinterpret_cast<CvtResType*>(tile_w)[i * kCvtCount + cvtIdx] =
67+
Converter::convert(
68+
reinterpret_cast<CvtSrcType*>(&tile_w_quantized)[cvtIdx]);
69+
}
70+
}
71+
#pragma unroll
72+
for (int32_t i = 0; i < TILE_M; ++i) {
73+
auto tile_a_quantized =
74+
reinterpret_cast<VecType const*>(act + i * k + idxK)[0];
75+
#pragma unroll
76+
for (int32_t cvtIdx = 0; cvtIdx < kCvtCount; ++cvtIdx) {
77+
reinterpret_cast<CvtResType*>(tile_a)[cvtIdx] = Converter::convert(
78+
reinterpret_cast<CvtSrcType*>(&tile_a_quantized)[cvtIdx]);
79+
}
80+
#pragma unroll
81+
for (int32_t j = 0; j < TILE_N; ++j) {
82+
#pragma unroll
83+
for (int32_t l = 0; l < kStepK; ++l) {
84+
acc[i * TILE_N + j] =
85+
fma(tile_a[l], tile_w[j * kStepK + l], acc[i * TILE_N + j]);
86+
}
87+
}
88+
}
89+
}
90+
91+
typedef cub::WarpReduce<float> WarpReduce;
92+
93+
static constexpr int32_t kWarpSize = 32;
94+
static constexpr int32_t kWarpNum = BLOCK_SIZE / kWarpSize;
95+
int32_t warpId = tid / kWarpSize, laneId = tid % kWarpSize;
96+
__shared__ float shmem[TILE_M * TILE_N * kWarpNum];
97+
__shared__ typename WarpReduce::TempStorage tempStorage[kWarpNum];
98+
#pragma unroll
99+
for (int32_t mi = 0; mi < TILE_M; ++mi) {
100+
#pragma unroll
101+
for (int32_t ni = 0; ni < TILE_N; ++ni) {
102+
float val = WarpReduce(tempStorage[warpId]).Sum(acc[mi * TILE_N + ni]);
103+
if (laneId == 0) {
104+
shmem[mi * TILE_N + ni + warpId * TILE_M * TILE_N] = val;
105+
}
106+
}
107+
}
108+
109+
__syncthreads();
110+
for (int32_t ii = tid; ii < TILE_M * TILE_N; ii += BLOCK_SIZE) {
111+
int32_t mid = ii / TILE_N, nid = ii % TILE_N;
112+
float val = 0;
113+
#pragma unroll
114+
for (int32_t jj = 0; jj < kWarpNum; ++jj) {
115+
val += shmem[jj * TILE_M * TILE_N + ii];
116+
}
117+
118+
if constexpr (UseBias) {
119+
output[mid * n + nid] = static_cast<OutputType>(val * alpha + (float)*(bias+nid)) ;
120+
} else {
121+
output[mid * n + nid] = static_cast<OutputType>(val * alpha);
122+
}
123+
}
124+
}
125+
126+
template <typename InputType,
127+
typename OutputType,
128+
int32_t TILE_M,
129+
int32_t TILE_N,
130+
int32_t BLOCK_SIZE>
131+
void cudaCoreGemmKernel(GemmParams const& params) {
132+
dim3 block(BLOCK_SIZE);
133+
dim3 grid(params.m / TILE_M, params.n / TILE_N);
134+
// std::cout << "m" << params.m << " n" << params.n << " k " << params.k << std::endl;
135+
136+
if (params.bias != nullptr) {
137+
cudaCoreGemm<InputType, OutputType, TILE_M, TILE_N, BLOCK_SIZE, true>
138+
<<<grid, block, 0, params.stream>>>(
139+
reinterpret_cast<InputType const*>(params.act),
140+
reinterpret_cast<InputType const*>(params.weight),
141+
reinterpret_cast<OutputType const*>(params.bias),
142+
reinterpret_cast<OutputType*>(params.output),
143+
params.m,
144+
params.n,
145+
params.k,
146+
params.alpha);
147+
} else {
148+
cudaCoreGemm<InputType, OutputType, TILE_M, TILE_N, BLOCK_SIZE, false>
149+
<<<grid, block, 0, params.stream>>>(
150+
reinterpret_cast<InputType const*>(params.act),
151+
reinterpret_cast<InputType const*>(params.weight),
152+
reinterpret_cast<OutputType const*>(params.bias),
153+
reinterpret_cast<OutputType*>(params.output),
154+
params.m,
155+
params.n,
156+
params.k,
157+
params.alpha);
158+
}
159+
}
160+
161+
template <typename InputType,
162+
typename OutputType,
163+
int TILE_M,
164+
int TILE_N,
165+
int BLOCK_SIZE>
166+
bool cudaCoreGemmTemplateCaller(GemmParams const& params) {
167+
constexpr int cudaCoreGemmTemplateMaxM = 16;
168+
if (params.m == TILE_M) {
169+
cudaCoreGemmKernel<InputType, OutputType, TILE_M, TILE_N, BLOCK_SIZE>(
170+
params);
171+
return true;
172+
}
173+
if constexpr (TILE_M < cudaCoreGemmTemplateMaxM) {
174+
return cudaCoreGemmTemplateCaller<InputType,
175+
OutputType,
176+
TILE_M + 1,
177+
TILE_N,
178+
BLOCK_SIZE>(params);
179+
}
180+
return false;
181+
}
182+
183+
template <typename InputType, typename OutputType>
184+
bool cuda_core_gemm_launcher(GemmParams const& params) {
185+
return cudaCoreGemmTemplateCaller<InputType, OutputType, 1, 2, 256>(params);
186+
}
187+
188+
template bool cuda_core_gemm_launcher<__nv_fp8_e4m3, __nv_bfloat16>(GemmParams const&);
189+
template bool cuda_core_gemm_launcher<__nv_fp8_e4m3, half>(GemmParams const&);
190+
template bool cuda_core_gemm_launcher<__nv_fp8_e5m2, __nv_bfloat16>(GemmParams const&);
191+
template bool cuda_core_gemm_launcher<__nv_fp8_e5m2, half>(GemmParams const&);
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "fp8_common.h" // NOLINT
18+
19+
typedef struct {
20+
void const* act;
21+
void const* weight;
22+
void const* bias;
23+
void* output;
24+
int32_t m, n, k;
25+
float alpha;
26+
cudaStream_t stream;
27+
} GemmParams;
28+
29+
inline bool enable_cuda_core_fp8_gemm() {
30+
static const char* enable_cuda_core_fp8_env = std::getenv("FLAGS_cuda_core_fp8_gemm");
31+
static const bool enable_cuda_core_fp8_gemm =
32+
enable_cuda_core_fp8_env != nullptr && std::string(enable_cuda_core_fp8_env) == "1";
33+
return enable_cuda_core_fp8_gemm;
34+
}
35+
36+
template <typename InputType, typename OutputType>
37+
bool cuda_core_gemm_launcher(GemmParams const& params);

csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h"
1818
#include "fp8_common.h" // NOLINT
19+
#include "fp8_fp8_half_cuda_core_gemm.h"
1920

2021
std::vector<paddle::Tensor> cutlass_fp8_fp8_half_gemm(
2122
const paddle::Tensor& x,
@@ -116,26 +117,55 @@ std::vector<paddle::Tensor> cutlass_fp8_fp8_half_gemm(
116117
}
117118
}
118119

119-
GemmEpilogueAllParams params = {
120-
x_ptr,
121-
y_ptr,
122-
out_ptr,
123-
scale,
124-
M,
125-
N,
126-
K,
127-
lda,
128-
ldb,
129-
ldd,
130-
batch_count,
131-
place,
132-
stream,
133-
sm_version,
134-
0.01, // for leaky_relu
135-
bias_data,
136-
bias_dims,
137-
fuse_gemm_config};
138-
fp8_fp8_gemm_scale_bias_act(params);
120+
if (M <=4 && trans_y && !trans_x && act == "noact" && enable_cuda_core_fp8_gemm()) {
121+
GemmParams params = {
122+
x_ptr,
123+
y_ptr,
124+
bias_data,
125+
out_ptr,
126+
M,
127+
N,
128+
K,
129+
scale,
130+
stream,
131+
};
132+
133+
if (x.dtype() == phi::DataType::FLOAT8_E4M3FN)
134+
{
135+
if(output_dtype == "bfloat16") {
136+
cuda_core_gemm_launcher<__nv_fp8_e4m3, __nv_bfloat16>(params);
137+
138+
} else {
139+
cuda_core_gemm_launcher<__nv_fp8_e4m3, half>(params);
140+
}
141+
} else {
142+
if(output_dtype == "bfloat16") {
143+
cuda_core_gemm_launcher<__nv_fp8_e5m2, __nv_bfloat16>(params);
144+
} else {
145+
cuda_core_gemm_launcher<__nv_fp8_e5m2, half>(params);
146+
}
147+
}
148+
} else {
149+
GemmEpilogueAllParams params = {x_ptr,
150+
y_ptr,
151+
out_ptr,
152+
scale,
153+
M,
154+
N,
155+
K,
156+
lda,
157+
ldb,
158+
ldd,
159+
batch_count,
160+
place,
161+
stream,
162+
sm_version,
163+
0.01, // for leaky_relu
164+
bias_data,
165+
bias_dims,
166+
fuse_gemm_config};
167+
fp8_fp8_gemm_scale_bias_act(params);
168+
}
139169
return {out};
140170
}
141171

csrc/setup_cuda.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def get_gencode_flags():
159159
sources += find_end_files("gpu/cutlass_kernels/fp8_gemm_fused/autogen", ".cu")
160160
sources += [
161161
"gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu",
162+
"gpu/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.cu",
162163
"gpu/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu",
163164
]
164165

llm/docs/predict/best_practices.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ PaddleNLP 提供了多种环境变量,用于优化推理性能和资源使用
99

1010
- `FLAGS_cublaslt_device_best_config`:在 FLAGS_enable_blaslt_global_search 设为1的前提下,使用`FLAGS_cublaslt_device_best_config`来指定离线调优出的 int8 gemm 配置文件,默认值为""。配置文件可以通过`PaddleNLP/csrc/utils/tune_cublaslt_int8_gemm.py`产出,该脚本会自动搜索当前输入大小下 cuBLASLt 提供的最优 gemm 配置并将结果记录下来,需要注意的是不同的 CUDA 版本需要分别 tune。推理 A8W8模型并且 FLAGS_enable_blaslt_global_search 设为1时使用此 FLAG 会获得更优的性能。
1111

12+
- `FLAGS_cuda_core_int8_gemm`:是否开启小 Batch Int8 Gemm优化,默认值不开启。设为1可开启,推理A8W8模型时性能会更好。
13+
14+
- `FLAGS_cuda_core_fp8_gemm`:是否开启小 Batch FP8 Gemm优化,默认值不开启。设为1可开启,性能会更好。
15+
1216
**GQA 优化**
1317

1418
- `FLAGS_use_xqa_optim`:gpa 是否开启 xqa 优化,默认值为0,表示不开启。gqa 模型(如 llama3/3.1、qwen2)设为1性能会更好。

0 commit comments

Comments
 (0)