Skip to content

Commit 0977858

Browse files
authored
[LLM INFER] Optimize fuse some kernels in postprocess (#9201)
* optimize fuse some kernels * optimize fuse some kernels * fix top_p reject * fix * ci * fix review * fix
1 parent 19a2e1f commit 0977858

10 files changed

+837
-58
lines changed

csrc/gpu/get_padding_offset_v2.cu

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,14 @@
1313
// limitations under the License.
1414

1515
#include "paddle/extension.h"
16+
#include "helper.h"
1617

17-
__global__ void RemovePaddingV2(int64_t *output_data,
18-
const int64_t *input_data,
19-
const int *seq_lens,
20-
const int *cum_offsets,
21-
const int sequence_length) {
22-
const int bi = blockIdx.x;
23-
const int tid = threadIdx.x;
24-
25-
for (int i = tid; i < seq_lens[bi]; i += blockDim.x) {
26-
const int tgt_seq_id = bi * sequence_length - cum_offsets[bi] + i;
27-
const int src_seq_id = bi * sequence_length + i;
28-
output_data[tgt_seq_id] = input_data[src_seq_id];
29-
}
30-
}
31-
32-
__global__ void GetPaddingOffsetKernelV2(int *padding_offset,
18+
__global__ void GetPaddingOffsetV2Kernel(int *padding_offset,
3319
int *cum_offsets_out,
3420
int *cu_seqlens_q,
3521
int *cu_seqlens_k,
22+
int64_t *output_data,
23+
const int64_t *input_data,
3624
const int *cum_offsets,
3725
const int *seq_lens,
3826
const int max_seq_len) {
@@ -42,8 +30,15 @@ __global__ void GetPaddingOffsetKernelV2(int *padding_offset,
4230
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
4331
for (int i = ti; i < seq_lens[bi]; i += blockDim.x) {
4432
padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset;
33+
const int tgt_seq_id = bi * max_seq_len - cum_offset + i;
34+
const int src_seq_id = bi * max_seq_len + i;
35+
output_data[tgt_seq_id] = input_data[src_seq_id];
4536
}
4637
if (ti == 0) {
38+
if (bi == 0) {
39+
cu_seqlens_q[0] = 0;
40+
cu_seqlens_k[0] = 0;
41+
}
4742
cum_offsets_out[bi] = cum_offset;
4843
int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi];
4944
cu_seqlens_q[bi + 1] = cum_seq_len;
@@ -64,24 +59,21 @@ std::vector<paddle::Tensor> GetPaddingOffsetV2(const paddle::Tensor& input_ids,
6459
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);
6560

6661
const int token_num_data = cpu_token_num.data<int64_t>()[0];
67-
auto x_remove_padding = paddle::full({token_num_data}, 0, paddle::DataType::INT64, input_ids.place());
68-
auto padding_offset = paddle::full({token_num_data}, 0, paddle::DataType::INT32, input_ids.place());
69-
auto cu_seqlens_q = paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
70-
auto cu_seqlens_k = paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
71-
int blockSize = min((token_num_data + 32 - 1) / 32 * 32, 128);
72-
GetPaddingOffsetKernelV2<<<bsz, 128, 0, cu_stream>>>(
62+
63+
auto x_remove_padding = GetEmptyTensor({token_num_data}, paddle::DataType::INT64, input_ids.place());
64+
auto padding_offset = GetEmptyTensor({token_num_data}, paddle::DataType::INT32, input_ids.place());
65+
auto cu_seqlens_q = GetEmptyTensor({bsz + 1}, paddle::DataType::INT32, input_ids.place());
66+
auto cu_seqlens_k = GetEmptyTensor({bsz + 1}, paddle::DataType::INT32, input_ids.place());
67+
68+
GetPaddingOffsetV2Kernel<<<bsz, 128, 0, cu_stream>>>(
7369
padding_offset.data<int>(),
7470
cum_offsets_out.data<int>(),
7571
cu_seqlens_q.data<int>(),
7672
cu_seqlens_k.data<int>(),
77-
cum_offsets.data<int>(),
78-
seq_len.data<int>(),
79-
seq_length);
80-
RemovePaddingV2<<<bsz, blockSize, 0, cu_stream>>>(
8173
x_remove_padding.data<int64_t>(),
8274
input_ids.data<int64_t>(),
75+
cum_offsets.data<int>(),
8376
seq_len.data<int>(),
84-
cum_offsets_out.data<int>(),
8577
seq_length);
8678
return {x_remove_padding, cum_offsets_out, padding_offset, cu_seqlens_q, cu_seqlens_k}; // , enc_token_num, dec_token_num};
8779
}
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "helper.h"
16+
17+
template<typename T>
18+
__global__ void set_preids_token_penalty_multi_scores_kernel(const bool *stop_flags,
19+
int64_t *pre_ids,
20+
const int64_t *input_ids,
21+
const int *seq_lens_encoder,
22+
const int *seq_lens_decoder,
23+
const int64_t *step_idx,
24+
const T *penalty_scores,
25+
const T *frequency_score,
26+
const T *presence_score,
27+
const float *temperatures,
28+
const int64_t *cur_len,
29+
const int64_t *min_len,
30+
const int64_t *eos_token_id,
31+
const int64_t *bad_words_list,
32+
int *repeat_times,
33+
T *logits,
34+
const int64_t bs,
35+
const int64_t length,
36+
const int64_t end_length,
37+
const int64_t length_id,
38+
const int64_t bad_words_length,
39+
const int64_t length_input_ids) {
40+
int bi = blockIdx.x;
41+
T *logits_now = logits + bi * length;
42+
int tid = threadIdx.x;
43+
44+
if (tid < bs && !stop_flags[tid]) {
45+
int64_t *pre_ids_now = pre_ids + tid * length;
46+
const int64_t *input_ids_now = input_ids + tid * length_input_ids;
47+
const int seq_len_dec = seq_lens_decoder[tid];
48+
const int seq_len_enc = seq_lens_encoder[tid];
49+
if (seq_len_dec == 0 && seq_len_enc == 0) return; // stoped
50+
51+
const int step_idx_now = step_idx[bi];
52+
if (tid == 0 && step_idx_now >= 0) {
53+
if (seq_len_enc > 0) { // encoder, get last token accord to seq_lens_encoder
54+
pre_ids_now[step_idx_now] = input_ids_now[seq_len_enc - 1];
55+
} else { // decoedr, get first token
56+
pre_ids_now[step_idx_now] = input_ids_now[0];
57+
}
58+
}
59+
}
60+
__syncthreads();
61+
// min_length process
62+
if (bi < bs) {
63+
if (cur_len[bi] < min_len[bi]) {
64+
if (tid < end_length) {
65+
logits_now[eos_token_id[tid]] = -1e10;
66+
}
67+
}
68+
}
69+
// update repeat_times
70+
int *repeat_times_now = repeat_times + bi * length;
71+
const int64_t *pre_ids_now = pre_ids + bi * length_id;
72+
for (int i = tid; i < length_id; i += blockDim.x) {
73+
int64_t id = pre_ids_now[i];
74+
if (id < 0) break;
75+
atomicAdd(&repeat_times_now[id], 1);
76+
}
77+
__syncthreads();
78+
// penalty_scores process
79+
float alpha = static_cast<float>(penalty_scores[bi]);
80+
float beta = static_cast<float>(frequency_score[bi]);
81+
float gamma = static_cast<float>(presence_score[bi]);
82+
for (int i = tid; i < length; i += blockDim.x) {
83+
int times = repeat_times_now[i];
84+
float logit_now = static_cast<float>(logits_now[i]);
85+
if (times != 0) {
86+
logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha;
87+
logit_now = logit_now - times * beta - gamma;
88+
}
89+
logits_now[i] = static_cast<T>(logit_now / temperatures[bi]);
90+
}
91+
__syncthreads();
92+
// bad_words process
93+
for (int i = tid; i < bad_words_length; i += blockDim.x) {
94+
const int64_t bad_words_token_id = bad_words_list[i];
95+
if (bad_words_token_id >= length || bad_words_token_id < 0) continue;
96+
logits_now[bad_words_token_id] = -1e10;
97+
}
98+
}
99+
100+
template <paddle::DataType D>
101+
void set_preids_token_penalty_multi_scores(const paddle::Tensor& pre_ids,
102+
const paddle::Tensor& input_ids,
103+
const paddle::Tensor& seq_lens_encoder,
104+
const paddle::Tensor& seq_lens_decoder,
105+
const paddle::Tensor& step_idx,
106+
const paddle::Tensor& stop_flags,
107+
const paddle::Tensor& logits,
108+
const paddle::Tensor& penalty_scores,
109+
const paddle::Tensor& frequency_score,
110+
const paddle::Tensor& presence_score,
111+
const paddle::Tensor& temperatures,
112+
const paddle::Tensor& bad_tokens,
113+
const paddle::Tensor& cur_len,
114+
const paddle::Tensor& min_len,
115+
const paddle::Tensor& eos_token_id) {
116+
117+
typedef PDTraits<D> traits_;
118+
typedef typename traits_::DataType DataType_;
119+
typedef typename traits_::data_t data_t;
120+
auto cu_stream = logits.stream();
121+
std::vector<int64_t> shape = logits.shape();
122+
auto repeat_times = paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place());
123+
int64_t bs = shape[0];
124+
int64_t length = shape[1];
125+
int64_t length_id = pre_ids.shape()[1];
126+
int64_t length_bad_words = bad_tokens.shape()[0];
127+
int64_t length_input_ids = input_ids.shape()[1];
128+
129+
int64_t end_length = eos_token_id.shape()[0];
130+
131+
set_preids_token_penalty_multi_scores_kernel<DataType_><<<bs, 1024, 0, cu_stream>>>(
132+
stop_flags.data<bool>(),
133+
const_cast<int64_t*>(pre_ids.data<int64_t>()),
134+
input_ids.data<int64_t>(),
135+
seq_lens_encoder.data<int>(),
136+
seq_lens_decoder.data<int>(),
137+
step_idx.data<int64_t>(),
138+
reinterpret_cast<DataType_*>(const_cast<data_t*>(penalty_scores.data<data_t>())),
139+
reinterpret_cast<DataType_*>(const_cast<data_t*>(frequency_score.data<data_t>())),
140+
reinterpret_cast<DataType_*>(const_cast<data_t*>(presence_score.data<data_t>())),
141+
temperatures.data<float>(),
142+
cur_len.data<int64_t>(),
143+
min_len.data<int64_t>(),
144+
eos_token_id.data<int64_t>(),
145+
bad_tokens.data<int64_t>(),
146+
repeat_times.data<int>(),
147+
reinterpret_cast<DataType_*>(const_cast<data_t*>(logits.data<data_t>())),
148+
bs,
149+
length,
150+
end_length,
151+
length_id,
152+
length_bad_words,
153+
length_input_ids
154+
);
155+
}
156+
157+
void SetPreidsTokenPenaltyMultiScores(const paddle::Tensor& pre_ids,
158+
const paddle::Tensor& input_ids,
159+
const paddle::Tensor& seq_lens_encoder,
160+
const paddle::Tensor& seq_lens_decoder,
161+
const paddle::Tensor& step_idx,
162+
const paddle::Tensor& stop_flags,
163+
const paddle::Tensor& logits,
164+
const paddle::Tensor& penalty_scores,
165+
const paddle::Tensor& frequency_scores,
166+
const paddle::Tensor& presence_scores,
167+
const paddle::Tensor& temperatures,
168+
const paddle::Tensor& bad_tokens,
169+
const paddle::Tensor& cur_len,
170+
const paddle::Tensor& min_len,
171+
const paddle::Tensor& eos_token_id) {
172+
173+
switch (logits.type()) {
174+
case paddle::DataType::BFLOAT16: {
175+
return set_preids_token_penalty_multi_scores<paddle::DataType::BFLOAT16>(
176+
pre_ids,
177+
input_ids,
178+
seq_lens_encoder,
179+
seq_lens_decoder,
180+
step_idx,
181+
stop_flags,
182+
logits,
183+
penalty_scores,
184+
frequency_scores,
185+
presence_scores,
186+
temperatures,
187+
bad_tokens,
188+
cur_len,
189+
min_len,
190+
eos_token_id
191+
);
192+
}
193+
case paddle::DataType::FLOAT16: {
194+
return set_preids_token_penalty_multi_scores<paddle::DataType::FLOAT16>(
195+
pre_ids,
196+
input_ids,
197+
seq_lens_encoder,
198+
seq_lens_decoder,
199+
step_idx,
200+
stop_flags,
201+
logits,
202+
penalty_scores,
203+
frequency_scores,
204+
presence_scores,
205+
temperatures,
206+
bad_tokens,
207+
cur_len,
208+
min_len,
209+
eos_token_id
210+
);
211+
}
212+
case paddle::DataType::FLOAT32: {
213+
return set_preids_token_penalty_multi_scores<paddle::DataType::FLOAT32>(
214+
pre_ids,
215+
input_ids,
216+
seq_lens_encoder,
217+
seq_lens_decoder,
218+
step_idx,
219+
stop_flags,
220+
logits,
221+
penalty_scores,
222+
frequency_scores,
223+
presence_scores,
224+
temperatures,
225+
bad_tokens,
226+
cur_len,
227+
min_len,
228+
eos_token_id
229+
);
230+
}
231+
default: {
232+
PD_THROW(
233+
"NOT supported data type. "
234+
"Only float16, bfloat16 and float32 are supported. ");
235+
break;
236+
}
237+
}
238+
}
239+
240+
PD_BUILD_OP(set_preids_token_penalty_multi_scores)
241+
.Inputs({"pre_ids",
242+
"input_ids",
243+
"seq_lens_encoder",
244+
"seq_lens_decoder",
245+
"step_idx",
246+
"stop_flags",
247+
"logits",
248+
"penalty_scores",
249+
"frequency_scores",
250+
"presence_scores",
251+
"temperatures",
252+
"bad_tokens",
253+
"cur_len",
254+
"min_len",
255+
"eos_token_id"})
256+
.Outputs({"logits_out", "pre_ids_out"})
257+
.SetInplaceMap({{"logits", "logits_out"}, {"pre_ids", "pre_ids_out"}})
258+
.SetKernelFn(PD_KERNEL(SetPreidsTokenPenaltyMultiScores));

0 commit comments

Comments
 (0)