Skip to content

[LLM INFER] top_p_sampling_reject support top_p=0 and custom seed #9202

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 125 additions & 1 deletion csrc/gpu/sample_kernels/sampling.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ namespace sampling {

using namespace cub;

#define DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, ...) \
if (compute_capacity.first >= 8) { \
constexpr uint32_t BLOCK_THREADS = 1024; \
__VA_ARGS__ \
} else { \
constexpr uint32_t BLOCK_THREADS = 512; \
__VA_ARGS__ \
}

constexpr BlockScanAlgorithm SCAN_ALGO = BLOCK_SCAN_WARP_SCANS;
constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS;

Expand Down Expand Up @@ -376,6 +385,92 @@ __global__ void TopPSamplingFromProbKernel(DType* probs,
}


template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
typename DType, typename IdType>
__global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples, IdType* output,
bool* success, IdType* top_k_arr, uint32_t top_k_val,
uint32_t d, uint32_t max_top_k_rounds) {
const uint32_t batch_size = gridDim.x;
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx];

extern __shared__ __align__(
alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
uint8_t smem_sampling[];
auto& temp_storage = reinterpret_cast<
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);

vec_t<DType, VEC_SIZE> probs_vec;
DType aggregate;
DType q = DType(1);
DType pivot = DType(0);
IdType sampled_id;
for (uint32_t round = 0; round < max_top_k_rounds; ++round) {
temp_storage.data.sampled_id = d - 1;
__syncthreads();
DType u = uniform_samples[round * batch_size + bx] * q;
aggregate = DType(0);
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(DType(0));
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}

DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM,
DETERMINISTIC, DType>(i, d, pivot, u, probs_vec, aggregate,
&temp_storage);
if (aggregate > u) {
break;
}
}
__syncthreads();
sampled_id = temp_storage.data.sampled_id;
pivot = max(pivot, probs[bx * d + sampled_id]);

Pair<DType> aggregate_gt_pivot{DType(0), 0};
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(DType(0));
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}

Pair<DType> probs_gt_pivot[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_gt_pivot[j] = {(probs_vec[j] > pivot) ? probs_vec[j] : DType(0),
(probs_vec[j] > pivot && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
}

aggregate_gt_pivot += BlockReduce<Pair<DType>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_pair)
.Sum<VEC_SIZE>(probs_gt_pivot);
if (tx == 0) {
temp_storage.data.block_aggregate.pair = aggregate_gt_pivot;
}
__syncthreads();
}
q = temp_storage.data.block_aggregate.pair.value;
if (temp_storage.data.block_aggregate.pair.count < k) {
break;
}
}
__syncthreads();
if (tx == 0) {
output[bx] = sampled_id;
if (temp_storage.data.block_aggregate.pair.count >= k) {
// failed to sample within MAX_TOP_P_ROUNDS
if (success != nullptr) {
success[bx] = false;
}
} else {
if (success != nullptr) {
success[bx] = true;
}
}
}
}

template <typename T, typename IdType>
cudaError_t TopPSamplingFromProb(T* probs,
T* uniform_samples,
Expand Down Expand Up @@ -425,4 +520,33 @@ cudaError_t TopPSamplingFromProb(T* probs,
return cudaSuccess;
}

} // namespace sampling
template <typename T, typename IdType>
cudaError_t TopKSamplingFromProb(T* probs, T* uniform_samples, IdType* output, bool* success,
T* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d,
uint32_t max_top_k_rounds, bool deterministic,
cudaStream_t stream = 0) {
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);

auto compute_capacity = GetCudaComputeCapability();
DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, {
const uint32_t smem_size =
sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&probs, &uniform_samples, &output, &success,
&top_k_arr, &top_k_val, &d, &max_top_k_rounds};

DISPATCH_ALIGNED_VEC_SIZE(
vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
auto kernel = TopKSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO, VEC_SIZE,
DETERMINISTIC, T, IdType>;
CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
CUDA_CALL(
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
})});
return cudaSuccess;
});
}

} // namespace sampling
71 changes: 47 additions & 24 deletions csrc/gpu/sample_kernels/top_p_sampling_reject.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,41 +16,63 @@
#include "sample_kernels/sampling.cuh"

