Skip to content

Commit

Permalink
feat: non-inplace rope operators (#405)
Browse files Browse the repository at this point in the history
As requested in #403, this PR implements non-inplace rope operators.
  • Loading branch information
yzh119 committed Jul 29, 2024
1 parent 2496f5b commit 74ffba1
Show file tree
Hide file tree
Showing 10 changed files with 555 additions and 17 deletions.
2 changes: 2 additions & 0 deletions docs/api/python/rope.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ Kernels for applying rotary embeddings.

apply_rope_inplace
apply_llama31_rope_inplace
apply_rope
apply_llama31_rope
4 changes: 2 additions & 2 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1440,7 +1440,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg
write_o_reg_gmem<num_warps_x, num_warps_z, num_frags_x, num_frags_y>(
o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len,
/*o_stride_n=*/
partition_kv ? num_qo_heads * head_dim * num_kv_chunks : num_qo_heads * head_dim,
partition_kv ? num_qo_heads * head_dim * num_kv_chunks : num_qo_heads * head_dim,
/*o_stride_h=*/head_dim, group_size);

// write lse
Expand Down Expand Up @@ -1732,7 +1732,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage
write_o_reg_gmem<num_warps_x, num_warps_z, num_frags_x, num_frags_y>(
o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len,
/*o_stride_n=*/
partition_kv ? num_qo_heads * head_dim * num_kv_chunks : num_qo_heads * head_dim,
partition_kv ? num_qo_heads * head_dim * num_kv_chunks : num_qo_heads * head_dim,
/*o_stride_h=*/head_dim, group_size);

// write lse
Expand Down
174 changes: 174 additions & 0 deletions include/flashinfer/pos_enc.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,86 @@ __global__ void BatchQKApplyRotaryInPlaceKernel(
}
}

template <bool interleave, uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType,
typename IdType>
__global__ void BatchQKApplyRotaryKernel(DType* __restrict__ q, DType* __restrict__ k,
DType* __restrict__ q_rope, DType* __restrict__ k_rope,
IdType* __restrict__ indptr, IdType* __restrict__ offsets,
uint32_t batch_size, uint32_t num_qo_heads,
uint32_t num_kv_heads, size_t q_stride_n,
size_t q_stride_h, size_t k_stride_n, size_t k_stride_h,
float smooth_a, float smooth_b, float rope_rcp_scale,
float rope_rcp_theta) {
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
const uint32_t bdy = blockDim.y;
vec_t<float, vec_size> freq;
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
if constexpr (interleave) {
freq[i] = __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) / 2)) / float(head_dim));
} else {
freq[i] = __powf(rope_rcp_theta,
float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim));
}

float smooth = freq[i] * smooth_a + smooth_b;
smooth = max(0.0f, min(1.0f, smooth)); // clamp to [0, 1]
freq[i] = (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i];
}

if (bx < batch_size * num_qo_heads) {
// apply rotary to q
const uint32_t batch_idx = bx / num_qo_heads;
const uint32_t qo_head_idx = bx % num_qo_heads;
const uint32_t seq_len = indptr[batch_idx + 1] - indptr[batch_idx];
const uint32_t offset = offsets[batch_idx];
#pragma unroll 2
for (uint32_t i = 0; i < (seq_len + bdy - 1) / bdy; ++i) {
vec_t<float, vec_size> q_vec;
if (i * bdy + ty < seq_len) {
DType* q_ptr = q + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, qo_head_idx, 0,
q_stride_n, q_stride_h);
DType* q_rope_ptr =
q_rope + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, qo_head_idx, 0,
/*q_stride_n=*/num_qo_heads * head_dim,
/*q_stride_h=*/head_dim);
if constexpr (interleave) {
q_vec =
vec_apply_llama_rope_interleave<vec_size, bdx>(q_ptr, freq, offset + i * bdy + ty);
} else {
q_vec = vec_apply_llama_rope<vec_size, bdx>(q_ptr, freq, offset + i * bdy + ty);
}
q_vec.cast_store(q_rope_ptr + tx * vec_size);
}
}
} else {
// apply rotary to k
uint32_t batch_idx = (bx - batch_size * num_qo_heads) / num_kv_heads;
uint32_t kv_head_idx = (bx - batch_size * num_qo_heads) % num_kv_heads;
const uint32_t seq_len = indptr[batch_idx + 1] - indptr[batch_idx];
const uint32_t offset = offsets[batch_idx];
#pragma unroll 2
for (uint32_t i = 0; i < (seq_len + bdy - 1) / bdy; ++i) {
vec_t<float, vec_size> k_vec;
if (i * bdy + ty < seq_len) {
DType* k_ptr = k + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, kv_head_idx, 0,
k_stride_n, k_stride_h);
DType* k_rope_ptr =
k_rope + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, kv_head_idx, 0,
/*kv_stride_n=*/num_kv_heads * head_dim,
/*kv_stride_h=*/head_dim);
if constexpr (interleave) {
k_vec =
vec_apply_llama_rope_interleave<vec_size, bdx>(k_ptr, freq, offset + i * bdy + ty);
} else {
k_vec = vec_apply_llama_rope<vec_size, bdx>(k_ptr, freq, offset + i * bdy + ty);
}
k_vec.cast_store(k_rope_ptr + +tx * vec_size);
}
}
}
}

