Skip to content

Commit

Permalink
feat: add llama 3.1 style rope (#401)
Browse files Browse the repository at this point in the history
Reference implementation:
https://github.com/meta-llama/llama-models/blob/709a61fd810157f75fbb314e7287089eec06d9c3/models/llama3_1/api/model.py#L41

This PR also expose the `BatchQKApplyRotaryInPlaceKernel` to pytorch
APIs, previous they are only used in TVM wrappers.
  • Loading branch information
yzh119 committed Jul 27, 2024
1 parent 73a764b commit 4c89dec
Show file tree
Hide file tree
Showing 12 changed files with 641 additions and 34 deletions.
14 changes: 14 additions & 0 deletions docs/api/python/rope.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
.. _apirope:

flashinfer.rope
===============

Kernels for applying rotary embeddings.

.. currentmodule:: flashinfer.rope

.. autosummary::
:toctree: _generate

apply_rope_inplace
apply_llama31_rope_inplace
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@ FlashInfer is a library for Large Language Models that provides high-performance
api/python/sampling
api/python/group_gemm
api/python/norm
api/python/rope
api/python/quantization
178 changes: 145 additions & 33 deletions include/flashinfer/pos_enc.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#ifndef FLASHINFER_POS_ENC_CUH_
#define FLASHINFER_POS_ENC_CUH_

#include <cmath>
#include <string>

#include "layout.cuh"
Expand Down Expand Up @@ -93,20 +94,56 @@ __device__ __forceinline__ vec_t<float, vec_size> vec_apply_llama_rope(
return vec;
}

template <uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType, typename IdType>
__global__ void BatchQKApplyRotaryInPlaceKernel(DType* __restrict__ q, DType* __restrict__ k,
IdType* __restrict__ indptr,
IdType* __restrict__ offsets, uint32_t batch_size,
uint32_t num_qo_heads, uint32_t num_kv_heads,
float rope_rcp_scale, float rope_rcp_theta) {
/*!
* \brief Apply RoPE (Rotary Positional Embeddings) to x[0: head_dim] with interleave,
* return thread-local vector.
* \tparam vec_size A template integer indicates the vector size used
* in the kernel
* \tparam bdx A template integer indicates the blockDim.x
* \tparam T A template type indicates the x data type
* \param x A pointer to the start of x data
* \param freq A vector of float indicates the thread-local rope frequency
* \param offset A integer indicates the offset of the position in RoPE
*/
template <uint32_t vec_size, uint32_t bdx, typename T>
__device__ __forceinline__ vec_t<float, vec_size> vec_apply_llama_rope_interleave(
const T* x, const vec_t<float, vec_size>& freq, int32_t offset) {
vec_t<float, vec_size> vec, vec_before;
vec.cast_load(x + threadIdx.x * vec_size);
vec_before = vec;

#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
float embed = float(offset) * freq[i];
float cos, sin;
__sincosf(embed, &sin, &cos);
vec[i] = vec[i] * cos + ((i % 2 == 0) ? -vec_before[i ^ 1] : vec_before[i ^ 1]) * sin;
}
return vec;
}

template <bool interleave, uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType,
typename IdType>
__global__ void BatchQKApplyRotaryInPlaceKernel(
DType* __restrict__ q, DType* __restrict__ k, IdType* __restrict__ indptr,
IdType* __restrict__ offsets, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads,
size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, float smooth_a,
float smooth_b, float rope_rcp_scale, float rope_rcp_theta) {
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
const uint32_t bdy = blockDim.y;
vec_t<float, vec_size> freq;
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
freq[i] =
rope_rcp_scale *
__powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim));
if constexpr (interleave) {
freq[i] = __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) / 2)) / float(head_dim));
} else {
freq[i] = __powf(rope_rcp_theta,
float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim));
}

float smooth = freq[i] * smooth_a + smooth_b;
smooth = max(0.0f, min(1.0f, smooth)); // clamp to [0, 1]
freq[i] = (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i];
}

if (bx < batch_size * num_qo_heads) {
Expand All @@ -120,8 +157,13 @@ __global__ void BatchQKApplyRotaryInPlaceKernel(DType* __restrict__ q, DType* __
vec_t<float, vec_size> q_vec;
if (i * bdy + ty < seq_len) {
DType* q_ptr = q + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, qo_head_idx, 0,
num_qo_heads * head_dim, head_dim);
q_vec = vec_apply_llama_rope<vec_size, bdx>(q_ptr, freq, offset + i * bdy + ty);
q_stride_n, q_stride_h);
if constexpr (interleave) {
q_vec =
vec_apply_llama_rope_interleave<vec_size, bdx>(q_ptr, freq, offset + i * bdy + ty);
} else {
q_vec = vec_apply_llama_rope<vec_size, bdx>(q_ptr, freq, offset + i * bdy + ty);
}
q_vec.cast_store(q_ptr + tx * vec_size);
}
}
Expand All @@ -136,42 +178,112 @@ __global__ void BatchQKApplyRotaryInPlaceKernel(DType* __restrict__ q, DType* __
vec_t<float, vec_size> k_vec;
if (i * bdy + ty < seq_len) {
DType* k_ptr = k + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, kv_head_idx, 0,
num_kv_heads * head_dim, head_dim);
k_vec = vec_apply_llama_rope<vec_size, bdx>(k_ptr, freq, offset + i * bdy + ty);
k_stride_n, k_stride_h);
if constexpr (interleave) {
k_vec =
vec_apply_llama_rope_interleave<vec_size, bdx>(k_ptr, freq, offset + i * bdy + ty);
} else {
k_vec = vec_apply_llama_rope<vec_size, bdx>(k_ptr, freq, offset + i * bdy + ty);
}
k_vec.cast_store(k_ptr + tx * vec_size);
}
}
}
}

