1
+ // Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ #include " helper.h"
16
+
17
+ template <typename T>
18
+ __global__ void set_preids_token_penalty_multi_scores_kernel (const bool *stop_flags,
19
+ int64_t *pre_ids,
20
+ const int64_t *input_ids,
21
+ const int *seq_lens_encoder,
22
+ const int *seq_lens_decoder,
23
+ const int64_t *step_idx,
24
+ const T *penalty_scores,
25
+ const T *frequency_score,
26
+ const T *presence_score,
27
+ const float *temperatures,
28
+ const int64_t *cur_len,
29
+ const int64_t *min_len,
30
+ const int64_t *eos_token_id,
31
+ const int64_t *bad_words_list,
32
+ int *repeat_times,
33
+ T *logits,
34
+ const int64_t bs,
35
+ const int64_t length,
36
+ const int64_t end_length,
37
+ const int64_t length_id,
38
+ const int64_t bad_words_length,
39
+ const int64_t length_input_ids) {
40
+ int bi = blockIdx .x ;
41
+ T *logits_now = logits + bi * length;
42
+ int tid = threadIdx .x ;
43
+
44
+ if (tid < bs && !stop_flags[tid]) {
45
+ int64_t *pre_ids_now = pre_ids + tid * length;
46
+ const int64_t *input_ids_now = input_ids + tid * length_input_ids;
47
+ const int seq_len_dec = seq_lens_decoder[tid];
48
+ const int seq_len_enc = seq_lens_encoder[tid];
49
+ if (seq_len_dec == 0 && seq_len_enc == 0 ) return ; // stoped
50
+
51
+ const int step_idx_now = step_idx[bi];
52
+ if (tid == 0 && step_idx_now >= 0 ) {
53
+ if (seq_len_enc > 0 ) { // encoder, get last token accord to seq_lens_encoder
54
+ pre_ids_now[step_idx_now] = input_ids_now[seq_len_enc - 1 ];
55
+ } else { // decoedr, get first token
56
+ pre_ids_now[step_idx_now] = input_ids_now[0 ];
57
+ }
58
+ }
59
+ }
60
+ __syncthreads ();
61
+ // min_length process
62
+ if (bi < bs) {
63
+ if (cur_len[bi] < min_len[bi]) {
64
+ if (tid < end_length) {
65
+ logits_now[eos_token_id[tid]] = -1e10 ;
66
+ }
67
+ }
68
+ }
69
+ // update repeat_times
70
+ int *repeat_times_now = repeat_times + bi * length;
71
+ const int64_t *pre_ids_now = pre_ids + bi * length_id;
72
+ for (int i = tid; i < length_id; i += blockDim .x ) {
73
+ int64_t id = pre_ids_now[i];
74
+ if (id < 0 ) break ;
75
+ atomicAdd (&repeat_times_now[id], 1 );
76
+ }
77
+ __syncthreads ();
78
+ // penalty_scores process
79
+ float alpha = static_cast <float >(penalty_scores[bi]);
80
+ float beta = static_cast <float >(frequency_score[bi]);
81
+ float gamma = static_cast <float >(presence_score[bi]);
82
+ for (int i = tid; i < length; i += blockDim .x ) {
83
+ int times = repeat_times_now[i];
84
+ float logit_now = static_cast <float >(logits_now[i]);
85
+ if (times != 0 ) {
86
+ logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha;
87
+ logit_now = logit_now - times * beta - gamma;
88
+ }
89
+ logits_now[i] = static_cast <T>(logit_now / temperatures[bi]);
90
+ }
91
+ __syncthreads ();
92
+ // bad_words process
93
+ for (int i = tid; i < bad_words_length; i += blockDim .x ) {
94
+ const int64_t bad_words_token_id = bad_words_list[i];
95
+ if (bad_words_token_id >= length || bad_words_token_id < 0 ) continue ;
96
+ logits_now[bad_words_token_id] = -1e10 ;
97
+ }
98
+ }
99
+
100
+ template <paddle::DataType D>
101
+ void set_preids_token_penalty_multi_scores (const paddle::Tensor& pre_ids,
102
+ const paddle::Tensor& input_ids,
103
+ const paddle::Tensor& seq_lens_encoder,
104
+ const paddle::Tensor& seq_lens_decoder,
105
+ const paddle::Tensor& step_idx,
106
+ const paddle::Tensor& stop_flags,
107
+ const paddle::Tensor& logits,
108
+ const paddle::Tensor& penalty_scores,
109
+ const paddle::Tensor& frequency_score,
110
+ const paddle::Tensor& presence_score,
111
+ const paddle::Tensor& temperatures,
112
+ const paddle::Tensor& bad_tokens,
113
+ const paddle::Tensor& cur_len,
114
+ const paddle::Tensor& min_len,
115
+ const paddle::Tensor& eos_token_id) {
116
+
117
+ typedef PDTraits<D> traits_;
118
+ typedef typename traits_::DataType DataType_;
119
+ typedef typename traits_::data_t data_t ;
120
+ auto cu_stream = logits.stream ();
121
+ std::vector<int64_t > shape = logits.shape ();
122
+ auto repeat_times = paddle::full (shape, 0 , paddle::DataType::INT32, pre_ids.place ());
123
+ int64_t bs = shape[0 ];
124
+ int64_t length = shape[1 ];
125
+ int64_t length_id = pre_ids.shape ()[1 ];
126
+ int64_t length_bad_words = bad_tokens.shape ()[0 ];
127
+ int64_t length_input_ids = input_ids.shape ()[1 ];
128
+
129
+ int64_t end_length = eos_token_id.shape ()[0 ];
130
+
131
+ set_preids_token_penalty_multi_scores_kernel<DataType_><<<bs, 1024 , 0 , cu_stream>>> (
132
+ stop_flags.data <bool >(),
133
+ const_cast <int64_t *>(pre_ids.data <int64_t >()),
134
+ input_ids.data <int64_t >(),
135
+ seq_lens_encoder.data <int >(),
136
+ seq_lens_decoder.data <int >(),
137
+ step_idx.data <int64_t >(),
138
+ reinterpret_cast <DataType_*>(const_cast <data_t *>(penalty_scores.data <data_t >())),
139
+ reinterpret_cast <DataType_*>(const_cast <data_t *>(frequency_score.data <data_t >())),
140
+ reinterpret_cast <DataType_*>(const_cast <data_t *>(presence_score.data <data_t >())),
141
+ temperatures.data <float >(),
142
+ cur_len.data <int64_t >(),
143
+ min_len.data <int64_t >(),
144
+ eos_token_id.data <int64_t >(),
145
+ bad_tokens.data <int64_t >(),
146
+ repeat_times.data <int >(),
147
+ reinterpret_cast <DataType_*>(const_cast <data_t *>(logits.data <data_t >())),
148
+ bs,
149
+ length,
150
+ end_length,
151
+ length_id,
152
+ length_bad_words,
153
+ length_input_ids
154
+ );
155
+ }
156
+
157
+ void SetPreidsTokenPenaltyMultiScores (const paddle::Tensor& pre_ids,
158
+ const paddle::Tensor& input_ids,
159
+ const paddle::Tensor& seq_lens_encoder,
160
+ const paddle::Tensor& seq_lens_decoder,
161
+ const paddle::Tensor& step_idx,
162
+ const paddle::Tensor& stop_flags,
163
+ const paddle::Tensor& logits,
164
+ const paddle::Tensor& penalty_scores,
165
+ const paddle::Tensor& frequency_scores,
166
+ const paddle::Tensor& presence_scores,
167
+ const paddle::Tensor& temperatures,
168
+ const paddle::Tensor& bad_tokens,
169
+ const paddle::Tensor& cur_len,
170
+ const paddle::Tensor& min_len,
171
+ const paddle::Tensor& eos_token_id) {
172
+
173
+ switch (logits.type ()) {
174
+ case paddle::DataType::BFLOAT16: {
175
+ return set_preids_token_penalty_multi_scores<paddle::DataType::BFLOAT16>(
176
+ pre_ids,
177
+ input_ids,
178
+ seq_lens_encoder,
179
+ seq_lens_decoder,
180
+ step_idx,
181
+ stop_flags,
182
+ logits,
183
+ penalty_scores,
184
+ frequency_scores,
185
+ presence_scores,
186
+ temperatures,
187
+ bad_tokens,
188
+ cur_len,
189
+ min_len,
190
+ eos_token_id
191
+ );
192
+ }
193
+ case paddle::DataType::FLOAT16: {
194
+ return set_preids_token_penalty_multi_scores<paddle::DataType::FLOAT16>(
195
+ pre_ids,
196
+ input_ids,
197
+ seq_lens_encoder,
198
+ seq_lens_decoder,
199
+ step_idx,
200
+ stop_flags,
201
+ logits,
202
+ penalty_scores,
203
+ frequency_scores,
204
+ presence_scores,
205
+ temperatures,
206
+ bad_tokens,
207
+ cur_len,
208
+ min_len,
209
+ eos_token_id
210
+ );
211
+ }
212
+ case paddle::DataType::FLOAT32: {
213
+ return set_preids_token_penalty_multi_scores<paddle::DataType::FLOAT32>(
214
+ pre_ids,
215
+ input_ids,
216
+ seq_lens_encoder,
217
+ seq_lens_decoder,
218
+ step_idx,
219
+ stop_flags,
220
+ logits,
221
+ penalty_scores,
222
+ frequency_scores,
223
+ presence_scores,
224
+ temperatures,
225
+ bad_tokens,
226
+ cur_len,
227
+ min_len,
228
+ eos_token_id
229
+ );
230
+ }
231
+ default : {
232
+ PD_THROW (
233
+ " NOT supported data type. "
234
+ " Only float16, bfloat16 and float32 are supported. " );
235
+ break ;
236
+ }
237
+ }
238
+ }
239
+
240
+ PD_BUILD_OP (set_preids_token_penalty_multi_scores)
241
+ .Inputs({" pre_ids" ,
242
+ " input_ids" ,
243
+ " seq_lens_encoder" ,
244
+ " seq_lens_decoder" ,
245
+ " step_idx" ,
246
+ " stop_flags" ,
247
+ " logits" ,
248
+ " penalty_scores" ,
249
+ " frequency_scores" ,
250
+ " presence_scores" ,
251
+ " temperatures" ,
252
+ " bad_tokens" ,
253
+ " cur_len" ,
254
+ " min_len" ,
255
+ " eos_token_id" })
256
+ .Outputs({" logits_out" , " pre_ids_out" })
257
+ .SetInplaceMap({{" logits" , " logits_out" }, {" pre_ids" , " pre_ids_out" }})
258
+ .SetKernelFn(PD_KERNEL(SetPreidsTokenPenaltyMultiScores));
0 commit comments