diff --git a/csrc/gpu/sample_kernels/sampling.cuh b/csrc/gpu/sample_kernels/sampling.cuh index 334747dd0126..4940070d2dfa 100644 --- a/csrc/gpu/sample_kernels/sampling.cuh +++ b/csrc/gpu/sample_kernels/sampling.cuh @@ -33,6 +33,15 @@ namespace sampling { using namespace cub; +#define DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, ...) \ + if (compute_capacity.first >= 8) { \ + constexpr uint32_t BLOCK_THREADS = 1024; \ + __VA_ARGS__ \ + } else { \ + constexpr uint32_t BLOCK_THREADS = 512; \ + __VA_ARGS__ \ + } + constexpr BlockScanAlgorithm SCAN_ALGO = BLOCK_SCAN_WARP_SCANS; constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS; @@ -277,17 +286,12 @@ template aggregate_gt_pivot{DType(0), 0}; for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(DType(0)); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + row_idx * d + - (i * BLOCK_THREADS + tx) * VEC_SIZE); + probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); } - DType probs_gt_pivot[VEC_SIZE]; + Pair probs_gt_pivot[VEC_SIZE]; #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { - probs_gt_pivot[j] = (probs_vec[j] > pivot) ? probs_vec[j] : DType(0); + probs_gt_pivot[j] = {(probs_vec[j] > pivot) ? probs_vec[j] : DType(0), + (probs_vec[j] > pivot && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; } - aggregate_gt_pivot += - BlockReduce(temp_storage.block_prim.reduce) - .Sum(probs_gt_pivot); + aggregate_gt_pivot += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_pair) + .Sum(probs_gt_pivot); if (tx == 0) { - temp_storage.data.block_aggregate.value = aggregate_gt_pivot; + temp_storage.data.block_aggregate.pair = aggregate_gt_pivot; } __syncthreads(); } - q = temp_storage.data.block_aggregate.value; - if (float(q) < top_p) { + q = temp_storage.data.block_aggregate.pair.value; + if (float(q) > 0 && float(q) < top_p) { + // top_p is not 0 break; + } else { + // top_p is 0, use top_k, k=1 + if (temp_storage.data.block_aggregate.pair.count < 1) { + break; + } } } __syncthreads(); if (tx == 0) { output[bx] = sampled_id; - if (float(q) >= top_p) { - // failed to sample within MAX_TOP_P_ROUNDS - if (success != nullptr) { - success[bx] = false; - } - } else { - if (success != nullptr) { - success[bx] = true; - } - } } } - template cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, - bool* success, - T* top_p_arr, uint32_t batch_size, const T* top_p_val, uint32_t d, @@ -395,13 +392,9 @@ cudaError_t TopPSamplingFromProb(T* probs, sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - IdType* row_indices_placeholder = nullptr; void* args[] = {&probs, &uniform_samples, &output, - &success, - &row_indices_placeholder, - &top_p_arr, &top_p_val, &d, &max_top_p_rounds}; @@ -425,4 +418,4 @@ cudaError_t TopPSamplingFromProb(T* probs, return cudaSuccess; } -} // namespace sampling +} // namespace sampling \ No newline at end of file diff --git a/csrc/gpu/sample_kernels/top_p_sampling_reject.cu b/csrc/gpu/sample_kernels/top_p_sampling_reject.cu index 1e98b0a81cd5..df62f2c12efe 100644 --- a/csrc/gpu/sample_kernels/top_p_sampling_reject.cu +++ b/csrc/gpu/sample_kernels/top_p_sampling_reject.cu @@ -16,7 +16,8 @@ #include "sample_kernels/sampling.cuh" std::vector TopPSamplingReject(const paddle::Tensor& probs, - const paddle::Tensor& top_p) { + const paddle::Tensor& top_p, + int seed) { std::vector probs_shape = probs.shape(); unsigned int batch_size = probs_shape[0]; unsigned int vocab_size = probs_shape[1]; @@ -24,40 +25,37 @@ std::vector TopPSamplingReject(const paddle::Tensor& probs, // default is 32 unsigned int max_top_p_rounds = 32; std::vector uniform_samples_shape = {batch_size, max_top_p_rounds}; - paddle::Tensor uniform_samples = paddle::experimental::uniform( - uniform_samples_shape, paddle::DataType::FLOAT32, 0, 1, 0, probs.place()); + paddle::Tensor uniform_samples = + paddle::experimental::uniform(uniform_samples_shape, + paddle::DataType::FLOAT32, + 0, + 1, + seed, + probs.place()); - // todo: add parameter for deterministic, now default is true - bool deterministic = true; - paddle::Tensor probs_input; - - probs_input = paddle::experimental::cast(probs, paddle::DataType::FLOAT32); auto cu_stream = probs.stream(); auto samples = - paddle::full({batch_size}, 0, paddle::DataType::INT32, probs.place()); - auto success = - paddle::full({batch_size}, 0, paddle::DataType::BOOL, probs.place()); + paddle::empty({batch_size, 1}, paddle::DataType::INT64, probs.place()); + + cudaError_t status; - cudaError_t status = - sampling::TopPSamplingFromProb(probs_input.data(), - uniform_samples.data(), - samples.data(), - success.data(), - nullptr, - batch_size, - top_p.data(), - vocab_size, - max_top_p_rounds, - deterministic, - cu_stream); + status = sampling::TopPSamplingFromProb( + const_cast(probs.data()), + uniform_samples.data(), + samples.data(), + batch_size, + top_p.data(), + vocab_size, + max_top_p_rounds, + true, + cu_stream); + PD_CHECK(status == cudaSuccess, "SamplingFromProbs failed with error code " + std::string(cudaGetErrorString(status))); - paddle::Tensor samples_output; - samples_output = paddle::experimental::cast(samples, paddle::DataType::INT64); - return {samples_output}; + return {samples}; } std::vector> TopPSamplingRejectInferShape( @@ -69,12 +67,13 @@ std::vector> TopPSamplingRejectInferShape( std::vector TopPSamplingRejectInferDtype( const paddle::DataType& probs_dtype, const paddle::DataType& top_p_shape) { - return {probs_dtype}; + return {paddle::DataType::INT64}; } PD_BUILD_OP(top_p_sampling_reject) .Inputs({"probs", "top_p"}) .Outputs({"samples"}) + .Attrs({"seed: int"}) .SetKernelFn(PD_KERNEL(TopPSamplingReject)) .SetInferShapeFn(PD_INFER_SHAPE(TopPSamplingRejectInferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(TopPSamplingRejectInferDtype)); + .SetInferDtypeFn(PD_INFER_DTYPE(TopPSamplingRejectInferDtype)); \ No newline at end of file diff --git a/csrc/gpu/test/python/test_top_p_sampling_reject.py b/csrc/gpu/test/python/test_top_p_sampling_reject.py new file mode 100644 index 000000000000..7a0605a3bdbe --- /dev/null +++ b/csrc/gpu/test/python/test_top_p_sampling_reject.py @@ -0,0 +1,63 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import paddle +from paddlenlp_ops import top_p_sampling_reject + +paddle.seed(2023) + +batch_size = 3 +vocab_size = 40080 +max_rounds = 32 + +class SetPreidsTokenPenaltyMultiScores(unittest.TestCase): + def test_top_p_sampling_reject_case1(self): + # top_p为1, 不同seed + pre_norm_prob_np = np.random.rand(batch_size, vocab_size).astype(np.float32) + + paddle_pre_norm_prob = paddle.to_tensor(pre_norm_prob_np) + paddle_norm_prob = paddle_pre_norm_prob / paddle_pre_norm_prob.sum(axis=-1, keepdim=True) + top_p_paddle = paddle.full((batch_size,), 1) + samples = top_p_sampling_reject(paddle_norm_prob, top_p_paddle, 0) + print(samples) + samples = top_p_sampling_reject(paddle_norm_prob, top_p_paddle, 1024) + print(samples) + samples = top_p_sampling_reject(paddle_norm_prob, top_p_paddle, 2033) + print(samples) + + def test_top_p_sampling_reject_case2(self): + # top_p为0 + pre_norm_prob_np = np.random.rand(batch_size, vocab_size).astype(np.float32) + + paddle_pre_norm_prob = paddle.to_tensor(pre_norm_prob_np) + paddle_norm_prob = paddle_pre_norm_prob / paddle_pre_norm_prob.sum(axis=-1, keepdim=True) + top_p_paddle = paddle.full((batch_size,), 0) + samples = top_p_sampling_reject(paddle_norm_prob, top_p_paddle, 0) + print(samples) + + def test_top_p_sampling_reject_case3(self): + # 不同batch的top_p值不同 + pre_norm_prob_np = np.random.rand(batch_size, vocab_size).astype(np.float32) + + paddle_pre_norm_prob = paddle.to_tensor(pre_norm_prob_np) + paddle_norm_prob = paddle_pre_norm_prob / paddle_pre_norm_prob.sum(axis=-1, keepdim=True) + top_p_paddle = paddle.uniform(shape=[batch_size,1], min=0, max=1) + samples = top_p_sampling_reject(paddle_norm_prob, top_p_paddle, 0) + print(samples) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/paddlenlp/experimental/transformers/generation_utils.py b/paddlenlp/experimental/transformers/generation_utils.py index 5133c944cdf6..336cf9de2397 100644 --- a/paddlenlp/experimental/transformers/generation_utils.py +++ b/paddlenlp/experimental/transformers/generation_utils.py @@ -333,7 +333,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): try: from paddlenlp_ops import top_p_sampling_reject - next_tokens = top_p_sampling_reject(probs, top_p) + next_tokens = top_p_sampling_reject(probs, top_p, 0) except: _, next_tokens = paddle.tensor.top_p_sampling(probs, top_p) @@ -677,7 +677,7 @@ def _post_process_( try: from paddlenlp_ops import top_p_sampling_reject - next_tokens = top_p_sampling_reject(probs, top_p) + next_tokens = top_p_sampling_reject(probs, top_p, 0) except: _, next_tokens = paddle.tensor.top_p_sampling(probs, top_p)