#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \
if (interleave) { \
const bool INTERLEAVE = true; \
__VA_ARGS__ \
} else { \
const bool INTERLEAVE = false; \
__VA_ARGS__ \
}

template <typename DType, typename IdType>
cudaError_t BatchQKApplyRotaryInPlace(DType* __restrict__ q, DType* __restrict__ k,
IdType* __restrict__ indptr, IdType* __restrict__ offsets,
uint32_t batch_size, uint32_t num_qo_heads,
uint32_t num_kv_heads, uint32_t head_dim,
float rope_scale = 1.f, float rope_theta = 1e4,
uint32_t num_kv_heads, uint32_t head_dim, size_t q_stride_n,
size_t q_stride_h, size_t k_stride_n, size_t k_stride_h,
bool interleave, float rope_scale, float rope_theta,
cudaStream_t stream = nullptr) {
float rope_rcp_scale = 1.0f / rope_scale;
float rope_rcp_theta = 1.0f / rope_theta;
float smooth_a = 0.f;
float smooth_b = 0.f;

DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
constexpr uint32_t bdx = HEAD_DIM / vec_size;
uint32_t num_threads = std::max(128U, bdx);
uint32_t bdy = num_threads / bdx;
dim3 nblks(batch_size * (num_qo_heads + num_kv_heads));
dim3 nthrs(bdx, bdy);
auto kernel =
BatchQKApplyRotaryInPlaceKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>;
void* args[] = {(void*)&q,
(void*)&k,
(void*)&indptr,
(void*)&offsets,
(void*)&batch_size,
(void*)&num_qo_heads,
(void*)&num_kv_heads,
(void*)&q_stride_n,
(void*)&q_stride_h,
(void*)&k_stride_n,
(void*)&k_stride_h,
(void*)&smooth_a,
(void*)&smooth_b,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
});
});

return cudaSuccess;
}

template <typename DType, typename IdType>
cudaError_t BatchQKApplyLlama31RotaryInPlace(
DType* __restrict__ q, DType* __restrict__ k, IdType* __restrict__ indptr,
IdType* __restrict__ offsets, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads,
uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h,
bool interleave, float rope_scale, float rope_theta, float low_freq_factor,
float high_freq_factor, float old_context_length, cudaStream_t stream = nullptr) {
float rope_rcp_scale = 1.0f / rope_scale;
float rope_rcp_theta = 1.0f / rope_theta;
float smooth_a = old_context_length / (2 * M_PI * high_freq_factor - 2 * M_PI * low_freq_factor);
float smooth_b = -1.0f / (high_freq_factor / low_freq_factor - 1.0f);

DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
constexpr uint32_t bdx = HEAD_DIM / vec_size;
uint32_t num_threads = std::max(128U, bdx);
uint32_t bdy = num_threads / bdx;
dim3 nblks(batch_size * (num_qo_heads + num_kv_heads));
dim3 nthrs(bdx, bdy);
auto kernel = BatchQKApplyRotaryInPlaceKernel<HEAD_DIM, vec_size, bdx, DType, IdType>;
void* args[] = {(void*)&q,
(void*)&k,
(void*)&indptr,
(void*)&offsets,
(void*)&batch_size,
(void*)&num_qo_heads,
(void*)&num_kv_heads,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
constexpr uint32_t bdx = HEAD_DIM / vec_size;
uint32_t num_threads = std::max(128U, bdx);
uint32_t bdy = num_threads / bdx;
dim3 nblks(batch_size * (num_qo_heads + num_kv_heads));
dim3 nthrs(bdx, bdy);
auto kernel =
BatchQKApplyRotaryInPlaceKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>;
void* args[] = {(void*)&q,
(void*)&k,
(void*)&indptr,
(void*)&offsets,
(void*)&batch_size,
(void*)&num_qo_heads,
(void*)&num_kv_heads,
(void*)&q_stride_n,
(void*)&q_stride_h,
(void*)&k_stride_n,
(void*)&k_stride_h,
(void*)&smooth_a,
(void*)&smooth_b,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
});
});

