1
+ // Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ #include " fp8_fp8_half_cuda_core_gemm.h"
16
+ #include " cutlass/numeric_conversion.h"
17
+
18
+ template <typename InputType,
19
+ typename OutputType,
20
+ int32_t TILE_M,
21
+ int32_t TILE_N,
22
+ int32_t BLOCK_SIZE,
23
+ bool UseBias>
24
+ __global__ void cudaCoreGemm (InputType const * __restrict__ act,
25
+ InputType const * __restrict__ weight,
26
+ OutputType const * __restrict__ bias,
27
+ OutputType* __restrict__ output,
28
+ int32_t m,
29
+ int32_t n,
30
+ int32_t k,
31
+ float alpha) {
32
+ using VecType = int4 ;
33
+ static constexpr int32_t kStepK =
34
+ static_cast <int32_t >(128 / (8 * sizeof (InputType)));
35
+ static constexpr int32_t kTileK = kStepK * BLOCK_SIZE;
36
+ auto tileIdM = static_cast <int32_t >(blockIdx .x * TILE_M);
37
+ auto tileIdN = static_cast <int32_t >(blockIdx .y * TILE_N);
38
+ auto tid = static_cast <int32_t >(threadIdx .x );
39
+ float tile_a[kStepK ], tile_w[TILE_N * kStepK ];
40
+ float acc[TILE_M * TILE_N];
41
+
42
+ static_assert (kStepK % 4 == 0 );
43
+ using Converter = cutlass::NumericArrayConverter<float , InputType, 4 >;
44
+ using CvtSrcType = typename Converter::source_type;
45
+ using CvtResType = typename Converter::result_type;
46
+
47
+ static constexpr int32_t kCvtCount =
48
+ static_cast <int32_t >(sizeof (VecType) / sizeof (CvtSrcType));
49
+
50
+ #pragma unroll
51
+ for (int32_t i = 0 ; i < TILE_M * TILE_N; ++i) {
52
+ acc[i] = 0 ;
53
+ }
54
+ act += tileIdM * k;
55
+ weight += tileIdN * k;
56
+ output += tileIdM * n + tileIdN;
57
+ if constexpr (UseBias) {
58
+ bias += tileIdN;
59
+ }
60
+ for (int32_t idxK = tid * kStepK ; idxK < k; idxK += kTileK ) {
61
+ for (int32_t i = 0 ; i < TILE_N; ++i) {
62
+ auto tile_w_quantized =
63
+ reinterpret_cast <VecType const *>(weight + i * k + idxK)[0 ];
64
+ #pragma unroll
65
+ for (int32_t cvtIdx = 0 ; cvtIdx < kCvtCount ; ++cvtIdx) {
66
+ reinterpret_cast <CvtResType*>(tile_w)[i * kCvtCount + cvtIdx] =
67
+ Converter::convert (
68
+ reinterpret_cast <CvtSrcType*>(&tile_w_quantized)[cvtIdx]);
69
+ }
70
+ }
71
+ #pragma unroll
72
+ for (int32_t i = 0 ; i < TILE_M; ++i) {
73
+ auto tile_a_quantized =
74
+ reinterpret_cast <VecType const *>(act + i * k + idxK)[0 ];
75
+ #pragma unroll
76
+ for (int32_t cvtIdx = 0 ; cvtIdx < kCvtCount ; ++cvtIdx) {
77
+ reinterpret_cast <CvtResType*>(tile_a)[cvtIdx] = Converter::convert (
78
+ reinterpret_cast <CvtSrcType*>(&tile_a_quantized)[cvtIdx]);
79
+ }
80
+ #pragma unroll
81
+ for (int32_t j = 0 ; j < TILE_N; ++j) {
82
+ #pragma unroll
83
+ for (int32_t l = 0 ; l < kStepK ; ++l) {
84
+ acc[i * TILE_N + j] =
85
+ fma (tile_a[l], tile_w[j * kStepK + l], acc[i * TILE_N + j]);
86
+ }
87
+ }
88
+ }
89
+ }
90
+
91
+ typedef cub::WarpReduce<float > WarpReduce;
92
+
93
+ static constexpr int32_t kWarpSize = 32 ;
94
+ static constexpr int32_t kWarpNum = BLOCK_SIZE / kWarpSize ;
95
+ int32_t warpId = tid / kWarpSize , laneId = tid % kWarpSize ;
96
+ __shared__ float shmem[TILE_M * TILE_N * kWarpNum ];
97
+ __shared__ typename WarpReduce::TempStorage tempStorage[kWarpNum ];
98
+ #pragma unroll
99
+ for (int32_t mi = 0 ; mi < TILE_M; ++mi) {
100
+ #pragma unroll
101
+ for (int32_t ni = 0 ; ni < TILE_N; ++ni) {
102
+ float val = WarpReduce (tempStorage[warpId]).Sum (acc[mi * TILE_N + ni]);
103
+ if (laneId == 0 ) {
104
+ shmem[mi * TILE_N + ni + warpId * TILE_M * TILE_N] = val;
105
+ }
106
+ }
107
+ }
108
+
109
+ __syncthreads ();
110
+ for (int32_t ii = tid; ii < TILE_M * TILE_N; ii += BLOCK_SIZE) {
111
+ int32_t mid = ii / TILE_N, nid = ii % TILE_N;
112
+ float val = 0 ;
113
+ #pragma unroll
114
+ for (int32_t jj = 0 ; jj < kWarpNum ; ++jj) {
115
+ val += shmem[jj * TILE_M * TILE_N + ii];
116
+ }
117
+
118
+ if constexpr (UseBias) {
119
+ output[mid * n + nid] = static_cast <OutputType>(val * alpha + (float )*(bias+nid)) ;
120
+ } else {
121
+ output[mid * n + nid] = static_cast <OutputType>(val * alpha);
122
+ }
123
+ }
124
+ }
125
+
126
+ template <typename InputType,
127
+ typename OutputType,
128
+ int32_t TILE_M,
129
+ int32_t TILE_N,
130
+ int32_t BLOCK_SIZE>
131
+ void cudaCoreGemmKernel (GemmParams const & params) {
132
+ dim3 block (BLOCK_SIZE);
133
+ dim3 grid (params.m / TILE_M, params.n / TILE_N);
134
+ // std::cout << "m" << params.m << " n" << params.n << " k " << params.k << std::endl;
135
+
136
+ if (params.bias != nullptr ) {
137
+ cudaCoreGemm<InputType, OutputType, TILE_M, TILE_N, BLOCK_SIZE, true >
138
+ <<<grid, block, 0 , params.stream>>> (
139
+ reinterpret_cast <InputType const *>(params.act ),
140
+ reinterpret_cast <InputType const *>(params.weight ),
141
+ reinterpret_cast <OutputType const *>(params.bias ),
142
+ reinterpret_cast <OutputType*>(params.output ),
143
+ params.m ,
144
+ params.n ,
145
+ params.k ,
146
+ params.alpha );
147
+ } else {
148
+ cudaCoreGemm<InputType, OutputType, TILE_M, TILE_N, BLOCK_SIZE, false >
149
+ <<<grid, block, 0 , params.stream>>> (
150
+ reinterpret_cast <InputType const *>(params.act ),
151
+ reinterpret_cast <InputType const *>(params.weight ),
152
+ reinterpret_cast <OutputType const *>(params.bias ),
153
+ reinterpret_cast <OutputType*>(params.output ),
154
+ params.m ,
155
+ params.n ,
156
+ params.k ,
157
+ params.alpha );
158
+ }
159
+ }
160
+
161
+ template <typename InputType,
162
+ typename OutputType,
163
+ int TILE_M,
164
+ int TILE_N,
165
+ int BLOCK_SIZE>
166
+ bool cudaCoreGemmTemplateCaller (GemmParams const & params) {
167
+ constexpr int cudaCoreGemmTemplateMaxM = 16 ;
168
+ if (params.m == TILE_M) {
169
+ cudaCoreGemmKernel<InputType, OutputType, TILE_M, TILE_N, BLOCK_SIZE>(
170
+ params);
171
+ return true ;
172
+ }
173
+ if constexpr (TILE_M < cudaCoreGemmTemplateMaxM) {
174
+ return cudaCoreGemmTemplateCaller<InputType,
175
+ OutputType,
176
+ TILE_M + 1 ,
177
+ TILE_N,
178
+ BLOCK_SIZE>(params);
179
+ }
180
+ return false ;
181
+ }
182
+
183
+ template <typename InputType, typename OutputType>
184
+ bool cuda_core_gemm_launcher (GemmParams const & params) {
185
+ return cudaCoreGemmTemplateCaller<InputType, OutputType, 1 , 2 , 256 >(params);
186
+ }
187
+
188
+ template bool cuda_core_gemm_launcher<__nv_fp8_e4m3, __nv_bfloat16>(GemmParams const &);
189
+ template bool cuda_core_gemm_launcher<__nv_fp8_e4m3, half>(GemmParams const &);
190
+ template bool cuda_core_gemm_launcher<__nv_fp8_e5m2, __nv_bfloat16>(GemmParams const &);
191
+ template bool cuda_core_gemm_launcher<__nv_fp8_e5m2, half>(GemmParams const &);
0 commit comments