Skip to content

Commit 6902c3e

Browse files
authored
delete csrc/generation/reset_need_stop_value.cc (#8413)
1 parent 05acad5 commit 6902c3e

File tree

6 files changed

+67
-78
lines changed

6 files changed

+67
-78
lines changed

csrc/generation/reset_need_stop_value.cc

Lines changed: 0 additions & 12 deletions
This file was deleted.

csrc/generation/save_with_output_msg.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,15 @@ void SaveOutMmsg(const paddle::Tensor& x,
3232
if (rank_id > 0) return;
3333
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
3434
int64_t *x_data = x_cpu.data<int64_t>();
35+
auto not_need_stop_cpu = not_need_stop.copy_to(paddle::CPUPlace(), false);
36+
bool* not_need_stop_data = not_need_stop_cpu.data<bool>();
37+
3538
static struct msgdata msg_sed;
3639
static key_t key = ftok("./", 1);
3740
static int msgid = msgget(key, IPC_CREAT | 0666);
3841

3942
msg_sed.mtype = 1;
40-
bool not_need_stop_data = not_need_stop.data<bool>()[0];
41-
msg_sed.mtext[0] = not_need_stop_data ? 1 : -1;
43+
msg_sed.mtext[0] = not_need_stop_data[0] ? 1 : -1;
4244
int bsz = x.shape()[0];
4345
msg_sed.mtext[1] = bsz;
4446
for (int i = 2; i < bsz + 2; i++) {
@@ -55,4 +57,4 @@ PD_BUILD_OP(save_output)
5557
.Attrs({"rank_id: int64_t"})
5658
.Outputs({"x_out"})
5759
.SetInplaceMap({{"x", "x_out"}})
58-
.SetKernelFn(PD_KERNEL(SaveOutMmsg));
60+
.SetKernelFn(PD_KERNEL(SaveOutMmsg));

csrc/generation/update_inputs.cu

Lines changed: 55 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,32 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
115
#include "helper.h"
216

317
template <int THREADBLOCK_SIZE>
4-
__global__ void update_inputs_kernel(
5-
bool *not_need_stop,
6-
int *seq_lens_this_time,
7-
int *seq_lens_encoder,
8-
int *seq_lens_decoder,
9-
int64_t *input_ids,
10-
const int64_t *stop_nums,
11-
const bool *stop_flags,
12-
const bool *is_block_step,
13-
const int64_t *next_tokens,
14-
const int bsz,
15-
const int max_bsz,
16-
const int input_ids_stride) {
18+
__global__ void update_inputs_kernel(bool *not_need_stop,
19+
int *seq_lens_this_time,
20+
int *seq_lens_encoder,
21+
int *seq_lens_decoder,
22+
int64_t *input_ids,
23+
const int64_t *stop_nums,
24+
const bool *stop_flags,
25+
const bool *is_block_step,
26+
const int64_t *next_tokens,
27+
const int bsz,
28+
const int max_bsz,
29+
const int input_ids_stride) {
1730
int thread_idx = threadIdx.x;
1831
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
1932
__shared__ typename BlockReduce::TempStorage temp_storage;
@@ -37,7 +50,10 @@ __global__ void update_inputs_kernel(
3750
const int seq_len_encoder = seq_lens_encoder[thread_idx];
3851
const int seq_len_decoder = seq_lens_decoder[thread_idx];
3952

40-
seq_lens_decoder[thread_idx] = stop_flag_now ? 0 : (seq_len_decoder == 0 ? seq_len_encoder : seq_len_decoder + 1);
53+
seq_lens_decoder[thread_idx] =
54+
stop_flag_now
55+
? 0
56+
: (seq_len_decoder == 0 ? seq_len_encoder : seq_len_decoder + 1);
4157

4258
seq_lens_this_time[thread_idx] = stop_flag_now ? 0 : 1;
4359
seq_lens_encoder[thread_idx] = 0;
@@ -51,43 +67,38 @@ __global__ void update_inputs_kernel(
5167
}
5268
}
5369

54-
void UpdateInputes(const paddle::Tensor& stop_flags,
55-
const paddle::Tensor& not_need_stop, // cpu
56-
const paddle::Tensor& seq_lens_this_time,
57-
const paddle::Tensor& seq_lens_encoder,
58-
const paddle::Tensor& seq_lens_decoder,
59-
const paddle::Tensor& input_ids,
60-
const paddle::Tensor& stop_nums,
61-
const paddle::Tensor& next_tokens,
62-
const paddle::Tensor& is_block_step) {
70+
void UpdateInputes(const paddle::Tensor &stop_flags,
71+
const paddle::Tensor &not_need_stop,
72+
const paddle::Tensor &seq_lens_this_time,
73+
const paddle::Tensor &seq_lens_encoder,
74+
const paddle::Tensor &seq_lens_decoder,
75+
const paddle::Tensor &input_ids,
76+
const paddle::Tensor &stop_nums,
77+
const paddle::Tensor &next_tokens,
78+
const paddle::Tensor &is_block_step) {
6379
const int max_bsz = stop_flags.shape()[0];
6480
const int now_bsz = seq_lens_this_time.shape()[0];
6581
const int input_ids_stride = input_ids.shape()[1];
66-
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
6782
update_inputs_kernel<1024><<<1, 1024, 0, input_ids.stream()>>>(
68-
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
69-
const_cast<int*>(seq_lens_this_time.data<int>()),
70-
const_cast<int*>(seq_lens_encoder.data<int>()),
71-
const_cast<int*>(seq_lens_decoder.data<int>()),
72-
const_cast<int64_t*>(input_ids.data<int64_t>()),
73-
stop_nums.data<int64_t>(),
74-
stop_flags.data<bool>(),
75-
is_block_step.data<bool>(),
76-
next_tokens.data<int64_t>(),
77-
now_bsz,
78-
max_bsz,
79-
input_ids_stride
80-
);
81-
auto not_need_stop_cpu = not_need_stop_gpu.copy_to(not_need_stop.place(), false);
82-
bool *not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
83-
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
83+
const_cast<bool *>(not_need_stop.data<bool>()),
84+
const_cast<int *>(seq_lens_this_time.data<int>()),
85+
const_cast<int *>(seq_lens_encoder.data<int>()),
86+
const_cast<int *>(seq_lens_decoder.data<int>()),
87+
const_cast<int64_t *>(input_ids.data<int64_t>()),
88+
stop_nums.data<int64_t>(),
89+
stop_flags.data<bool>(),
90+
is_block_step.data<bool>(),
91+
next_tokens.data<int64_t>(),
92+
now_bsz,
93+
max_bsz,
94+
input_ids_stride);
8495
}
8596

8697
PD_BUILD_OP(update_inputs)
87-
.Inputs({"stop_flags",
88-
"not_need_stop",
89-
"seq_lens_this_time",
90-
"seq_lens_encoder",
98+
.Inputs({"stop_flags",
99+
"not_need_stop",
100+
"seq_lens_this_time",
101+
"seq_lens_encoder",
91102
"seq_lens_decoder",
92103
"input_ids",
93104
"stop_nums",

csrc/setup_cuda.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ def get_gencode_flags():
7272
"./generation/stop_generation_multi_ends_v2.cu",
7373
"./generation/update_inputs.cu",
7474
"./generation/get_output.cc",
75-
"./generation/reset_need_stop_value.cc",
7675
"./generation/save_with_output_msg.cc",
7776
"./generation/write_int8_cache_kv.cu",
7877
"./generation/step.cu",

llm/predictor.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,25 +49,15 @@
4949
AutoConfig,
5050
AutoModelForCausalLM,
5151
AutoTokenizer,
52-
ChatGLMv2Tokenizer,
5352
ChatGLMTokenizer,
53+
ChatGLMv2Tokenizer,
5454
LlamaTokenizer,
5555
PretrainedModel,
5656
PretrainedTokenizer,
5757
)
5858
from paddlenlp.utils.import_utils import import_module, is_paddlenlp_ops_available
5959
from paddlenlp.utils.log import logger
6060

61-
try:
62-
from paddlenlp_ops import reset_stop_value
63-
except (ImportError, ModuleNotFoundError):
64-
logger.warning(
65-
"if you run predictor.py with --inference_model argument, please ensure you install "
66-
"the paddlenlp_ops by following the instructions "
67-
"provided at https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md"
68-
)
69-
70-
7161
# Note(@RochardWooSJTU): MAX_BSZ must be the same as definition in get_output / save_output
7262
MAX_BSZ = 512
7363

@@ -242,7 +232,8 @@ def _preprocess(self, source):
242232
padding=True,
243233
# when use chat_template, it should not add special tokens
244234
# chatglm2 prefix-tokens can not be tokenized into ids
245-
add_special_tokens=self.tokenizer.chat_template is None or isinstance(self.tokenizer, (ChatGLMv2Tokenizer, ChatGLMTokenizer)),
235+
add_special_tokens=self.tokenizer.chat_template is None
236+
or isinstance(self.tokenizer, (ChatGLMv2Tokenizer, ChatGLMTokenizer)),
246237
)
247238
return tokenized_source
248239

@@ -877,7 +868,7 @@ def init_inputs(self, config: PredictorArgument):
877868
self.inputs["seq_lens_encoder"] = paddle.full(shape=[config.batch_size, 1], fill_value=0, dtype="int32")
878869
self.inputs["seq_lens_decoder"] = paddle.full(shape=[config.batch_size, 1], fill_value=0, dtype="int32")
879870
self.inputs["step_idx"] = paddle.full(shape=[config.batch_size, 1], fill_value=0, dtype="int64")
880-
self.inputs["not_need_stop"] = paddle.full(shape=[1], fill_value=False, dtype="bool").cpu()
871+
self.inputs["not_need_stop"] = paddle.full(shape=[1], fill_value=False, dtype="bool")
881872
self.inputs["stop_flags"] = paddle.full(shape=[config.batch_size, 1], fill_value=True, dtype="bool")
882873
self.inputs["next_tokens"] = paddle.full(shape=[config.batch_size, 1], fill_value=-1, dtype="int64")
883874
self.inputs["is_block_step"] = paddle.full(shape=[config.batch_size], fill_value=False, dtype="bool")
@@ -945,7 +936,7 @@ def _preprocess(self, source):
945936
self.inputs["seq_lens_decoder"][i : i + 1] = 0
946937
self.inputs["step_idx"][i : i + 1] = 0
947938
self.inputs["stop_flags"][i : i + 1] = False
948-
reset_stop_value(self.inputs["not_need_stop"])
939+
self.inputs["not_need_stop"][0] = True
949940
need_block_nums = (
950941
length + self.config.max_length + self.pre_cache_length + self.block_size - 1
951942
) // self.block_size
@@ -1010,7 +1001,6 @@ def predict(self, input_texts: str | list[str]):
10101001
for i in range(self.config.batch_size):
10111002
self.free_list.extend(self.used_list[i])
10121003
self.used_list[i] = []
1013-
reset_stop_value(self.inputs["not_need_stop"])
10141004

10151005
outputs = []
10161006
while len(outputs) < self.batch_size:
@@ -1147,7 +1137,6 @@ def predict(self, input_texts: str | list[str]):
11471137
for i in range(self.config.batch_size):
11481138
self.free_list.extend(self.used_list[i])
11491139
self.used_list[i] = []
1150-
reset_stop_value(self.inputs["not_need_stop"])
11511140

11521141
outputs = []
11531142
while len(outputs) < self.batch_size:

paddlenlp/experimental/transformers/generation_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -671,15 +671,15 @@ def _post_process_(
671671

672672
step_idx = paddle.where(model_kwargs["stop_flags"], model_kwargs["step_idx"], model_kwargs["step_idx"] + 1)
673673
paddle.assign(step_idx, model_kwargs["step_idx"])
674-
length_cond = paddle.greater_equal(model_kwargs["step_idx"], model_kwargs["max_dec_len"])
674+
length_cond = paddle.greater_equal(step_idx, model_kwargs["max_dec_len"])
675675
stop_flags = paddle.logical_or(model_kwargs["stop_flags"], length_cond)
676676
set_stop_value_multi_ends_v2(
677677
next_tokens, stop_flags, model_kwargs["seq_lens_this_time"], eos_token_id, model_kwargs["next_tokens"]
678678
) # multi ends
679679
paddle.assign(stop_flags, model_kwargs["stop_flags"])
680680
# update inputs
681681
update_inputs(
682-
model_kwargs["stop_flags"],
682+
stop_flags,
683683
model_kwargs["not_need_stop"],
684684
model_kwargs["seq_lens_this_time"],
685685
model_kwargs["seq_lens_encoder"],

0 commit comments

Comments
 (0)