return cudaSuccess;
Expand Down
3 changes: 3 additions & 0 deletions python/csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("chain_speculative_sampling", &chain_speculative_sampling,
"Speculative sampling from sequence of probabilities");
m.def("rmsnorm", &rmsnorm, "Root mean square normalization");
m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place");
m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace,
"Apply Llama 3.1 style RoPE in-place");
m.def("packbits", &packbits, "GPU packbits operator");
m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator");
py::class_<BatchDecodeWithPagedKVCachePyTorchWrapper>(m,
Expand Down
8 changes: 8 additions & 0 deletions python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,14 @@ torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tenso

torch::Tensor rmsnorm(torch::Tensor x, torch::Tensor w, double eps);

void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta);

void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
torch::Tensor offsets, bool interleave, float rope_scale,
float rope_theta, float low_freq_factor, float high_freq_factor,
float old_context_length);

torch::Tensor packbits(torch::Tensor x, const std::string& bitorder);

torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr,
Expand Down
105 changes: 105 additions & 0 deletions python/csrc/rope.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <flashinfer/pos_enc.cuh>

#include "flashinfer_ops.h"
#include "pytorch_extension_utils.h"

using namespace flashinfer;

void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
torch::Tensor offsets, bool interleave, float rope_scale,
float rope_theta) {
CHECK_CUDA(q); // not necessarily contiguous
CHECK_CUDA(k); // not necessarily contiguous
CHECK_INPUT(indptr);
CHECK_INPUT(offsets);

auto device = q.device();
CHECK_EQ(k.device(), device);
CHECK_DIM(3, q); // q: (nnz, H_Q, D)
CHECK_DIM(3, k); // k: (nnz, H_K, D)
CHECK_DIM(1, indptr); // indptr: (B + 1)
CHECK_DIM(1, offsets); // offsets: (B)
CHECK_EQ(q.size(0), k.size(0));
CHECK_EQ(q.size(2), k.size(2));
unsigned int num_qo_heads = q.size(1);
unsigned int num_kv_heads = k.size(1);
unsigned int head_dim = q.size(2);
unsigned int batch_size = offsets.size(0);
CHECK_EQ(indptr.size(0), batch_size + 1);
size_t q_stride_n = q.stride(0);
size_t q_stride_h = q.stride(1);
size_t k_stride_n = k.stride(0);
size_t k_stride_h = k.stride(1);
indptr = indptr.to(torch::kInt32);
offsets = offsets.to(torch::kInt32);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
cudaError_t status = BatchQKApplyRotaryInPlace(
static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k.data_ptr()),
static_cast<int32_t*>(indptr.data_ptr()), static_cast<int32_t*>(offsets.data_ptr()),
batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n,
k_stride_h, interleave, rope_scale, rope_theta, torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "BatchQKApplyRotaryInPlace failed with error code " +
std::string(cudaGetErrorString(status)));
return true;
});
}

void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
torch::Tensor offsets, bool interleave, float rope_scale,
float rope_theta, float low_freq_factor, float high_freq_factor,
float old_context_length) {
CHECK_CUDA(q); // not necessarily contiguous
CHECK_CUDA(k); // not necessarily contiguous
CHECK_INPUT(indptr);
CHECK_INPUT(offsets);

auto device = q.device();
CHECK_EQ(k.device(), device);
CHECK_DIM(3, q); // q: (nnz, H_Q, D)
CHECK_DIM(3, k); // k: (nnz, H_K, D)
CHECK_DIM(1, indptr); // indptr: (B + 1)
CHECK_DIM(1, offsets); // offsets: (B)
CHECK_EQ(q.size(0), k.size(0));
CHECK_EQ(q.size(2), k.size(2));
unsigned int num_qo_heads = q.size(1);
unsigned int num_kv_heads = k.size(1);
unsigned int head_dim = q.size(2);
unsigned int batch_size = offsets.size(0);
CHECK_EQ(indptr.size(0), batch_size + 1);
size_t q_stride_n = q.stride(0);
size_t q_stride_h = q.stride(1);
size_t k_stride_n = k.stride(0);
size_t k_stride_h = k.stride(1);
indptr = indptr.to(torch::kInt32);
offsets = offsets.to(torch::kInt32);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
cudaError_t status = BatchQKApplyLlama31RotaryInPlace(
static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k.data_ptr()),
static_cast<int32_t*>(indptr.data_ptr()), static_cast<int32_t*>(offsets.data_ptr()),
batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n,
k_stride_h, interleave, rope_scale, rope_theta, low_freq_factor, high_freq_factor,
old_context_length, torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "BatchQKApplyLlama31RotaryInPlace failed with error code " +
std::string(cudaGetErrorString(status)));
return true;
});
}
1 change: 1 addition & 0 deletions python/flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
chain_speculative_sampling,
)
from .norm import rmsnorm
from .rope import apply_rope_inplace, apply_llama31_rope_inplace
from .group_gemm import SegmentGEMMWrapper
from .quantization import packbits, segment_packbits

Expand Down
Loading

0 comments on commit 4c89dec

Please sign in to comment.