Skip to content

Commit

Permalink
added topp_sampling kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi committed Dec 7, 2023
1 parent 9257567 commit 35375ae
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 70 deletions.
1 change: 1 addition & 0 deletions src/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ cc_library(
sampling/penalty_kernels.cu
sampling/softmax_kernels.cu
sampling/topk_kernels.cu
sampling/topp_kernels.cu
DEPS
glog::glog
torch
Expand Down
157 changes: 87 additions & 70 deletions src/kernels/sampling/topk_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ namespace llm::kernel {

// reduce topk for each thread block and save the result to temp storage
template <typename T, int BLOCK_SIZE, int BLOCKS_PER_SEQ>
__global__ void topk_within_block(const T* __restrict logits,
T* __restrict tmp_logits,
int* __restrict tmp_topk_ids,
T* __restrict tmp_topk_logits,
int max_top_k,
const int* __restrict top_ks,
int vocab_size) {
__global__ void partial_topk_within_block(const T* __restrict logits,
T* __restrict tmp_logits,
int* __restrict tmp_topk_ids,
T* __restrict tmp_topk_logits,
int max_top_k,
const int* __restrict top_ks,
int vocab_size) {
typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;

Expand All @@ -33,15 +33,15 @@ __global__ void topk_within_block(const T* __restrict logits,
const int block_lane = bid % BLOCKS_PER_SEQ;
const int k = top_ks[batch_id];

const int tmp_log_buf_idx = batch_id * vocab_size;
const int tmp_topk_buf_idx =
const int tmp_logits_base = batch_id * vocab_size;
const int tmp_topk_base =
batch_id * BLOCKS_PER_SEQ * max_top_k + block_lane * k;

// copy log_probs to tmp_log_probs for modifying
#pragma unroll
for (int elem_id = tid + block_lane * BLOCK_SIZE; elem_id < vocab_size;
elem_id += BLOCK_SIZE * BLOCKS_PER_SEQ) {
const int index = elem_id + tmp_log_buf_idx;
for (int id = tid + block_lane * BLOCK_SIZE; id < vocab_size;
id += BLOCK_SIZE * BLOCKS_PER_SEQ) {
const int index = id + tmp_logits_base;
tmp_logits[index] = logits[index];
}

Expand All @@ -53,52 +53,55 @@ __global__ void topk_within_block(const T* __restrict logits,
partial.init();

#pragma unroll
for (int elem_id = tid + block_lane * BLOCK_SIZE; elem_id < vocab_size;
elem_id += BLOCK_SIZE * BLOCKS_PER_SEQ) {
const int index = elem_id + tmp_log_buf_idx;
for (int id = tid + block_lane * BLOCK_SIZE; id < vocab_size;
id += BLOCK_SIZE * BLOCKS_PER_SEQ) {
const int index = id + tmp_logits_base;
partial.insert(tmp_logits[index], index);
}

// reduce within each block
TopK_2<T> total =
BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2<T>);

// save the topk idx and value to temp storage
if (tid == 0) {
const int index = tmp_topk_buf_idx + ite;
const int index = tmp_topk_base + ite;
tmp_topk_ids[index] = total.p;
tmp_topk_logits[index] = total.u;
// remove the largest item by setting the score to -MAX_T_VAL
tmp_logits[total.p] = -MAX_T_VAL;
}

// wait for all threads to finish
__syncthreads();
}
}

// reduce topk across blocks for each batch
template <typename T, int BLOCK_SIZE, int BLOCKS_PER_SEQ>
__global__ void topk_sampling(int* output_ids,
float* output_log_probs,
const int* __restrict tmp_topk_ids,
T* __restrict tmp_topk_logits,
int max_top_k,
const int* __restrict top_ks,
const float* __restrict top_ps,
curandState_t* curandstate) {
__global__ void topk_sampling_across_blocks(int* output_ids,
float* output_log_probs,
const int* __restrict tmp_topk_ids,
T* __restrict tmp_topk_logits,
int max_top_k,
const int* __restrict top_ks,
const float* __restrict top_ps,
curandState_t* curandstate) {
typedef cub::BlockReduce<TopK_2<float>, BLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;

const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
const T MAX_T_VAL = (std::is_same_v<T, half>) ? HALF_FLT_MAX : FLT_MAX;

const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int k = top_ks[bid];
const float p = top_ps != nullptr ? top_ps[bid] : 1.0f;
const int size = k * BLOCKS_PER_SEQ;
// each block processes one sequence
const int batch_id = blockIdx.x;
const int k = top_ks[batch_id];
const float p = top_ps != nullptr ? top_ps[batch_id] : 1.0f;
const int stride = max_top_k * BLOCKS_PER_SEQ;

// move the pointer to the corresponding batch
T* topk_logits = tmp_topk_logits + bid * stride;
const int* topk_ids = tmp_topk_ids + bid * stride;
T* topk_logits = tmp_topk_logits + batch_id * stride;
const int* topk_ids = tmp_topk_ids + batch_id * stride;

// use shared memory to save temp topk idxs and values
extern __shared__ char smem[]; // idxs + vals for topk
Expand All @@ -113,19 +116,25 @@ __global__ void topk_sampling(int* output_ids,
s_sum_val = 1e-6f;
}

// each block has a partial topk
const int size = k * BLOCKS_PER_SEQ;
// use float to record laggest value
TopK_2<float> partial;
// calculate topk and softmax for each sequence
for (int ite = 0; ite < k; ++ite) {
partial.init();

// merge partial topk from all blocks
#pragma unroll
for (int i = tid; i < size; i += BLOCK_SIZE) {
partial.insert(topk_logits[i], i);
}

// reduce within each block to get the top idx and value
TopK_2<float> total =
BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2<float>);

// save the topk idx and value to shared memory
if (tid == 0) {
if (ite == 0) {
s_max_val = total.u;
Expand All @@ -134,54 +143,59 @@ __global__ void topk_sampling(int* output_ids,
// remove the largest item by setting the score to -MAX_T_VAL
topk_logits[total.p] = -MAX_T_VAL;

// calculate expf(x - max_val) and sum
// calculate expf(x - max_val) and sum for softmax
const float exp_logit = __expf(total.u - s_max_val);
s_vals[ite] = exp_logit;
s_sum_val += exp_logit;
}
__syncthreads();
}

// let thread 0 to sample the id from topk candidates
// let thread 0 sample the id from topk candidates
if (tid == 0) {
float rand_num = curand_uniform(curandstate + bid) * p * s_sum_val;
float rand_num = curand_uniform(curandstate + batch_id) * p * s_sum_val;
for (int i = 0; i < k; ++i) {
const float exp_logit = s_vals[i];
rand_num -= exp_logit;
if (rand_num <= 0 || i == k - 1) {
output_ids[bid] = topk_ids[s_idxs[i]];
output_ids[batch_id] = topk_ids[s_idxs[i]];
// the log_prob is the probability of the selected tokens
const float log_prob = logf(exp_logit / s_sum_val);
output_log_probs[bid] = log_prob;
output_log_probs[batch_id] = log_prob;
break;
}
}
}
}

#define CASE_K(K_MIN, K_MAX, BLOCK_SIZE_1, BLOCK_SIZE_2, BLOCKS_PER_SEQ) \
case K_MIN ... K_MAX: \
topk_within_block<scalar_t, BLOCK_SIZE_1, BLOCKS_PER_SEQ> \
<<<batch_size * BLOCKS_PER_SEQ, BLOCK_SIZE_1, 0, stream>>>( \
_logits, \
tmp_logits, \
tmp_topk_ids, \
tmp_topk_logits, \
max_top_k, \
_top_ks, \
vocab_size); \
topk_sampling<scalar_t, BLOCK_SIZE_2, BLOCKS_PER_SEQ> \
<<<batch_size, \
BLOCK_SIZE_2, \
K_MAX * sizeof(int) + K_MAX * sizeof(float), \
stream>>>(_output_ids, \
_output_log_probs, \
tmp_topk_ids, \
tmp_topk_logits, \
max_top_k, \
_top_ks, \
_top_ps, \
curandstate); \
// topk sampling kernel launcher that calculates the topk for each sequence in
// following steps:
// 1. split the sequence into BLOCKS_PER_SEQ blocks for parallel processing and
// calculate the partial topk for each block
// 2. reduce the partial topk across blocks for each sequence
#define CASE_K_RANGE(K_MIN, K_MAX, BLOCK_SIZE_1, BLOCK_SIZE_2, BLOCKS_PER_SEQ) \
case K_MIN ... K_MAX: \
partial_topk_within_block<scalar_t, BLOCK_SIZE_1, BLOCKS_PER_SEQ> \
<<<batch_size * BLOCKS_PER_SEQ, BLOCK_SIZE_1, 0, stream>>>( \
_logits, \
tmp_logits, \
tmp_topk_ids, \
tmp_topk_logits, \
max_top_k, \
_top_ks, \
vocab_size); \
topk_sampling_across_blocks<scalar_t, BLOCK_SIZE_2, BLOCKS_PER_SEQ> \
<<<batch_size, \
BLOCK_SIZE_2, \
K_MAX * sizeof(int) + K_MAX * sizeof(float), \
stream>>>(_output_ids, \
_output_log_probs, \
tmp_topk_ids, \
tmp_topk_logits, \
max_top_k, \
_top_ks, \
_top_ps, \
curandstate); \
break;

void invoke_topk_sampling(torch::Tensor& output_ids,
Expand All @@ -196,17 +210,19 @@ void invoke_topk_sampling(torch::Tensor& output_ids,
const int vocab_size = logits.size(1);
const int max_blocks_per_seq = 8;

int tmp_logits_size = batch_size * vocab_size; // type float
int tmp_topk_size =
batch_size * max_blocks_per_seq * max_top_k; // type int + float
// tmp_logits to save modified logits
size_t tmp_logits_size = batch_size * vocab_size;
// tmp_topk_* to save topk ids and logits for each block
size_t tmp_topk_size = batch_size * max_blocks_per_seq * max_top_k;
// round up to prevent memory misalignment
tmp_logits_size = ((tmp_logits_size + 3) / 4) * 4;
tmp_topk_size = ((tmp_topk_size + 3) / 4) * 4;

DISPATCH_FLOATING_TYPES(logits.scalar_type(), "tok_kernel", [&] {
CHECK_GE(workspace.numel(),
tmp_logits_size * sizeof(scalar_t) +
tmp_topk_size * (sizeof(int) + sizeof(scalar_t)));
const size_t min_workspace_size =
tmp_logits_size * sizeof(scalar_t) +
tmp_topk_size * (sizeof(int) + sizeof(scalar_t));
assert(workspace.numel() >= min_workspace_size);

// scratch space for topk
scalar_t* tmp_logits = workspace.data_ptr<scalar_t>();
Expand All @@ -224,10 +240,11 @@ void invoke_topk_sampling(torch::Tensor& output_ids,

auto stream = at::cuda::getCurrentCUDAStream();
switch (max_top_k) {
CASE_K(1, 16, 128, 128, 8);
CASE_K(17, 32, 256, 128, 8);
CASE_K(33, 64, 256, 256, 8);
CASE_K(65, 1024, 256, 256, 8);
// K_MIN, K_MAX, BLOCK_SIZE_1, BLOCK_SIZE_2, BLOCKS_PER_SEQ
CASE_K_RANGE(1, 16, 128, 128, 8);
CASE_K_RANGE(17, 32, 256, 128, 8);
CASE_K_RANGE(33, 64, 256, 256, 8);
CASE_K_RANGE(65, 1024, 256, 256, 8);
default:
GLOG(FATAL) << "topk_sampling only supports max_top_k <= 1024 but got "
<< max_top_k;
Expand Down
111 changes: 111 additions & 0 deletions src/kernels/sampling/topp_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// adapted from https://github.com/NVIDIA/FasterTransformer

#include <curand_kernel.h>
#include <cub/cub.cuh>

namespace llm::kernel {

struct RunningTotalOp {
// Running prefix
float running_total;
// Constructor
__device__ RunningTotalOp(float running_total)
: running_total(running_total) {}
__device__ float operator()(float block_aggregate) {
float old_prefix = running_total;
running_total += block_aggregate;
return old_prefix;
}
};

template <typename T, int BLOCK_SIZE>
__global__ void topp_sampling(int* output_ids,
float* output_log_probs,
const int* __restrict sorted_ids,
const T* __restrict sorted_log_probs,
const float* __restrict top_ps,
int vocab_size,
curandState_t* curandstate) {
// shared variables used to communicate between threads in a block
// flag to indicate if the thread should stop scanning
__shared__ int s_stop;
// the random number generated by curand
__shared__ float s_random_num;

constexpr int kWarpSize = 32;
constexpr int kNumWarps = BLOCK_SIZE / kWarpSize;
const int tid = threadIdx.x;
const int batch_id = blockIdx.x;
const int lane_id = tid % kWarpSize;
const int warp_id = tid / kWarpSize;
const float top_p = top_ps[batch_id];

// let thread 0 to initialize the shared variables
if (tid == 0) {
s_stop = 0;
s_random_num = curand_uniform(curandstate + batch_id) * top_p;
}

// TODO: quick path?

// scan the sorted log probs to find the stopping position
typedef cub::BlockScan<float, BLOCK_SIZE> BlockScan;
__shared__ typename BlockScan::TempStorage temp_storage;
// a shared variable to record which lane in each wrap has found the stopping
// position
__shared__ uint32_t s_selected_lane[kNumWarps];
// a accumulative sum of the probs
RunningTotalOp running_total_op(0.0f);

// let lane 0 in each warp to initialize the shared variable
if (lane_id == 0) {
s_selected_lane[warp_id] = 0;
}
__syncthreads();

const int offset = batch_id * vocab_size;
const int end = (vocab_size + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE;
float thread_log_prob = 0.0f;
int active_idx = 0;
for (int i = tid; i < end; i += BLOCK_SIZE) {
const float log_prob =
(i < vocab_size) ? static_cast<float>(sorted_log_probs[offset + i])
: 0.f;
BlockScan(temp_storage)
.InclusiveSum(log_prob, thread_log_prob, running_total_op);
// gathers predicate bits from each thread in the warp
const uint32_t lane_active_mask =
__ballot_sync(0xFFFFFFFF, s_random_num <= thread_log_prob);

active_idx = i;
if (lane_active_mask != 0) {
if (lane_id == 0) {
atomicAdd(&s_stop, 1);
s_selected_lane[warp_id] = lane_active_mask;
}
}
__syncthreads();
if (s_stop > 0) {
break;
}
}

// select first active warp
bool skip = s_selected_lane[warp_id] == 0;
for (int i = 1; i < warp_id; ++i) {
if (s_selected_lane[i] != 0) {
skip = true;
}
}

if (!skip) {
const int active_lane_id = kWarpSize - __popc(s_selected_lane[warp_id]);
if (lane_id == active_lane_id) {
output_ids[batch_id] = sorted_ids[offset + active_idx];
const float log_prob = logf(sorted_log_probs[offset + active_idx]);
output_log_probs[batch_id] = log_prob;
}
}
}

} // namespace llm::kernel

0 comments on commit 35375ae

Please sign in to comment.