1
+ // Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
1
15
#include " helper.h"
2
16
3
17
template <int THREADBLOCK_SIZE>
4
- __global__ void update_inputs_kernel (
5
- bool *not_need_stop,
6
- int *seq_lens_this_time,
7
- int *seq_lens_encoder,
8
- int *seq_lens_decoder,
9
- int64_t *input_ids,
10
- const int64_t *stop_nums,
11
- const bool *stop_flags,
12
- const bool *is_block_step,
13
- const int64_t *next_tokens,
14
- const int bsz,
15
- const int max_bsz,
16
- const int input_ids_stride) {
18
+ __global__ void update_inputs_kernel (bool *not_need_stop,
19
+ int *seq_lens_this_time,
20
+ int *seq_lens_encoder,
21
+ int *seq_lens_decoder,
22
+ int64_t *input_ids,
23
+ const int64_t *stop_nums,
24
+ const bool *stop_flags,
25
+ const bool *is_block_step,
26
+ const int64_t *next_tokens,
27
+ const int bsz,
28
+ const int max_bsz,
29
+ const int input_ids_stride) {
17
30
int thread_idx = threadIdx .x ;
18
31
typedef cub::BlockReduce<int64_t , THREADBLOCK_SIZE> BlockReduce;
19
32
__shared__ typename BlockReduce::TempStorage temp_storage;
@@ -37,7 +50,10 @@ __global__ void update_inputs_kernel(
37
50
const int seq_len_encoder = seq_lens_encoder[thread_idx];
38
51
const int seq_len_decoder = seq_lens_decoder[thread_idx];
39
52
40
- seq_lens_decoder[thread_idx] = stop_flag_now ? 0 : (seq_len_decoder == 0 ? seq_len_encoder : seq_len_decoder + 1 );
53
+ seq_lens_decoder[thread_idx] =
54
+ stop_flag_now
55
+ ? 0
56
+ : (seq_len_decoder == 0 ? seq_len_encoder : seq_len_decoder + 1 );
41
57
42
58
seq_lens_this_time[thread_idx] = stop_flag_now ? 0 : 1 ;
43
59
seq_lens_encoder[thread_idx] = 0 ;
@@ -51,43 +67,38 @@ __global__ void update_inputs_kernel(
51
67
}
52
68
}
53
69
54
- void UpdateInputes (const paddle::Tensor& stop_flags,
55
- const paddle::Tensor& not_need_stop, // cpu
56
- const paddle::Tensor& seq_lens_this_time,
57
- const paddle::Tensor& seq_lens_encoder,
58
- const paddle::Tensor& seq_lens_decoder,
59
- const paddle::Tensor& input_ids,
60
- const paddle::Tensor& stop_nums,
61
- const paddle::Tensor& next_tokens,
62
- const paddle::Tensor& is_block_step) {
70
+ void UpdateInputes (const paddle::Tensor & stop_flags,
71
+ const paddle::Tensor & not_need_stop,
72
+ const paddle::Tensor & seq_lens_this_time,
73
+ const paddle::Tensor & seq_lens_encoder,
74
+ const paddle::Tensor & seq_lens_decoder,
75
+ const paddle::Tensor & input_ids,
76
+ const paddle::Tensor & stop_nums,
77
+ const paddle::Tensor & next_tokens,
78
+ const paddle::Tensor & is_block_step) {
63
79
const int max_bsz = stop_flags.shape ()[0 ];
64
80
const int now_bsz = seq_lens_this_time.shape ()[0 ];
65
81
const int input_ids_stride = input_ids.shape ()[1 ];
66
- auto not_need_stop_gpu = not_need_stop.copy_to (stop_flags.place (), false );
67
82
update_inputs_kernel<1024 ><<<1 , 1024 , 0 , input_ids.stream()>>> (
68
- const_cast <bool *>(not_need_stop_gpu.data <bool >()),
69
- const_cast <int *>(seq_lens_this_time.data <int >()),
70
- const_cast <int *>(seq_lens_encoder.data <int >()),
71
- const_cast <int *>(seq_lens_decoder.data <int >()),
72
- const_cast <int64_t *>(input_ids.data <int64_t >()),
73
- stop_nums.data <int64_t >(),
74
- stop_flags.data <bool >(),
75
- is_block_step.data <bool >(),
76
- next_tokens.data <int64_t >(),
77
- now_bsz,
78
- max_bsz,
79
- input_ids_stride
80
- );
81
- auto not_need_stop_cpu = not_need_stop_gpu.copy_to (not_need_stop.place (), false );
82
- bool *not_need_stop_data = const_cast <bool *>(not_need_stop.data <bool >());
83
- not_need_stop_data[0 ] = not_need_stop_cpu.data <bool >()[0 ];
83
+ const_cast <bool *>(not_need_stop.data <bool >()),
84
+ const_cast <int *>(seq_lens_this_time.data <int >()),
85
+ const_cast <int *>(seq_lens_encoder.data <int >()),
86
+ const_cast <int *>(seq_lens_decoder.data <int >()),
87
+ const_cast <int64_t *>(input_ids.data <int64_t >()),
88
+ stop_nums.data <int64_t >(),
89
+ stop_flags.data <bool >(),
90
+ is_block_step.data <bool >(),
91
+ next_tokens.data <int64_t >(),
92
+ now_bsz,
93
+ max_bsz,
94
+ input_ids_stride);
84
95
}
85
96
86
97
PD_BUILD_OP (update_inputs)
87
- .Inputs({" stop_flags" ,
88
- " not_need_stop" ,
89
- " seq_lens_this_time" ,
90
- " seq_lens_encoder" ,
98
+ .Inputs({" stop_flags" ,
99
+ " not_need_stop" ,
100
+ " seq_lens_this_time" ,
101
+ " seq_lens_encoder" ,
91
102
" seq_lens_decoder" ,
92
103
" input_ids" ,
93
104
" stop_nums" ,
0 commit comments