std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor& probs,
const paddle::Tensor& top_p) {
const paddle::Tensor& top_p,
int seed) {
std::vector<int64_t> probs_shape = probs.shape();
unsigned int batch_size = probs_shape[0];
unsigned int vocab_size = probs_shape[1];

// default is 32
unsigned int max_top_p_rounds = 32;
std::vector<int64_t> uniform_samples_shape = {batch_size, max_top_p_rounds};
paddle::Tensor uniform_samples = paddle::experimental::uniform(
uniform_samples_shape, paddle::DataType::FLOAT32, 0, 1, 0, probs.place());
paddle::Tensor uniform_samples =
paddle::experimental::uniform(uniform_samples_shape,
paddle::DataType::FLOAT32,
0,
1,
seed,
probs.place());

// todo: add parameter for deterministic, now default is true
bool deterministic = true;
paddle::Tensor probs_input;

probs_input = paddle::experimental::cast(probs, paddle::DataType::FLOAT32);
auto cu_stream = probs.stream();

auto samples =
paddle::full({batch_size}, 0, paddle::DataType::INT32, probs.place());
paddle::full({batch_size, 1}, 0, paddle::DataType::INT32, probs.place());
auto success =
paddle::full({batch_size}, 0, paddle::DataType::BOOL, probs.place());
paddle::full({batch_size, 1}, 0, paddle::DataType::BOOL, probs.place());

auto top_p_host =
paddle::experimental::copy_to(top_p, paddle::CPUPlace(), false);
float top_p_val = top_p_host.data<float>()[0];
cudaError_t status;
if (top_p_val == 0.0) {
// top_p is 0,use top_k sampling .
status = sampling::TopKSamplingFromProb<float, int>(
const_cast<float*>(probs.data<float>()),
uniform_samples.data<float>(),
samples.data<int>(),
success.data<bool>(),
nullptr,
batch_size,
1,
vocab_size,
max_top_p_rounds,
true,
cu_stream);
} else {
status = sampling::TopPSamplingFromProb<float, int>(
const_cast<float*>(probs.data<float>()),
uniform_samples.data<float>(),
samples.data<int>(),
success.data<bool>(),
nullptr,
batch_size,
top_p.data<float>(),
vocab_size,
max_top_p_rounds,
true,
cu_stream);
}

cudaError_t status =
sampling::TopPSamplingFromProb<float, int>(probs_input.data<float>(),
uniform_samples.data<float>(),
samples.data<int>(),
success.data<bool>(),
nullptr,
batch_size,
top_p.data<float>(),
vocab_size,
max_top_p_rounds,
deterministic,
cu_stream);
PD_CHECK(status == cudaSuccess,
"SamplingFromProbs failed with error code " +
std::string(cudaGetErrorString(status)));
Expand All @@ -69,12 +91,13 @@ std::vector<std::vector<int64_t>> TopPSamplingRejectInferShape(

std::vector<paddle::DataType> TopPSamplingRejectInferDtype(
const paddle::DataType& probs_dtype, const paddle::DataType& top_p_shape) {
return {probs_dtype};
return {paddle::DataType::INT64};
}

PD_BUILD_OP(top_p_sampling_reject)
.Inputs({"probs", "top_p"})
.Outputs({"samples"})
.Attrs({"seed: int"})
.SetKernelFn(PD_KERNEL(TopPSamplingReject))
.SetInferShapeFn(PD_INFER_SHAPE(TopPSamplingRejectInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(TopPSamplingRejectInferDtype));
.SetInferDtypeFn(PD_INFER_DTYPE(TopPSamplingRejectInferDtype));
4 changes: 2 additions & 2 deletions paddlenlp/experimental/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@
try:
from paddlenlp_ops import top_p_sampling_reject

next_tokens = top_p_sampling_reject(probs, top_p)
next_tokens = top_p_sampling_reject(probs, top_p, 0)

Check warning on line 336 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L336

Added line #L336 was not covered by tests
except:
_, next_tokens = paddle.tensor.top_p_sampling(probs, top_p)

Expand Down Expand Up @@ -677,7 +677,7 @@
try:
from paddlenlp_ops import top_p_sampling_reject

next_tokens = top_p_sampling_reject(probs, top_p)
next_tokens = top_p_sampling_reject(probs, top_p, 0)

Check warning on line 680 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L680

Added line #L680 was not covered by tests
except:
_, next_tokens = paddle.tensor.top_p_sampling(probs, top_p)

Expand Down
Loading