#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \
if (interleave) { \
const bool INTERLEAVE = true; \
Expand Down Expand Up @@ -289,6 +369,100 @@ cudaError_t BatchQKApplyLlama31RotaryInPlace(
return cudaSuccess;
}

template <typename DType, typename IdType>
cudaError_t BatchQKApplyRotary(DType* __restrict__ q, DType* __restrict__ k,
DType* __restrict__ q_rope, DType* __restrict__ k_rope,
IdType* __restrict__ indptr, IdType* __restrict__ offsets,
uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads,
uint32_t head_dim, size_t q_stride_n, size_t q_stride_h,
size_t k_stride_n, size_t k_stride_h, bool interleave,
float rope_scale, float rope_theta, cudaStream_t stream = nullptr) {
float rope_rcp_scale = 1.0f / rope_scale;
float rope_rcp_theta = 1.0f / rope_theta;
float smooth_a = 0.f;
float smooth_b = 0.f;

DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
constexpr uint32_t bdx = HEAD_DIM / vec_size;
uint32_t num_threads = std::max(128U, bdx);
uint32_t bdy = num_threads / bdx;
dim3 nblks(batch_size * (num_qo_heads + num_kv_heads));
dim3 nthrs(bdx, bdy);
auto kernel = BatchQKApplyRotaryKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>;
void* args[] = {(void*)&q,
(void*)&k,
(void*)&q_rope,
(void*)&k_rope,
(void*)&indptr,
(void*)&offsets,
(void*)&batch_size,
(void*)&num_qo_heads,
(void*)&num_kv_heads,
(void*)&q_stride_n,
(void*)&q_stride_h,
(void*)&k_stride_n,
(void*)&k_stride_h,
(void*)&smooth_a,
(void*)&smooth_b,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
});
});

return cudaSuccess;
}

template <typename DType, typename IdType>
cudaError_t BatchQKApplyLlama31Rotary(DType* __restrict__ q, DType* __restrict__ k,
DType* __restrict__ q_rope, DType* __restrict__ k_rope,
IdType* __restrict__ indptr, IdType* __restrict__ offsets,
uint32_t batch_size, uint32_t num_qo_heads,
uint32_t num_kv_heads, uint32_t head_dim, size_t q_stride_n,
size_t q_stride_h, size_t k_stride_n, size_t k_stride_h,
bool interleave, float rope_scale, float rope_theta,
float low_freq_factor, float high_freq_factor,
float old_context_length, cudaStream_t stream = nullptr) {
float rope_rcp_scale = 1.0f / rope_scale;
float rope_rcp_theta = 1.0f / rope_theta;
float smooth_a = old_context_length / (2 * M_PI * high_freq_factor - 2 * M_PI * low_freq_factor);
float smooth_b = -1.0f / (high_freq_factor / low_freq_factor - 1.0f);

DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
constexpr uint32_t bdx = HEAD_DIM / vec_size;
uint32_t num_threads = std::max(128U, bdx);
uint32_t bdy = num_threads / bdx;
dim3 nblks(batch_size * (num_qo_heads + num_kv_heads));
dim3 nthrs(bdx, bdy);
auto kernel = BatchQKApplyRotaryKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>;
void* args[] = {(void*)&q,
(void*)&k,
(void*)&q_rope,
(void*)&k_rope,
(void*)&indptr,
(void*)&offsets,
(void*)&batch_size,
(void*)&num_qo_heads,
(void*)&num_kv_heads,
(void*)&q_stride_n,
(void*)&q_stride_h,
(void*)&k_stride_n,
(void*)&k_stride_h,
(void*)&smooth_a,
(void*)&smooth_b,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
});
});

return cudaSuccess;
}

} // namespace flashinfer

#endif // FLASHINFER_POS_ENC_CUH_
2 changes: 2 additions & 0 deletions python/csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place");
m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace,
"Apply Llama 3.1 style RoPE in-place");
m.def("apply_rope", &apply_rope, "Apply RoPE");
m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE");
m.def("packbits", &packbits, "GPU packbits operator");
m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator");
py::class_<BatchDecodeWithPagedKVCachePyTorchWrapper>(m,
Expand Down
10 changes: 10 additions & 0 deletions python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor
float rope_theta, float low_freq_factor, float high_freq_factor,
float old_context_length);

std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
torch::Tensor offsets, bool interleave, float rope_scale,
float rope_theta);

std::vector<torch::Tensor> apply_llama31_rope(torch::Tensor q, torch::Tensor k,
torch::Tensor indptr, torch::Tensor offsets,
bool interleave, float rope_scale, float rope_theta,
float low_freq_factor, float high_freq_factor,
float old_context_length);

torch::Tensor packbits(torch::Tensor x, const std::string& bitorder);

torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr,
Expand Down
100 changes: 99 additions & 1 deletion python/csrc/rope.cu
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,102 @@ void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor
std::string(cudaGetErrorString(status)));
return true;
});
}
}

std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
torch::Tensor offsets, bool interleave, float rope_scale,
float rope_theta) {
CHECK_CUDA(q); // not necessarily contiguous
CHECK_CUDA(k); // not necessarily contiguous
CHECK_INPUT(indptr);
CHECK_INPUT(offsets);

auto device = q.device();
CHECK_EQ(k.device(), device);
CHECK_DIM(3, q); // q: (nnz, H_Q, D)
CHECK_DIM(3, k); // k: (nnz, H_K, D)
CHECK_DIM(1, indptr); // indptr: (B + 1)
CHECK_DIM(1, offsets); // offsets: (B)
CHECK_EQ(q.size(0), k.size(0));
CHECK_EQ(q.size(2), k.size(2));
unsigned int num_qo_heads = q.size(1);
unsigned int num_kv_heads = k.size(1);
unsigned int head_dim = q.size(2);
unsigned int batch_size = offsets.size(0);
CHECK_EQ(indptr.size(0), batch_size + 1);
size_t q_stride_n = q.stride(0);
size_t q_stride_h = q.stride(1);
size_t k_stride_n = k.stride(0);
size_t k_stride_h = k.stride(1);
indptr = indptr.to(torch::kInt32);
offsets = offsets.to(torch::kInt32);
// NOTE(Zihao): empty_like do not copy strides so it's okay to use it here.
auto q_rope = torch::empty_like(q);
auto k_rope = torch::empty_like(k);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
cudaError_t status = BatchQKApplyRotary(
static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k.data_ptr()),
static_cast<c_type*>(q_rope.data_ptr()), static_cast<c_type*>(k_rope.data_ptr()),
static_cast<int32_t*>(indptr.data_ptr()), static_cast<int32_t*>(offsets.data_ptr()),
batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n,
k_stride_h, interleave, rope_scale, rope_theta, torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "BatchQKApplyRotary failed with error code " +
std::string(cudaGetErrorString(status)));
return true;
});

return {q_rope, k_rope};
}

std::vector<torch::Tensor> apply_llama31_rope(torch::Tensor q, torch::Tensor k,
torch::Tensor indptr, torch::Tensor offsets,
bool interleave, float rope_scale, float rope_theta,
float low_freq_factor, float high_freq_factor,
float old_context_length) {
CHECK_CUDA(q); // not necessarily contiguous
CHECK_CUDA(k); // not necessarily contiguous
CHECK_INPUT(indptr);
CHECK_INPUT(offsets);

auto device = q.device();
CHECK_EQ(k.device(), device);
CHECK_DIM(3, q); // q: (nnz, H_Q, D)
CHECK_DIM(3, k); // k: (nnz, H_K, D)
CHECK_DIM(1, indptr); // indptr: (B + 1)
CHECK_DIM(1, offsets); // offsets: (B)
CHECK_EQ(q.size(0), k.size(0));
CHECK_EQ(q.size(2), k.size(2));
unsigned int num_qo_heads = q.size(1);
unsigned int num_kv_heads = k.size(1);
unsigned int head_dim = q.size(2);
unsigned int batch_size = offsets.size(0);
CHECK_EQ(indptr.size(0), batch_size + 1);
size_t q_stride_n = q.stride(0);
size_t q_stride_h = q.stride(1);
size_t k_stride_n = k.stride(0);
size_t k_stride_h = k.stride(1);
indptr = indptr.to(torch::kInt32);
offsets = offsets.to(torch::kInt32);

// NOTE(Zihao): empty_like do not copy strides so it's okay to use it here.
auto q_rope = torch::empty_like(q);
auto k_rope = torch::empty_like(k);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
cudaError_t status = BatchQKApplyLlama31Rotary(
static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k.data_ptr()),
static_cast<c_type*>(q_rope.data_ptr()), static_cast<c_type*>(k_rope.data_ptr()),
static_cast<int32_t*>(indptr.data_ptr()), static_cast<int32_t*>(offsets.data_ptr()),
batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n,
k_stride_h, interleave, rope_scale, rope_theta, low_freq_factor, high_freq_factor,
old_context_length, torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "BatchQKApplyLlama31Rotary failed with error code " +
std::string(cudaGetErrorString(status)));
return true;
});

return {q_rope, k_rope};
}
2 changes: 1 addition & 1 deletion python/csrc/single_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache(
TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative");
const LogitsPostHook logits_post_hook =
logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone;

bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] {
return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] {
Expand Down
7 changes: 6 additions & 1 deletion python/flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@
chain_speculative_sampling,
)
from .norm import rmsnorm
from .rope import apply_rope_inplace, apply_llama31_rope_inplace
from .rope import (
apply_rope_inplace,
apply_llama31_rope_inplace,
apply_rope,
apply_llama31_rope,
)
from .group_gemm import SegmentGEMMWrapper
from .quantization import packbits, segment_packbits

Expand Down
Loading

0 comments on commit 74ffba1

Please sign in to comment.