From 42632e84d2f1abf47a4c9f92345e917daa46ddd4 Mon Sep 17 00:00:00 2001 From: gzy19990617 Date: Thu, 26 Sep 2024 04:13:44 +0000 Subject: [PATCH 1/5] fix top_p reject --- csrc/gpu/sample_kernels/sampling.cuh | 126 +++- .../sample_kernels/top_p_sampling_reject.cu | 592 ++++++++++++++++-- .../transformers/generation_utils.py | 4 +- 3 files changed, 659 insertions(+), 63 deletions(-) diff --git a/csrc/gpu/sample_kernels/sampling.cuh b/csrc/gpu/sample_kernels/sampling.cuh index 334747dd0126..99fe2ad7d6e5 100644 --- a/csrc/gpu/sample_kernels/sampling.cuh +++ b/csrc/gpu/sample_kernels/sampling.cuh @@ -33,6 +33,15 @@ namespace sampling { using namespace cub; +#define DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, ...) \ + if (compute_capacity.first >= 8) { \ + constexpr uint32_t BLOCK_THREADS = 1024; \ + __VA_ARGS__ \ + } else { \ + constexpr uint32_t BLOCK_THREADS = 512; \ + __VA_ARGS__ \ + } + constexpr BlockScanAlgorithm SCAN_ALGO = BLOCK_SCAN_WARP_SCANS; constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS; @@ -376,6 +385,92 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, } +template +__global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples, IdType* output, + bool* success, IdType* top_k_arr, uint32_t top_k_val, + uint32_t d, uint32_t max_top_k_rounds) { + const uint32_t batch_size = gridDim.x; + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx]; + + extern __shared__ __align__( + alignof(SamplingTempStorage)) + uint8_t smem_sampling[]; + auto& temp_storage = reinterpret_cast< + SamplingTempStorage&>(smem_sampling); + + vec_t probs_vec; + DType aggregate; + DType q = DType(1); + DType pivot = DType(0); + IdType sampled_id; + for (uint32_t round = 0; round < max_top_k_rounds; ++round) { + temp_storage.data.sampled_id = d - 1; + __syncthreads(); + DType u = uniform_samples[round * batch_size + bx] * q; + aggregate = DType(0); + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(DType(0)); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + + DeviceSamplingFromProb(i, d, pivot, u, probs_vec, aggregate, + &temp_storage); + if (aggregate > u) { + break; + } + } + __syncthreads(); + sampled_id = temp_storage.data.sampled_id; + pivot = max(pivot, probs[bx * d + sampled_id]); + + Pair aggregate_gt_pivot{DType(0), 0}; + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(DType(0)); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + + Pair probs_gt_pivot[VEC_SIZE]; +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + probs_gt_pivot[j] = {(probs_vec[j] > pivot) ? probs_vec[j] : DType(0), + (probs_vec[j] > pivot && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + } + + aggregate_gt_pivot += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_pair) + .Sum(probs_gt_pivot); + if (tx == 0) { + temp_storage.data.block_aggregate.pair = aggregate_gt_pivot; + } + __syncthreads(); + } + q = temp_storage.data.block_aggregate.pair.value; + if (temp_storage.data.block_aggregate.pair.count < k) { + break; + } + } + __syncthreads(); + if (tx == 0) { + output[bx] = sampled_id; + if (temp_storage.data.block_aggregate.pair.count >= k) { + // failed to sample within MAX_TOP_P_ROUNDS + if (success != nullptr) { + success[bx] = false; + } + } else { + if (success != nullptr) { + success[bx] = true; + } + } + } +} + template cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, @@ -425,4 +520,33 @@ cudaError_t TopPSamplingFromProb(T* probs, return cudaSuccess; } -} // namespace sampling +template +cudaError_t TopKSamplingFromProb(T* probs, T* uniform_samples, IdType* output, bool* success, + T* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d, + uint32_t max_top_k_rounds, bool deterministic, + cudaStream_t stream = 0) { + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + + auto compute_capacity = GetCudaComputeCapability(); + DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { + const uint32_t smem_size = + sizeof(SamplingTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&probs, &uniform_samples, &output, &success, + &top_k_arr, &top_k_val, &d, &max_top_k_rounds}; + + DISPATCH_ALIGNED_VEC_SIZE( + vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { + auto kernel = TopKSamplingFromProbKernel; + CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + CUDA_CALL( + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + })}); + return cudaSuccess; + }); +} + +} // namespace sampling \ No newline at end of file diff --git a/csrc/gpu/sample_kernels/top_p_sampling_reject.cu b/csrc/gpu/sample_kernels/top_p_sampling_reject.cu index 1e98b0a81cd5..99fe2ad7d6e5 100644 --- a/csrc/gpu/sample_kernels/top_p_sampling_reject.cu +++ b/csrc/gpu/sample_kernels/top_p_sampling_reject.cu @@ -12,69 +12,541 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "helper.h" -#include "sample_kernels/sampling.cuh" - -std::vector TopPSamplingReject(const paddle::Tensor& probs, - const paddle::Tensor& top_p) { - std::vector probs_shape = probs.shape(); - unsigned int batch_size = probs_shape[0]; - unsigned int vocab_size = probs_shape[1]; - - // default is 32 - unsigned int max_top_p_rounds = 32; - std::vector uniform_samples_shape = {batch_size, max_top_p_rounds}; - paddle::Tensor uniform_samples = paddle::experimental::uniform( - uniform_samples_shape, paddle::DataType::FLOAT32, 0, 1, 0, probs.place()); - - // todo: add parameter for deterministic, now default is true - bool deterministic = true; - paddle::Tensor probs_input; - - probs_input = paddle::experimental::cast(probs, paddle::DataType::FLOAT32); - auto cu_stream = probs.stream(); - - auto samples = - paddle::full({batch_size}, 0, paddle::DataType::INT32, probs.place()); - auto success = - paddle::full({batch_size}, 0, paddle::DataType::BOOL, probs.place()); - - cudaError_t status = - sampling::TopPSamplingFromProb(probs_input.data(), - uniform_samples.data(), - samples.data(), - success.data(), - nullptr, - batch_size, - top_p.data(), - vocab_size, - max_top_p_rounds, - deterministic, - cu_stream); - PD_CHECK(status == cudaSuccess, - "SamplingFromProbs failed with error code " + - std::string(cudaGetErrorString(status))); - - paddle::Tensor samples_output; - samples_output = paddle::experimental::cast(samples, paddle::DataType::INT64); - return {samples_output}; +// This code is partially inspired by and references the implementation found +// in FlashInfer.Specifically, the implementation of Top-p Sampling functionality +// in this code is inspired by the logic of +// FlashInfer’s flashinfer.sampling.top_p_sampling_from_probs . +// For more details on FlashInfer’s documentation, please refer to: +// https://docs.flashinfer.ai/generated/flashinfer.sampling.top_p_sampling_from_probs.html + +#pragma once + +#include +#include +#include +#include + +#include "sample_kernels/utils.cuh" + + +namespace sampling { + +using namespace cub; + +#define DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, ...) \ + if (compute_capacity.first >= 8) { \ + constexpr uint32_t BLOCK_THREADS = 1024; \ + __VA_ARGS__ \ + } else { \ + constexpr uint32_t BLOCK_THREADS = 512; \ + __VA_ARGS__ \ + } + +constexpr BlockScanAlgorithm SCAN_ALGO = BLOCK_SCAN_WARP_SCANS; +constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS; + +#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120100) +#define SAMPLING_CUB_SUBTRACTLEFT_DEFINED +#endif + +template +struct Pair { + T value; + int count; + + __device__ Pair operator+(const Pair& other) const { + return {value + other.value, count + other.count}; + } + __device__ Pair& operator+=(const Pair& other) { + value += other.value; + count += other.count; + return *this; + } +}; + +struct BoolDiffOp { + __device__ __forceinline__ bool operator()(const bool& lhs, + const bool& rhs) const { + return lhs != rhs; + } +}; + +template +struct SamplingTempStorage { + union { + T deterministic_scan[BLOCK_THREADS / 32]; + typename BlockScan::TempStorage scan; + typename BlockReduce::TempStorage + reduce; + typename BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage + reduce_pair; + typename BlockAdjacentDifference::TempStorage adj_diff; + } block_prim; + struct { + int32_t sampled_id; + union { + T value; + Pair pair; + T max_p; + } block_aggregate; + } data; +}; + +/*! + * \brief Deterministic inclusive scan implementation, use Belloch scan + * algorithm. \note This implementation is slower than the cub::BlockScan, but + * it is deterministic. + */ +template +__device__ __forceinline__ void DeterministicInclusiveSum( + const T* in_data, + T* out_data, + SamplingTempStorage* + temp_storage) { + T* smem_prefix_sum = temp_storage->block_prim.deterministic_scan; + T thread_data[VEC_SIZE]; + T thread_sum = 0; +#pragma unroll + for (uint32_t i = 0; i < VEC_SIZE; ++i) { + thread_sum += in_data[i]; + thread_data[i] = thread_sum; + } + + T thread_exclusive_prefix_sum = thread_sum; + +#pragma unroll + for (uint32_t offset = 1; offset < 32; offset *= 2) { + T tmp = __shfl_up_sync(0xffffffff, thread_exclusive_prefix_sum, offset); + if ((threadIdx.x + 1) % (offset * 2) == 0) { + thread_exclusive_prefix_sum += tmp; + } + } + + T warp_sum = __shfl_sync( + 0xffffffff, thread_exclusive_prefix_sum, threadIdx.x | 0xffffffff); + if (threadIdx.x % 32 == 31) { + thread_exclusive_prefix_sum = 0; + } + +#pragma unroll + for (uint32_t offset = 16; offset >= 1; offset /= 2) { + T tmp = __shfl_xor_sync(0xffffffff, thread_exclusive_prefix_sum, offset); + if ((threadIdx.x + 1) % (offset * 2) == 0) { + thread_exclusive_prefix_sum = tmp + thread_exclusive_prefix_sum; + } + if ((threadIdx.x + 1) % (offset * 2) == offset) { + thread_exclusive_prefix_sum = tmp; + } + } + + smem_prefix_sum[threadIdx.x / 32] = warp_sum; + __syncthreads(); + + if (threadIdx.x < 32) { + T warp_exclusive_prefix_sum = + (threadIdx.x < BLOCK_THREADS / 32) ? smem_prefix_sum[threadIdx.x] : 0; + +#pragma unroll + for (uint32_t offset = 1; offset < 32; offset *= 2) { + T tmp = __shfl_up_sync(0xffffffff, warp_exclusive_prefix_sum, offset); + if ((threadIdx.x + 1) % (offset * 2) == 0) { + warp_exclusive_prefix_sum += tmp; + } + } + + if (threadIdx.x % 32 == 31) { + warp_exclusive_prefix_sum = 0; + } + +#pragma unroll + for (uint32_t offset = 16; offset >= 1; offset /= 2) { + T tmp = __shfl_xor_sync(0xffffffff, warp_exclusive_prefix_sum, offset); + if ((threadIdx.x + 1) % (offset * 2) == 0) { + warp_exclusive_prefix_sum = tmp + warp_exclusive_prefix_sum; + } + if ((threadIdx.x + 1) % (offset * 2) == offset) { + warp_exclusive_prefix_sum = tmp; + } + } + if (threadIdx.x < BLOCK_THREADS / 32) { + smem_prefix_sum[threadIdx.x] = warp_exclusive_prefix_sum; + } + } + __syncthreads(); + +#pragma unroll + for (uint32_t i = 0; i < VEC_SIZE; ++i) { + out_data[i] = smem_prefix_sum[threadIdx.x / 32] + + thread_exclusive_prefix_sum + thread_data[i]; + } } -std::vector> TopPSamplingRejectInferShape( - const std::vector& probs_shape, - const std::vector& top_p_shape) { - int64_t bs = probs_shape[0]; - return {{bs, 1}}; +template +__device__ __forceinline__ void DeviceSamplingFromProb( + uint32_t i, + uint32_t d, + T threshold, + T u, + vec_t prob_vec, + T& aggregate, + SamplingTempStorage* + temp_storage) { + const uint32_t tx = threadIdx.x; + T prob_greater_than_threshold[VEC_SIZE]; + T inclusive_cdf[VEC_SIZE]; + bool greater_than_u[VEC_SIZE], valid[VEC_SIZE]; +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + prob_greater_than_threshold[j] = + (prob_vec[j] > threshold) ? prob_vec[j] : T(0); + valid[j] = + prob_vec[j] > threshold && (i * BLOCK_THREADS + tx) * VEC_SIZE < d; + } + T aggregate_local = BlockReduce( + temp_storage->block_prim.reduce) + .Sum(prob_greater_than_threshold); + if (tx == 0) { + temp_storage->data.block_aggregate.value = aggregate_local; + } + __syncthreads(); + aggregate_local = temp_storage->data.block_aggregate.value; + + if (aggregate + aggregate_local > u) { + if constexpr (DETERMINISTIC) { + DeterministicInclusiveSum( + prob_greater_than_threshold, inclusive_cdf, temp_storage); + } else { + BlockScan(temp_storage->block_prim.scan) + .InclusiveSum(prob_greater_than_threshold, inclusive_cdf); + + __syncthreads(); + } + +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + greater_than_u[j] = inclusive_cdf[j] + aggregate > u; + } + + bool greater_than_u_diff[VEC_SIZE]; +#ifdef SAMPLING_CUB_SUBTRACTLEFT_DEFINED + BlockAdjacentDifference( + temp_storage->block_prim.adj_diff) + .SubtractLeft( + greater_than_u, greater_than_u_diff, BoolDiffOp()); +#else + BlockAdjacentDifference( + temp_storage->block_prim.adj_diff) + .FlagHeads( + greater_than_u_diff, greater_than_u, BoolDiffOp(), 0); +#endif + __syncthreads(); + +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + if (greater_than_u_diff[j] && valid[j]) { + if constexpr (DETERMINISTIC) { + temp_storage->data.sampled_id = + (i * BLOCK_THREADS + tx) * VEC_SIZE + j; + } else { + // cub's block scan result might not be monotonic, so we need to find + // the first element + atomicMin(&(temp_storage->data.sampled_id), + (i * BLOCK_THREADS + tx) * VEC_SIZE + j); + } + } + } + __syncthreads(); + } + aggregate += aggregate_local; +} + +template +__global__ void TopPSamplingFromProbKernel(DType* probs, + DType* uniform_samples, + IdType* output, + bool* success, + IdType* row_indices, + float* top_p_arr, + float* top_p_val, + uint32_t d, + uint32_t max_top_p_rounds) { + const uint32_t batch_size = gridDim.x; + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + float top_p = (top_p_arr == nullptr) ? top_p_val[bx] : top_p_arr[bx]; + + const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx]; + + extern __shared__ __align__(alignof(SamplingTempStorage)) + uint8_t smem_sampling[]; + auto& temp_storage = + reinterpret_cast&>(smem_sampling); + + vec_t probs_vec; + DType aggregate; + DType q = DType(1); + DType pivot = DType(0); + IdType sampled_id; + for (uint32_t round = 0; round < max_top_p_rounds; ++round) { + temp_storage.data.sampled_id = d - 1; + __syncthreads(); + DType u = uniform_samples[round * batch_size + bx] * q; + aggregate = DType(0); + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(DType(0)); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.load(probs + row_idx * d + + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + + DeviceSamplingFromProb( + i, d, pivot, u, probs_vec, aggregate, &temp_storage); + if (aggregate > u) { + break; + } + } + __syncthreads(); + sampled_id = temp_storage.data.sampled_id; + pivot = max(pivot, probs[row_idx * d + sampled_id]); + + DType aggregate_gt_pivot = DType(0); + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(DType(0)); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.load(probs + row_idx * d + + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + + DType probs_gt_pivot[VEC_SIZE]; +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + probs_gt_pivot[j] = (probs_vec[j] > pivot) ? probs_vec[j] : DType(0); + } + + aggregate_gt_pivot += + BlockReduce(temp_storage.block_prim.reduce) + .Sum(probs_gt_pivot); + if (tx == 0) { + temp_storage.data.block_aggregate.value = aggregate_gt_pivot; + } + __syncthreads(); + } + q = temp_storage.data.block_aggregate.value; + if (float(q) < top_p) { + break; + } + } + __syncthreads(); + if (tx == 0) { + output[bx] = sampled_id; + if (float(q) >= top_p) { + // failed to sample within MAX_TOP_P_ROUNDS + if (success != nullptr) { + success[bx] = false; + } + } else { + if (success != nullptr) { + success[bx] = true; + } + } + } +} + + +template +__global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples, IdType* output, + bool* success, IdType* top_k_arr, uint32_t top_k_val, + uint32_t d, uint32_t max_top_k_rounds) { + const uint32_t batch_size = gridDim.x; + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx]; + + extern __shared__ __align__( + alignof(SamplingTempStorage)) + uint8_t smem_sampling[]; + auto& temp_storage = reinterpret_cast< + SamplingTempStorage&>(smem_sampling); + + vec_t probs_vec; + DType aggregate; + DType q = DType(1); + DType pivot = DType(0); + IdType sampled_id; + for (uint32_t round = 0; round < max_top_k_rounds; ++round) { + temp_storage.data.sampled_id = d - 1; + __syncthreads(); + DType u = uniform_samples[round * batch_size + bx] * q; + aggregate = DType(0); + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(DType(0)); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + + DeviceSamplingFromProb(i, d, pivot, u, probs_vec, aggregate, + &temp_storage); + if (aggregate > u) { + break; + } + } + __syncthreads(); + sampled_id = temp_storage.data.sampled_id; + pivot = max(pivot, probs[bx * d + sampled_id]); + + Pair aggregate_gt_pivot{DType(0), 0}; + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(DType(0)); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + + Pair probs_gt_pivot[VEC_SIZE]; +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + probs_gt_pivot[j] = {(probs_vec[j] > pivot) ? probs_vec[j] : DType(0), + (probs_vec[j] > pivot && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + } + + aggregate_gt_pivot += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_pair) + .Sum(probs_gt_pivot); + if (tx == 0) { + temp_storage.data.block_aggregate.pair = aggregate_gt_pivot; + } + __syncthreads(); + } + q = temp_storage.data.block_aggregate.pair.value; + if (temp_storage.data.block_aggregate.pair.count < k) { + break; + } + } + __syncthreads(); + if (tx == 0) { + output[bx] = sampled_id; + if (temp_storage.data.block_aggregate.pair.count >= k) { + // failed to sample within MAX_TOP_P_ROUNDS + if (success != nullptr) { + success[bx] = false; + } + } else { + if (success != nullptr) { + success[bx] = true; + } + } + } } -std::vector TopPSamplingRejectInferDtype( - const paddle::DataType& probs_dtype, const paddle::DataType& top_p_shape) { - return {probs_dtype}; +template +cudaError_t TopPSamplingFromProb(T* probs, + T* uniform_samples, + IdType* output, + bool* success, + T* top_p_arr, + uint32_t batch_size, + const T* top_p_val, + uint32_t d, + uint32_t max_top_p_rounds, + bool deterministic, + cudaStream_t stream = 0) { + constexpr uint32_t BLOCK_THREADS = 1024; + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + + const uint32_t smem_size = + sizeof(SamplingTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + IdType* row_indices_placeholder = nullptr; + void* args[] = {&probs, + &uniform_samples, + &output, + &success, + &row_indices_placeholder, + &top_p_arr, + &top_p_val, + &d, + &max_top_p_rounds}; + + DISPATCH_ALIGNED_VEC_SIZE( + vec_size, + VEC_SIZE, + {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { + auto kernel = TopPSamplingFromProbKernel; + CUDA_CALL(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + CUDA_CALL(cudaLaunchKernel( + (void*)kernel, nblks, nthrs, args, smem_size, stream)); + })}); + return cudaSuccess; +} + +template +cudaError_t TopKSamplingFromProb(T* probs, T* uniform_samples, IdType* output, bool* success, + T* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d, + uint32_t max_top_k_rounds, bool deterministic, + cudaStream_t stream = 0) { + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + + auto compute_capacity = GetCudaComputeCapability(); + DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { + const uint32_t smem_size = + sizeof(SamplingTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&probs, &uniform_samples, &output, &success, + &top_k_arr, &top_k_val, &d, &max_top_k_rounds}; + + DISPATCH_ALIGNED_VEC_SIZE( + vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { + auto kernel = TopKSamplingFromProbKernel; + CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + CUDA_CALL( + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + })}); + return cudaSuccess; + }); } -PD_BUILD_OP(top_p_sampling_reject) - .Inputs({"probs", "top_p"}) - .Outputs({"samples"}) - .SetKernelFn(PD_KERNEL(TopPSamplingReject)) - .SetInferShapeFn(PD_INFER_SHAPE(TopPSamplingRejectInferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(TopPSamplingRejectInferDtype)); +} // namespace sampling \ No newline at end of file diff --git a/paddlenlp/experimental/transformers/generation_utils.py b/paddlenlp/experimental/transformers/generation_utils.py index 5133c944cdf6..336cf9de2397 100644 --- a/paddlenlp/experimental/transformers/generation_utils.py +++ b/paddlenlp/experimental/transformers/generation_utils.py @@ -333,7 +333,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): try: from paddlenlp_ops import top_p_sampling_reject - next_tokens = top_p_sampling_reject(probs, top_p) + next_tokens = top_p_sampling_reject(probs, top_p, 0) except: _, next_tokens = paddle.tensor.top_p_sampling(probs, top_p) @@ -677,7 +677,7 @@ def _post_process_( try: from paddlenlp_ops import top_p_sampling_reject - next_tokens = top_p_sampling_reject(probs, top_p) + next_tokens = top_p_sampling_reject(probs, top_p, 0) except: _, next_tokens = paddle.tensor.top_p_sampling(probs, top_p) From 5c6330f411fd7e28943a0a184f9e92029f7c0d08 Mon Sep 17 00:00:00 2001 From: gzy19990617 Date: Thu, 26 Sep 2024 04:18:14 +0000 Subject: [PATCH 2/5] fix top_p reject --- .../sample_kernels/top_p_sampling_reject.cu | 615 +++--------------- 1 file changed, 83 insertions(+), 532 deletions(-) diff --git a/csrc/gpu/sample_kernels/top_p_sampling_reject.cu b/csrc/gpu/sample_kernels/top_p_sampling_reject.cu index 99fe2ad7d6e5..e23734ef4c49 100644 --- a/csrc/gpu/sample_kernels/top_p_sampling_reject.cu +++ b/csrc/gpu/sample_kernels/top_p_sampling_reject.cu @@ -12,541 +12,92 @@ // See the License for the specific language governing permissions and // limitations under the License. -// This code is partially inspired by and references the implementation found -// in FlashInfer.Specifically, the implementation of Top-p Sampling functionality -// in this code is inspired by the logic of -// FlashInfer’s flashinfer.sampling.top_p_sampling_from_probs . -// For more details on FlashInfer’s documentation, please refer to: -// https://docs.flashinfer.ai/generated/flashinfer.sampling.top_p_sampling_from_probs.html - -#pragma once - -#include -#include -#include -#include - -#include "sample_kernels/utils.cuh" - - -namespace sampling { - -using namespace cub; - -#define DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, ...) \ - if (compute_capacity.first >= 8) { \ - constexpr uint32_t BLOCK_THREADS = 1024; \ - __VA_ARGS__ \ - } else { \ - constexpr uint32_t BLOCK_THREADS = 512; \ - __VA_ARGS__ \ - } - -constexpr BlockScanAlgorithm SCAN_ALGO = BLOCK_SCAN_WARP_SCANS; -constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS; - -#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120100) -#define SAMPLING_CUB_SUBTRACTLEFT_DEFINED -#endif - -template -struct Pair { - T value; - int count; - - __device__ Pair operator+(const Pair& other) const { - return {value + other.value, count + other.count}; - } - __device__ Pair& operator+=(const Pair& other) { - value += other.value; - count += other.count; - return *this; - } -}; - -struct BoolDiffOp { - __device__ __forceinline__ bool operator()(const bool& lhs, - const bool& rhs) const { - return lhs != rhs; - } -}; - -template -struct SamplingTempStorage { - union { - T deterministic_scan[BLOCK_THREADS / 32]; - typename BlockScan::TempStorage scan; - typename BlockReduce::TempStorage - reduce; - typename BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage - reduce_pair; - typename BlockAdjacentDifference::TempStorage adj_diff; - } block_prim; - struct { - int32_t sampled_id; - union { - T value; - Pair pair; - T max_p; - } block_aggregate; - } data; -}; - -/*! - * \brief Deterministic inclusive scan implementation, use Belloch scan - * algorithm. \note This implementation is slower than the cub::BlockScan, but - * it is deterministic. - */ -template -__device__ __forceinline__ void DeterministicInclusiveSum( - const T* in_data, - T* out_data, - SamplingTempStorage* - temp_storage) { - T* smem_prefix_sum = temp_storage->block_prim.deterministic_scan; - T thread_data[VEC_SIZE]; - T thread_sum = 0; -#pragma unroll - for (uint32_t i = 0; i < VEC_SIZE; ++i) { - thread_sum += in_data[i]; - thread_data[i] = thread_sum; - } - - T thread_exclusive_prefix_sum = thread_sum; - -#pragma unroll - for (uint32_t offset = 1; offset < 32; offset *= 2) { - T tmp = __shfl_up_sync(0xffffffff, thread_exclusive_prefix_sum, offset); - if ((threadIdx.x + 1) % (offset * 2) == 0) { - thread_exclusive_prefix_sum += tmp; - } - } - - T warp_sum = __shfl_sync( - 0xffffffff, thread_exclusive_prefix_sum, threadIdx.x | 0xffffffff); - if (threadIdx.x % 32 == 31) { - thread_exclusive_prefix_sum = 0; - } - -#pragma unroll - for (uint32_t offset = 16; offset >= 1; offset /= 2) { - T tmp = __shfl_xor_sync(0xffffffff, thread_exclusive_prefix_sum, offset); - if ((threadIdx.x + 1) % (offset * 2) == 0) { - thread_exclusive_prefix_sum = tmp + thread_exclusive_prefix_sum; - } - if ((threadIdx.x + 1) % (offset * 2) == offset) { - thread_exclusive_prefix_sum = tmp; - } - } - - smem_prefix_sum[threadIdx.x / 32] = warp_sum; - __syncthreads(); - - if (threadIdx.x < 32) { - T warp_exclusive_prefix_sum = - (threadIdx.x < BLOCK_THREADS / 32) ? smem_prefix_sum[threadIdx.x] : 0; - -#pragma unroll - for (uint32_t offset = 1; offset < 32; offset *= 2) { - T tmp = __shfl_up_sync(0xffffffff, warp_exclusive_prefix_sum, offset); - if ((threadIdx.x + 1) % (offset * 2) == 0) { - warp_exclusive_prefix_sum += tmp; - } - } - - if (threadIdx.x % 32 == 31) { - warp_exclusive_prefix_sum = 0; - } - -#pragma unroll - for (uint32_t offset = 16; offset >= 1; offset /= 2) { - T tmp = __shfl_xor_sync(0xffffffff, warp_exclusive_prefix_sum, offset); - if ((threadIdx.x + 1) % (offset * 2) == 0) { - warp_exclusive_prefix_sum = tmp + warp_exclusive_prefix_sum; - } - if ((threadIdx.x + 1) % (offset * 2) == offset) { - warp_exclusive_prefix_sum = tmp; - } - } - if (threadIdx.x < BLOCK_THREADS / 32) { - smem_prefix_sum[threadIdx.x] = warp_exclusive_prefix_sum; - } - } - __syncthreads(); - -#pragma unroll - for (uint32_t i = 0; i < VEC_SIZE; ++i) { - out_data[i] = smem_prefix_sum[threadIdx.x / 32] + - thread_exclusive_prefix_sum + thread_data[i]; - } -} - -template -__device__ __forceinline__ void DeviceSamplingFromProb( - uint32_t i, - uint32_t d, - T threshold, - T u, - vec_t prob_vec, - T& aggregate, - SamplingTempStorage* - temp_storage) { - const uint32_t tx = threadIdx.x; - T prob_greater_than_threshold[VEC_SIZE]; - T inclusive_cdf[VEC_SIZE]; - bool greater_than_u[VEC_SIZE], valid[VEC_SIZE]; -#pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - prob_greater_than_threshold[j] = - (prob_vec[j] > threshold) ? prob_vec[j] : T(0); - valid[j] = - prob_vec[j] > threshold && (i * BLOCK_THREADS + tx) * VEC_SIZE < d; - } - T aggregate_local = BlockReduce( - temp_storage->block_prim.reduce) - .Sum(prob_greater_than_threshold); - if (tx == 0) { - temp_storage->data.block_aggregate.value = aggregate_local; - } - __syncthreads(); - aggregate_local = temp_storage->data.block_aggregate.value; - - if (aggregate + aggregate_local > u) { - if constexpr (DETERMINISTIC) { - DeterministicInclusiveSum( - prob_greater_than_threshold, inclusive_cdf, temp_storage); - } else { - BlockScan(temp_storage->block_prim.scan) - .InclusiveSum(prob_greater_than_threshold, inclusive_cdf); - - __syncthreads(); - } - -#pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - greater_than_u[j] = inclusive_cdf[j] + aggregate > u; - } - - bool greater_than_u_diff[VEC_SIZE]; -#ifdef SAMPLING_CUB_SUBTRACTLEFT_DEFINED - BlockAdjacentDifference( - temp_storage->block_prim.adj_diff) - .SubtractLeft( - greater_than_u, greater_than_u_diff, BoolDiffOp()); -#else - BlockAdjacentDifference( - temp_storage->block_prim.adj_diff) - .FlagHeads( - greater_than_u_diff, greater_than_u, BoolDiffOp(), 0); -#endif - __syncthreads(); - -#pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - if (greater_than_u_diff[j] && valid[j]) { - if constexpr (DETERMINISTIC) { - temp_storage->data.sampled_id = - (i * BLOCK_THREADS + tx) * VEC_SIZE + j; - } else { - // cub's block scan result might not be monotonic, so we need to find - // the first element - atomicMin(&(temp_storage->data.sampled_id), - (i * BLOCK_THREADS + tx) * VEC_SIZE + j); - } - } - } - __syncthreads(); - } - aggregate += aggregate_local; +#include "helper.h" +#include "sample_kernels/sampling.cuh" + +std::vector TopPSamplingReject(const paddle::Tensor& probs, + const paddle::Tensor& top_p, + int seed) { + std::vector probs_shape = probs.shape(); + unsigned int batch_size = probs_shape[0]; + unsigned int vocab_size = probs_shape[1]; + + // default is 32 + unsigned int max_top_p_rounds = 32; + std::vector uniform_samples_shape = {batch_size, max_top_p_rounds}; + paddle::Tensor uniform_samples = + paddle::experimental::uniform(uniform_samples_shape, + paddle::DataType::FLOAT32, + 0, + 1, + seed, + probs.place()); + + auto cu_stream = probs.stream(); + + auto samples = + paddle::full({batch_size, 1}, 0, paddle::DataType::INT32, probs.place()); + auto success = + paddle::full({batch_size, 1}, 0, paddle::DataType::BOOL, probs.place()); + + auto top_p_host = + paddle::experimental::copy_to(top_p, paddle::CPUPlace(), true); + float top_p_val = top_p_host.data()[0]; + cudaError_t status; + if (top_p_val == 0.0) { + // top_p is 0,use top_k sampling . + status = sampling::TopKSamplingFromProb( + const_cast(probs.data()), + uniform_samples.data(), + samples.data(), + success.data(), + nullptr, + batch_size, + 1, + vocab_size, + max_top_p_rounds, + true, + cu_stream); + } else { + status = sampling::TopPSamplingFromProb( + const_cast(probs.data()), + uniform_samples.data(), + samples.data(), + success.data(), + nullptr, + batch_size, + top_p.data(), + vocab_size, + max_top_p_rounds, + true, + cu_stream); + } + + PD_CHECK(status == cudaSuccess, + "SamplingFromProbs failed with error code " + + std::string(cudaGetErrorString(status))); + + paddle::Tensor samples_output; + samples_output = paddle::experimental::cast(samples, paddle::DataType::INT64); + return {samples_output}; } -template -__global__ void TopPSamplingFromProbKernel(DType* probs, - DType* uniform_samples, - IdType* output, - bool* success, - IdType* row_indices, - float* top_p_arr, - float* top_p_val, - uint32_t d, - uint32_t max_top_p_rounds) { - const uint32_t batch_size = gridDim.x; - const uint32_t bx = blockIdx.x, tx = threadIdx.x; - float top_p = (top_p_arr == nullptr) ? top_p_val[bx] : top_p_arr[bx]; - - const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx]; - - extern __shared__ __align__(alignof(SamplingTempStorage)) - uint8_t smem_sampling[]; - auto& temp_storage = - reinterpret_cast&>(smem_sampling); - - vec_t probs_vec; - DType aggregate; - DType q = DType(1); - DType pivot = DType(0); - IdType sampled_id; - for (uint32_t round = 0; round < max_top_p_rounds; ++round) { - temp_storage.data.sampled_id = d - 1; - __syncthreads(); - DType u = uniform_samples[round * batch_size + bx] * q; - aggregate = DType(0); - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + row_idx * d + - (i * BLOCK_THREADS + tx) * VEC_SIZE); - } - - DeviceSamplingFromProb( - i, d, pivot, u, probs_vec, aggregate, &temp_storage); - if (aggregate > u) { - break; - } - } - __syncthreads(); - sampled_id = temp_storage.data.sampled_id; - pivot = max(pivot, probs[row_idx * d + sampled_id]); - - DType aggregate_gt_pivot = DType(0); - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + row_idx * d + - (i * BLOCK_THREADS + tx) * VEC_SIZE); - } - - DType probs_gt_pivot[VEC_SIZE]; -#pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - probs_gt_pivot[j] = (probs_vec[j] > pivot) ? probs_vec[j] : DType(0); - } - - aggregate_gt_pivot += - BlockReduce(temp_storage.block_prim.reduce) - .Sum(probs_gt_pivot); - if (tx == 0) { - temp_storage.data.block_aggregate.value = aggregate_gt_pivot; - } - __syncthreads(); - } - q = temp_storage.data.block_aggregate.value; - if (float(q) < top_p) { - break; - } - } - __syncthreads(); - if (tx == 0) { - output[bx] = sampled_id; - if (float(q) >= top_p) { - // failed to sample within MAX_TOP_P_ROUNDS - if (success != nullptr) { - success[bx] = false; - } - } else { - if (success != nullptr) { - success[bx] = true; - } - } - } +std::vector> TopPSamplingRejectInferShape( + const std::vector& probs_shape, + const std::vector& top_p_shape) { + int64_t bs = probs_shape[0]; + return {{bs, 1}}; } - -template -__global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples, IdType* output, - bool* success, IdType* top_k_arr, uint32_t top_k_val, - uint32_t d, uint32_t max_top_k_rounds) { - const uint32_t batch_size = gridDim.x; - const uint32_t bx = blockIdx.x, tx = threadIdx.x; - uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx]; - - extern __shared__ __align__( - alignof(SamplingTempStorage)) - uint8_t smem_sampling[]; - auto& temp_storage = reinterpret_cast< - SamplingTempStorage&>(smem_sampling); - - vec_t probs_vec; - DType aggregate; - DType q = DType(1); - DType pivot = DType(0); - IdType sampled_id; - for (uint32_t round = 0; round < max_top_k_rounds; ++round) { - temp_storage.data.sampled_id = d - 1; - __syncthreads(); - DType u = uniform_samples[round * batch_size + bx] * q; - aggregate = DType(0); - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); - } - - DeviceSamplingFromProb(i, d, pivot, u, probs_vec, aggregate, - &temp_storage); - if (aggregate > u) { - break; - } - } - __syncthreads(); - sampled_id = temp_storage.data.sampled_id; - pivot = max(pivot, probs[bx * d + sampled_id]); - - Pair aggregate_gt_pivot{DType(0), 0}; - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); - } - - Pair probs_gt_pivot[VEC_SIZE]; -#pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - probs_gt_pivot[j] = {(probs_vec[j] > pivot) ? probs_vec[j] : DType(0), - (probs_vec[j] > pivot && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; - } - - aggregate_gt_pivot += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( - temp_storage.block_prim.reduce_pair) - .Sum(probs_gt_pivot); - if (tx == 0) { - temp_storage.data.block_aggregate.pair = aggregate_gt_pivot; - } - __syncthreads(); - } - q = temp_storage.data.block_aggregate.pair.value; - if (temp_storage.data.block_aggregate.pair.count < k) { - break; - } - } - __syncthreads(); - if (tx == 0) { - output[bx] = sampled_id; - if (temp_storage.data.block_aggregate.pair.count >= k) { - // failed to sample within MAX_TOP_P_ROUNDS - if (success != nullptr) { - success[bx] = false; - } - } else { - if (success != nullptr) { - success[bx] = true; - } - } - } -} - -template -cudaError_t TopPSamplingFromProb(T* probs, - T* uniform_samples, - IdType* output, - bool* success, - T* top_p_arr, - uint32_t batch_size, - const T* top_p_val, - uint32_t d, - uint32_t max_top_p_rounds, - bool deterministic, - cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 1024; - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - - const uint32_t smem_size = - sizeof(SamplingTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - IdType* row_indices_placeholder = nullptr; - void* args[] = {&probs, - &uniform_samples, - &output, - &success, - &row_indices_placeholder, - &top_p_arr, - &top_p_val, - &d, - &max_top_p_rounds}; - - DISPATCH_ALIGNED_VEC_SIZE( - vec_size, - VEC_SIZE, - {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { - auto kernel = TopPSamplingFromProbKernel; - CUDA_CALL(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - CUDA_CALL(cudaLaunchKernel( - (void*)kernel, nblks, nthrs, args, smem_size, stream)); - })}); - return cudaSuccess; -} - -template -cudaError_t TopKSamplingFromProb(T* probs, T* uniform_samples, IdType* output, bool* success, - T* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d, - uint32_t max_top_k_rounds, bool deterministic, - cudaStream_t stream = 0) { - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - - auto compute_capacity = GetCudaComputeCapability(); - DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { - const uint32_t smem_size = - sizeof(SamplingTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &uniform_samples, &output, &success, - &top_k_arr, &top_k_val, &d, &max_top_k_rounds}; - - DISPATCH_ALIGNED_VEC_SIZE( - vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { - auto kernel = TopKSamplingFromProbKernel; - CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - CUDA_CALL( - cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - })}); - return cudaSuccess; - }); +std::vector TopPSamplingRejectInferDtype( + const paddle::DataType& probs_dtype, const paddle::DataType& top_p_shape) { + return {paddle::DataType::INT64}; } -} // namespace sampling \ No newline at end of file +PD_BUILD_OP(top_p_sampling_reject) + .Inputs({"probs", "top_p"}) + .Outputs({"samples"}) + .Attrs({"seed: int"}) + .SetKernelFn(PD_KERNEL(TopPSamplingReject)) + .SetInferShapeFn(PD_INFER_SHAPE(TopPSamplingRejectInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(TopPSamplingRejectInferDtype)); \ No newline at end of file From c9e74d7d540990840e43a7a810013b3e839f6704 Mon Sep 17 00:00:00 2001 From: gzy19990617 Date: Thu, 26 Sep 2024 05:14:50 +0000 Subject: [PATCH 3/5] fix top_p reject --- csrc/gpu/sample_kernels/top_p_sampling_reject.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/gpu/sample_kernels/top_p_sampling_reject.cu b/csrc/gpu/sample_kernels/top_p_sampling_reject.cu index e23734ef4c49..ad525bd3526d 100644 --- a/csrc/gpu/sample_kernels/top_p_sampling_reject.cu +++ b/csrc/gpu/sample_kernels/top_p_sampling_reject.cu @@ -41,7 +41,7 @@ std::vector TopPSamplingReject(const paddle::Tensor& probs, paddle::full({batch_size, 1}, 0, paddle::DataType::BOOL, probs.place()); auto top_p_host = - paddle::experimental::copy_to(top_p, paddle::CPUPlace(), true); + paddle::experimental::copy_to(top_p, paddle::CPUPlace(), false); float top_p_val = top_p_host.data()[0]; cudaError_t status; if (top_p_val == 0.0) { From 2ee0b27b7bb432641b66d01b66948770cebb9a3b Mon Sep 17 00:00:00 2001 From: gaoziyuan Date: Sun, 29 Sep 2024 22:19:49 +0800 Subject: [PATCH 4/5] just for test,need change --- csrc/gpu/sample_kernels/sampling.cuh | 24 ++++----- .../sample_kernels/top_p_sampling_reject.cu | 51 +++++-------------- 2 files changed, 25 insertions(+), 50 deletions(-) diff --git a/csrc/gpu/sample_kernels/sampling.cuh b/csrc/gpu/sample_kernels/sampling.cuh index 99fe2ad7d6e5..9c8062bf5284 100644 --- a/csrc/gpu/sample_kernels/sampling.cuh +++ b/csrc/gpu/sample_kernels/sampling.cuh @@ -286,7 +286,6 @@ template = top_p) { - // failed to sample within MAX_TOP_P_ROUNDS - if (success != nullptr) { - success[bx] = false; - } - } else { - if (success != nullptr) { - success[bx] = true; - } - } + // todo:delete + // if (float(q) >= top_p) { + // // failed to sample within MAX_TOP_P_ROUNDS + // if (success != nullptr) { + // success[bx] = false; + // } + // } else { + // if (success != nullptr) { + // success[bx] = true; + // } + // } } } @@ -475,7 +475,6 @@ template cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, - bool* success, T* top_p_arr, uint32_t batch_size, const T* top_p_val, @@ -494,7 +493,6 @@ cudaError_t TopPSamplingFromProb(T* probs, void* args[] = {&probs, &uniform_samples, &output, - &success, &row_indices_placeholder, &top_p_arr, &top_p_val, diff --git a/csrc/gpu/sample_kernels/top_p_sampling_reject.cu b/csrc/gpu/sample_kernels/top_p_sampling_reject.cu index ad525bd3526d..ec2fe7eeaec3 100644 --- a/csrc/gpu/sample_kernels/top_p_sampling_reject.cu +++ b/csrc/gpu/sample_kernels/top_p_sampling_reject.cu @@ -36,50 +36,27 @@ std::vector TopPSamplingReject(const paddle::Tensor& probs, auto cu_stream = probs.stream(); auto samples = - paddle::full({batch_size, 1}, 0, paddle::DataType::INT32, probs.place()); - auto success = - paddle::full({batch_size, 1}, 0, paddle::DataType::BOOL, probs.place()); + paddle::empty({batch_size, 1}, paddle::DataType::INT64, probs.place()); - auto top_p_host = - paddle::experimental::copy_to(top_p, paddle::CPUPlace(), false); - float top_p_val = top_p_host.data()[0]; cudaError_t status; - if (top_p_val == 0.0) { - // top_p is 0,use top_k sampling . - status = sampling::TopKSamplingFromProb( - const_cast(probs.data()), - uniform_samples.data(), - samples.data(), - success.data(), - nullptr, - batch_size, - 1, - vocab_size, - max_top_p_rounds, - true, - cu_stream); - } else { - status = sampling::TopPSamplingFromProb( - const_cast(probs.data()), - uniform_samples.data(), - samples.data(), - success.data(), - nullptr, - batch_size, - top_p.data(), - vocab_size, - max_top_p_rounds, - true, - cu_stream); - } + status = sampling::TopPSamplingFromProb( + const_cast(probs.data()), + uniform_samples.data(), + samples.data(), + nullptr, + batch_size, + top_p.data(), + vocab_size, + max_top_p_rounds, + true, + cu_stream); + PD_CHECK(status == cudaSuccess, "SamplingFromProbs failed with error code " + std::string(cudaGetErrorString(status))); - paddle::Tensor samples_output; - samples_output = paddle::experimental::cast(samples, paddle::DataType::INT64); - return {samples_output}; + return {samples}; } std::vector> TopPSamplingRejectInferShape( From 8e3096f32c5a9c6749bfd9b1d351dbb74c32b239 Mon Sep 17 00:00:00 2001 From: gaoziyuan Date: Wed, 9 Oct 2024 17:39:51 +0800 Subject: [PATCH 5/5] optimize top_p --- csrc/gpu/sample_kernels/sampling.cuh | 147 ++---------------- .../sample_kernels/top_p_sampling_reject.cu | 1 - .../test/python/test_top_p_sampling_reject.py | 63 ++++++++ 3 files changed, 72 insertions(+), 139 deletions(-) create mode 100644 csrc/gpu/test/python/test_top_p_sampling_reject.py diff --git a/csrc/gpu/sample_kernels/sampling.cuh b/csrc/gpu/sample_kernels/sampling.cuh index 9c8062bf5284..4940070d2dfa 100644 --- a/csrc/gpu/sample_kernels/sampling.cuh +++ b/csrc/gpu/sample_kernels/sampling.cuh @@ -286,16 +286,12 @@ template pivot) ? probs_vec[j] : DType(0); - } - - aggregate_gt_pivot += - BlockReduce(temp_storage.block_prim.reduce) - .Sum(probs_gt_pivot); - if (tx == 0) { - temp_storage.data.block_aggregate.value = aggregate_gt_pivot; - } - __syncthreads(); - } - q = temp_storage.data.block_aggregate.value; - if (float(q) < top_p) { - break; - } - } - __syncthreads(); - if (tx == 0) { - output[bx] = sampled_id; - // todo:delete - // if (float(q) >= top_p) { - // // failed to sample within MAX_TOP_P_ROUNDS - // if (success != nullptr) { - // success[bx] = false; - // } - // } else { - // if (success != nullptr) { - // success[bx] = true; - // } - // } - } -} - - -template -__global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples, IdType* output, - bool* success, IdType* top_k_arr, uint32_t top_k_val, - uint32_t d, uint32_t max_top_k_rounds) { - const uint32_t batch_size = gridDim.x; - const uint32_t bx = blockIdx.x, tx = threadIdx.x; - uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx]; - - extern __shared__ __align__( - alignof(SamplingTempStorage)) - uint8_t smem_sampling[]; - auto& temp_storage = reinterpret_cast< - SamplingTempStorage&>(smem_sampling); - - vec_t probs_vec; - DType aggregate; - DType q = DType(1); - DType pivot = DType(0); - IdType sampled_id; - for (uint32_t round = 0; round < max_top_k_rounds; ++round) { - temp_storage.data.sampled_id = d - 1; - __syncthreads(); - DType u = uniform_samples[round * batch_size + bx] * q; - aggregate = DType(0); - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); - } - - DeviceSamplingFromProb(i, d, pivot, u, probs_vec, aggregate, - &temp_storage); - if (aggregate > u) { - break; - } - } - __syncthreads(); - sampled_id = temp_storage.data.sampled_id; pivot = max(pivot, probs[bx * d + sampled_id]); Pair aggregate_gt_pivot{DType(0), 0}; @@ -451,23 +359,19 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples, __syncthreads(); } q = temp_storage.data.block_aggregate.pair.value; - if (temp_storage.data.block_aggregate.pair.count < k) { + if (float(q) > 0 && float(q) < top_p) { + // top_p is not 0 break; + } else { + // top_p is 0, use top_k, k=1 + if (temp_storage.data.block_aggregate.pair.count < 1) { + break; + } } } __syncthreads(); if (tx == 0) { output[bx] = sampled_id; - if (temp_storage.data.block_aggregate.pair.count >= k) { - // failed to sample within MAX_TOP_P_ROUNDS - if (success != nullptr) { - success[bx] = false; - } - } else { - if (success != nullptr) { - success[bx] = true; - } - } } } @@ -475,7 +379,6 @@ template cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, - T* top_p_arr, uint32_t batch_size, const T* top_p_val, uint32_t d, @@ -489,12 +392,9 @@ cudaError_t TopPSamplingFromProb(T* probs, sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - IdType* row_indices_placeholder = nullptr; void* args[] = {&probs, &uniform_samples, &output, - &row_indices_placeholder, - &top_p_arr, &top_p_val, &d, &max_top_p_rounds}; @@ -518,33 +418,4 @@ cudaError_t TopPSamplingFromProb(T* probs, return cudaSuccess; } -template -cudaError_t TopKSamplingFromProb(T* probs, T* uniform_samples, IdType* output, bool* success, - T* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d, - uint32_t max_top_k_rounds, bool deterministic, - cudaStream_t stream = 0) { - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - - auto compute_capacity = GetCudaComputeCapability(); - DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { - const uint32_t smem_size = - sizeof(SamplingTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &uniform_samples, &output, &success, - &top_k_arr, &top_k_val, &d, &max_top_k_rounds}; - - DISPATCH_ALIGNED_VEC_SIZE( - vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { - auto kernel = TopKSamplingFromProbKernel; - CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - CUDA_CALL( - cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - })}); - return cudaSuccess; - }); -} - } // namespace sampling \ No newline at end of file diff --git a/csrc/gpu/sample_kernels/top_p_sampling_reject.cu b/csrc/gpu/sample_kernels/top_p_sampling_reject.cu index ec2fe7eeaec3..df62f2c12efe 100644 --- a/csrc/gpu/sample_kernels/top_p_sampling_reject.cu +++ b/csrc/gpu/sample_kernels/top_p_sampling_reject.cu @@ -44,7 +44,6 @@ std::vector TopPSamplingReject(const paddle::Tensor& probs, const_cast(probs.data()), uniform_samples.data(), samples.data(), - nullptr, batch_size, top_p.data(), vocab_size, diff --git a/csrc/gpu/test/python/test_top_p_sampling_reject.py b/csrc/gpu/test/python/test_top_p_sampling_reject.py new file mode 100644 index 000000000000..7a0605a3bdbe --- /dev/null +++ b/csrc/gpu/test/python/test_top_p_sampling_reject.py @@ -0,0 +1,63 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import paddle +from paddlenlp_ops import top_p_sampling_reject + +paddle.seed(2023) + +batch_size = 3 +vocab_size = 40080 +max_rounds = 32 + +class SetPreidsTokenPenaltyMultiScores(unittest.TestCase): + def test_top_p_sampling_reject_case1(self): + # top_p为1, 不同seed + pre_norm_prob_np = np.random.rand(batch_size, vocab_size).astype(np.float32) + + paddle_pre_norm_prob = paddle.to_tensor(pre_norm_prob_np) + paddle_norm_prob = paddle_pre_norm_prob / paddle_pre_norm_prob.sum(axis=-1, keepdim=True) + top_p_paddle = paddle.full((batch_size,), 1) + samples = top_p_sampling_reject(paddle_norm_prob, top_p_paddle, 0) + print(samples) + samples = top_p_sampling_reject(paddle_norm_prob, top_p_paddle, 1024) + print(samples) + samples = top_p_sampling_reject(paddle_norm_prob, top_p_paddle, 2033) + print(samples) + + def test_top_p_sampling_reject_case2(self): + # top_p为0 + pre_norm_prob_np = np.random.rand(batch_size, vocab_size).astype(np.float32) + + paddle_pre_norm_prob = paddle.to_tensor(pre_norm_prob_np) + paddle_norm_prob = paddle_pre_norm_prob / paddle_pre_norm_prob.sum(axis=-1, keepdim=True) + top_p_paddle = paddle.full((batch_size,), 0) + samples = top_p_sampling_reject(paddle_norm_prob, top_p_paddle, 0) + print(samples) + + def test_top_p_sampling_reject_case3(self): + # 不同batch的top_p值不同 + pre_norm_prob_np = np.random.rand(batch_size, vocab_size).astype(np.float32) + + paddle_pre_norm_prob = paddle.to_tensor(pre_norm_prob_np) + paddle_norm_prob = paddle_pre_norm_prob / paddle_pre_norm_prob.sum(axis=-1, keepdim=True) + top_p_paddle = paddle.uniform(shape=[batch_size,1], min=0, max=1) + samples = top_p_sampling_reject(paddle_norm_prob, top_p_paddle, 0) + print(samples) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file