From 4c89decadc8ae9f261cae97c350064156e66bc09 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sat, 27 Jul 2024 03:37:27 -0700 Subject: [PATCH] feat: add llama 3.1 style rope (#401) Reference implementation: https://github.com/meta-llama/llama-models/blob/709a61fd810157f75fbb314e7287089eec06d9c3/models/llama3_1/api/model.py#L41 This PR also expose the `BatchQKApplyRotaryInPlaceKernel` to pytorch APIs, previous they are only used in TVM wrappers. --- docs/api/python/rope.rst | 14 +++ docs/index.rst | 1 + include/flashinfer/pos_enc.cuh | 178 +++++++++++++++++++++++++++------ python/csrc/flashinfer_ops.cu | 3 + python/csrc/flashinfer_ops.h | 8 ++ python/csrc/rope.cu | 105 +++++++++++++++++++ python/flashinfer/__init__.py | 1 + python/flashinfer/rope.py | 149 +++++++++++++++++++++++++++ python/setup.py | 1 + python/tests/rope_reference.py | 70 +++++++++++++ python/tests/test_rope.py | 138 +++++++++++++++++++++++++ src/tvm_wrapper.cu | 7 +- 12 files changed, 641 insertions(+), 34 deletions(-) create mode 100644 docs/api/python/rope.rst create mode 100644 python/csrc/rope.cu create mode 100644 python/flashinfer/rope.py create mode 100644 python/tests/rope_reference.py create mode 100644 python/tests/test_rope.py diff --git a/docs/api/python/rope.rst b/docs/api/python/rope.rst new file mode 100644 index 00000000..b27ac7e9 --- /dev/null +++ b/docs/api/python/rope.rst @@ -0,0 +1,14 @@ +.. _apirope: + +flashinfer.rope +=============== + +Kernels for applying rotary embeddings. + +.. currentmodule:: flashinfer.rope + +.. autosummary:: + :toctree: _generate + + apply_rope_inplace + apply_llama31_rope_inplace diff --git a/docs/index.rst b/docs/index.rst index c8ed40e0..ce0129f7 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -35,4 +35,5 @@ FlashInfer is a library for Large Language Models that provides high-performance api/python/sampling api/python/group_gemm api/python/norm + api/python/rope api/python/quantization diff --git a/include/flashinfer/pos_enc.cuh b/include/flashinfer/pos_enc.cuh index 7f6cab39..efa5c8bf 100644 --- a/include/flashinfer/pos_enc.cuh +++ b/include/flashinfer/pos_enc.cuh @@ -16,6 +16,7 @@ #ifndef FLASHINFER_POS_ENC_CUH_ #define FLASHINFER_POS_ENC_CUH_ +#include #include #include "layout.cuh" @@ -93,20 +94,56 @@ __device__ __forceinline__ vec_t vec_apply_llama_rope( return vec; } -template -__global__ void BatchQKApplyRotaryInPlaceKernel(DType* __restrict__ q, DType* __restrict__ k, - IdType* __restrict__ indptr, - IdType* __restrict__ offsets, uint32_t batch_size, - uint32_t num_qo_heads, uint32_t num_kv_heads, - float rope_rcp_scale, float rope_rcp_theta) { +/*! + * \brief Apply RoPE (Rotary Positional Embeddings) to x[0: head_dim] with interleave, + * return thread-local vector. + * \tparam vec_size A template integer indicates the vector size used + * in the kernel + * \tparam bdx A template integer indicates the blockDim.x + * \tparam T A template type indicates the x data type + * \param x A pointer to the start of x data + * \param freq A vector of float indicates the thread-local rope frequency + * \param offset A integer indicates the offset of the position in RoPE + */ +template +__device__ __forceinline__ vec_t vec_apply_llama_rope_interleave( + const T* x, const vec_t& freq, int32_t offset) { + vec_t vec, vec_before; + vec.cast_load(x + threadIdx.x * vec_size); + vec_before = vec; + +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + float embed = float(offset) * freq[i]; + float cos, sin; + __sincosf(embed, &sin, &cos); + vec[i] = vec[i] * cos + ((i % 2 == 0) ? -vec_before[i ^ 1] : vec_before[i ^ 1]) * sin; + } + return vec; +} + +template +__global__ void BatchQKApplyRotaryInPlaceKernel( + DType* __restrict__ q, DType* __restrict__ k, IdType* __restrict__ indptr, + IdType* __restrict__ offsets, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, + size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, float smooth_a, + float smooth_b, float rope_rcp_scale, float rope_rcp_theta) { uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; const uint32_t bdy = blockDim.y; vec_t freq; #pragma unroll for (uint32_t i = 0; i < vec_size; ++i) { - freq[i] = - rope_rcp_scale * - __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim)); + if constexpr (interleave) { + freq[i] = __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) / 2)) / float(head_dim)); + } else { + freq[i] = __powf(rope_rcp_theta, + float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim)); + } + + float smooth = freq[i] * smooth_a + smooth_b; + smooth = max(0.0f, min(1.0f, smooth)); // clamp to [0, 1] + freq[i] = (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i]; } if (bx < batch_size * num_qo_heads) { @@ -120,8 +157,13 @@ __global__ void BatchQKApplyRotaryInPlaceKernel(DType* __restrict__ q, DType* __ vec_t q_vec; if (i * bdy + ty < seq_len) { DType* q_ptr = q + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, qo_head_idx, 0, - num_qo_heads * head_dim, head_dim); - q_vec = vec_apply_llama_rope(q_ptr, freq, offset + i * bdy + ty); + q_stride_n, q_stride_h); + if constexpr (interleave) { + q_vec = + vec_apply_llama_rope_interleave(q_ptr, freq, offset + i * bdy + ty); + } else { + q_vec = vec_apply_llama_rope(q_ptr, freq, offset + i * bdy + ty); + } q_vec.cast_store(q_ptr + tx * vec_size); } } @@ -136,42 +178,112 @@ __global__ void BatchQKApplyRotaryInPlaceKernel(DType* __restrict__ q, DType* __ vec_t k_vec; if (i * bdy + ty < seq_len) { DType* k_ptr = k + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, kv_head_idx, 0, - num_kv_heads * head_dim, head_dim); - k_vec = vec_apply_llama_rope(k_ptr, freq, offset + i * bdy + ty); + k_stride_n, k_stride_h); + if constexpr (interleave) { + k_vec = + vec_apply_llama_rope_interleave(k_ptr, freq, offset + i * bdy + ty); + } else { + k_vec = vec_apply_llama_rope(k_ptr, freq, offset + i * bdy + ty); + } k_vec.cast_store(k_ptr + tx * vec_size); } } } } +#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \ + if (interleave) { \ + const bool INTERLEAVE = true; \ + __VA_ARGS__ \ + } else { \ + const bool INTERLEAVE = false; \ + __VA_ARGS__ \ + } + template cudaError_t BatchQKApplyRotaryInPlace(DType* __restrict__ q, DType* __restrict__ k, IdType* __restrict__ indptr, IdType* __restrict__ offsets, uint32_t batch_size, uint32_t num_qo_heads, - uint32_t num_kv_heads, uint32_t head_dim, - float rope_scale = 1.f, float rope_theta = 1e4, + uint32_t num_kv_heads, uint32_t head_dim, size_t q_stride_n, + size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, + bool interleave, float rope_scale, float rope_theta, cudaStream_t stream = nullptr) { float rope_rcp_scale = 1.0f / rope_scale; float rope_rcp_theta = 1.0f / rope_theta; + float smooth_a = 0.f; + float smooth_b = 0.f; + + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32); + constexpr uint32_t bdx = HEAD_DIM / vec_size; + uint32_t num_threads = std::max(128U, bdx); + uint32_t bdy = num_threads / bdx; + dim3 nblks(batch_size * (num_qo_heads + num_kv_heads)); + dim3 nthrs(bdx, bdy); + auto kernel = + BatchQKApplyRotaryInPlaceKernel; + void* args[] = {(void*)&q, + (void*)&k, + (void*)&indptr, + (void*)&offsets, + (void*)&batch_size, + (void*)&num_qo_heads, + (void*)&num_kv_heads, + (void*)&q_stride_n, + (void*)&q_stride_h, + (void*)&k_stride_n, + (void*)&k_stride_h, + (void*)&smooth_a, + (void*)&smooth_b, + (void*)&rope_rcp_scale, + (void*)&rope_rcp_theta}; + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); + }); + }); + + return cudaSuccess; +} + +template +cudaError_t BatchQKApplyLlama31RotaryInPlace( + DType* __restrict__ q, DType* __restrict__ k, IdType* __restrict__ indptr, + IdType* __restrict__ offsets, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, + bool interleave, float rope_scale, float rope_theta, float low_freq_factor, + float high_freq_factor, float old_context_length, cudaStream_t stream = nullptr) { + float rope_rcp_scale = 1.0f / rope_scale; + float rope_rcp_theta = 1.0f / rope_theta; + float smooth_a = old_context_length / (2 * M_PI * high_freq_factor - 2 * M_PI * low_freq_factor); + float smooth_b = -1.0f / (high_freq_factor / low_freq_factor - 1.0f); - DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { - constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32); - constexpr uint32_t bdx = HEAD_DIM / vec_size; - uint32_t num_threads = std::max(128U, bdx); - uint32_t bdy = num_threads / bdx; - dim3 nblks(batch_size * (num_qo_heads + num_kv_heads)); - dim3 nthrs(bdx, bdy); - auto kernel = BatchQKApplyRotaryInPlaceKernel; - void* args[] = {(void*)&q, - (void*)&k, - (void*)&indptr, - (void*)&offsets, - (void*)&batch_size, - (void*)&num_qo_heads, - (void*)&num_kv_heads, - (void*)&rope_rcp_scale, - (void*)&rope_rcp_theta}; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32); + constexpr uint32_t bdx = HEAD_DIM / vec_size; + uint32_t num_threads = std::max(128U, bdx); + uint32_t bdy = num_threads / bdx; + dim3 nblks(batch_size * (num_qo_heads + num_kv_heads)); + dim3 nthrs(bdx, bdy); + auto kernel = + BatchQKApplyRotaryInPlaceKernel; + void* args[] = {(void*)&q, + (void*)&k, + (void*)&indptr, + (void*)&offsets, + (void*)&batch_size, + (void*)&num_qo_heads, + (void*)&num_kv_heads, + (void*)&q_stride_n, + (void*)&q_stride_h, + (void*)&k_stride_n, + (void*)&k_stride_h, + (void*)&smooth_a, + (void*)&smooth_b, + (void*)&rope_rcp_scale, + (void*)&rope_rcp_theta}; + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); + }); }); return cudaSuccess; diff --git a/python/csrc/flashinfer_ops.cu b/python/csrc/flashinfer_ops.cu index 4193f304..79c34b21 100644 --- a/python/csrc/flashinfer_ops.cu +++ b/python/csrc/flashinfer_ops.cu @@ -42,6 +42,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("chain_speculative_sampling", &chain_speculative_sampling, "Speculative sampling from sequence of probabilities"); m.def("rmsnorm", &rmsnorm, "Root mean square normalization"); + m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place"); + m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace, + "Apply Llama 3.1 style RoPE in-place"); m.def("packbits", &packbits, "GPU packbits operator"); m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator"); py::class_(m, diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index d837528f..32617c69 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -75,6 +75,14 @@ torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tenso torch::Tensor rmsnorm(torch::Tensor x, torch::Tensor w, double eps); +void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, + torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta); + +void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, + torch::Tensor offsets, bool interleave, float rope_scale, + float rope_theta, float low_freq_factor, float high_freq_factor, + float old_context_length); + torch::Tensor packbits(torch::Tensor x, const std::string& bitorder); torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr, diff --git a/python/csrc/rope.cu b/python/csrc/rope.cu new file mode 100644 index 00000000..7fb9f483 --- /dev/null +++ b/python/csrc/rope.cu @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "flashinfer_ops.h" +#include "pytorch_extension_utils.h" + +using namespace flashinfer; + +void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, + torch::Tensor offsets, bool interleave, float rope_scale, + float rope_theta) { + CHECK_CUDA(q); // not necessarily contiguous + CHECK_CUDA(k); // not necessarily contiguous + CHECK_INPUT(indptr); + CHECK_INPUT(offsets); + + auto device = q.device(); + CHECK_EQ(k.device(), device); + CHECK_DIM(3, q); // q: (nnz, H_Q, D) + CHECK_DIM(3, k); // k: (nnz, H_K, D) + CHECK_DIM(1, indptr); // indptr: (B + 1) + CHECK_DIM(1, offsets); // offsets: (B) + CHECK_EQ(q.size(0), k.size(0)); + CHECK_EQ(q.size(2), k.size(2)); + unsigned int num_qo_heads = q.size(1); + unsigned int num_kv_heads = k.size(1); + unsigned int head_dim = q.size(2); + unsigned int batch_size = offsets.size(0); + CHECK_EQ(indptr.size(0), batch_size + 1); + size_t q_stride_n = q.stride(0); + size_t q_stride_h = q.stride(1); + size_t k_stride_n = k.stride(0); + size_t k_stride_h = k.stride(1); + indptr = indptr.to(torch::kInt32); + offsets = offsets.to(torch::kInt32); + + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { + cudaError_t status = BatchQKApplyRotaryInPlace( + static_cast(q.data_ptr()), static_cast(k.data_ptr()), + static_cast(indptr.data_ptr()), static_cast(offsets.data_ptr()), + batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n, + k_stride_h, interleave, rope_scale, rope_theta, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, "BatchQKApplyRotaryInPlace failed with error code " + + std::string(cudaGetErrorString(status))); + return true; + }); +} + +void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, + torch::Tensor offsets, bool interleave, float rope_scale, + float rope_theta, float low_freq_factor, float high_freq_factor, + float old_context_length) { + CHECK_CUDA(q); // not necessarily contiguous + CHECK_CUDA(k); // not necessarily contiguous + CHECK_INPUT(indptr); + CHECK_INPUT(offsets); + + auto device = q.device(); + CHECK_EQ(k.device(), device); + CHECK_DIM(3, q); // q: (nnz, H_Q, D) + CHECK_DIM(3, k); // k: (nnz, H_K, D) + CHECK_DIM(1, indptr); // indptr: (B + 1) + CHECK_DIM(1, offsets); // offsets: (B) + CHECK_EQ(q.size(0), k.size(0)); + CHECK_EQ(q.size(2), k.size(2)); + unsigned int num_qo_heads = q.size(1); + unsigned int num_kv_heads = k.size(1); + unsigned int head_dim = q.size(2); + unsigned int batch_size = offsets.size(0); + CHECK_EQ(indptr.size(0), batch_size + 1); + size_t q_stride_n = q.stride(0); + size_t q_stride_h = q.stride(1); + size_t k_stride_n = k.stride(0); + size_t k_stride_h = k.stride(1); + indptr = indptr.to(torch::kInt32); + offsets = offsets.to(torch::kInt32); + + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { + cudaError_t status = BatchQKApplyLlama31RotaryInPlace( + static_cast(q.data_ptr()), static_cast(k.data_ptr()), + static_cast(indptr.data_ptr()), static_cast(offsets.data_ptr()), + batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n, + k_stride_h, interleave, rope_scale, rope_theta, low_freq_factor, high_freq_factor, + old_context_length, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, "BatchQKApplyLlama31RotaryInPlace failed with error code " + + std::string(cudaGetErrorString(status))); + return true; + }); +} \ No newline at end of file diff --git a/python/flashinfer/__init__.py b/python/flashinfer/__init__.py index 184116b2..db818d98 100644 --- a/python/flashinfer/__init__.py +++ b/python/flashinfer/__init__.py @@ -44,6 +44,7 @@ chain_speculative_sampling, ) from .norm import rmsnorm +from .rope import apply_rope_inplace, apply_llama31_rope_inplace from .group_gemm import SegmentGEMMWrapper from .quantization import packbits, segment_packbits diff --git a/python/flashinfer/rope.py b/python/flashinfer/rope.py new file mode 100644 index 00000000..6bb67eb8 --- /dev/null +++ b/python/flashinfer/rope.py @@ -0,0 +1,149 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch + +# mypy: disable-error-code="attr-defined" +try: + from . import _kernels +except ImportError as e: + import os + import logging + + if os.environ.get("BUILD_DOC", "0") == "1": + _kernels = None + logging.warning("Kernels are not loaded in documentation build mode.") + else: + raise e + + +def apply_rope_inplace( + q: torch.Tensor, + k: torch.Tensor, + indptr: torch.Tensor, + offsets: torch.Tensor, + interleave: bool = False, + rope_scale: float = 1, + rope_theta: float = 1e4, +) -> None: + r"""Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor) inplace. + + We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th + segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the + i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always + 0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch. + Please see :ref:`Ragged Tensor tutorial ` for more details about the + ragged tensor. + + Parameters + ---------- + q : torch.Tensor + Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)`, where ``nnz`` is the last + element of ``indptr``. + k : torch.Tensor + Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last + element of ``indptr``. + indptr : torch.Tensor + Indptr tensor, shape: ``(batch_size + 1)``. + offsets : torch.Tensor + The relative position offsets of each query in the batch, shape: ``(batch_size)``. + interleave : bool + Whether to use interleaved layout in the last dimension, default: ``False``. + + * If ``True``, the last dimension of the query/key tensor is interleaved, i.e., + we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. + + * If ``False``, the last dimension of the query/key tensor is not interleaved, i.e., + we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half + dimensions ``([..., head_dim//2:])``. + + rope_scale : float + The scaling factor used in the rope embedding, default: ``1``. + rope_theta : float + The theta value used in the rope embedding, default: ``1e4``. + """ + return _kernels.apply_rope_inplace( + q, k, indptr, offsets, interleave, rope_scale, rope_theta + ) + + +def apply_llama31_rope_inplace( + q: torch.Tensor, + k: torch.Tensor, + indptr: torch.Tensor, + offsets: torch.Tensor, + interleave: bool = True, + rope_scale: float = 8, + rope_theta: float = 5e5, + low_freq_factor: float = 1, + high_freq_factor: float = 4, + old_context_len: int = 8192, +) -> None: + r"""Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as + RaggedTensor) inplace. + + We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th + segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the + i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always + 0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch. + Please see :ref:`Ragged Tensor tutorial ` for more details about the + ragged tensor. + + Parameters + ---------- + q : torch.Tensor + Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)``, where ``nnz`` is the last + element of ``indptr``. + k : torch.Tensor + Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last + element of ``indptr``. + indptr : torch.Tensor + Indptr tensor, shape: ``(batch_size + 1)``. + offsets : torch.Tensor + The relative position offsets of each query in the batch, shape: ``(batch_size)``. + interleave : bool + Whether to use interleaved layout in the last dimension, default: ``False``. + + * If ``True``, the last dimension of the query/key tensor is interleaved, i.e., + we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. + + * If ``False``, the last dimension of the query/key tensor is not interleaved, i.e., + we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half + dimensions ``([..., head_dim//2:])``. + + rope_scale : float + The scaling factor used in the rope embedding, default: ``8``. + rope_theta : float + The theta value used in the rope embedding, default: ``5e5``. + low_freq_factor : float + The low frequency factor used in Llama 3.1 RoPE, default: ``1``. + high_freq_factor : float + The high frequency factor used in Llama 3.1 RoPE, default: ``4``. + old_context_len : int + The old context length used in Llama 3.1 RoPE, default: ``8192``. + """ + return _kernels.apply_llama31_rope_inplace( + q, + k, + indptr, + offsets, + interleave, + rope_scale, + rope_theta, + low_freq_factor, + high_freq_factor, + float(old_context_len), + ) diff --git a/python/setup.py b/python/setup.py index 448e6acd..b6424a21 100644 --- a/python/setup.py +++ b/python/setup.py @@ -318,6 +318,7 @@ def __init__(self, *args, **kwargs) -> None: "csrc/batch_prefill.cu", "csrc/sampling.cu", "csrc/norm.cu", + "csrc/rope.cu", "csrc/group_gemm.cu", "csrc/quantization.cu", ] diff --git a/python/tests/rope_reference.py b/python/tests/rope_reference.py new file mode 100644 index 00000000..4b1daa07 --- /dev/null +++ b/python/tests/rope_reference.py @@ -0,0 +1,70 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. +import torch +import math +from typing import Tuple + + +def apply_scaling(freqs: torch.Tensor): + # Values obtained from grid search + scale_factor = 8 + low_freq_factor = 1 + high_freq_factor = 4 + old_context_len = 8192 # original llama3 length + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + new_freqs = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / scale_factor) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + + +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False +): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + if use_scaled: + freqs = apply_scaling(freqs) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) diff --git a/python/tests/test_rope.py b/python/tests/test_rope.py new file mode 100644 index 00000000..e0676126 --- /dev/null +++ b/python/tests/test_rope.py @@ -0,0 +1,138 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import numpy as np +import flashinfer +import pytest +from rope_reference import * + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("qkv_len", [1, 4, 19, 204]) +@pytest.mark.parametrize("num_qo_heads", [8, 16]) +@pytest.mark.parametrize("num_kv_heads", [8]) +@pytest.mark.parametrize("offset", [0, 15, 99]) +@pytest.mark.parametrize("head_dim", [64, 128, 256]) +def test_llama_rope( + batch_size, + qkv_len, + num_qo_heads, + num_kv_heads, + offset, + head_dim, +): + nnz = batch_size * qkv_len + qkv_packed = torch.randn( + nnz, + (num_qo_heads + 2 * num_kv_heads) * head_dim, + dtype=torch.float16, + device="cuda:0", + ) + q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim) + k = qkv_packed[ + :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim + ].reshape(nnz, num_kv_heads, head_dim) + indptr = torch.tensor( + [i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0" + ) + offsets = torch.full((batch_size,), offset, dtype=torch.int32, device="cuda:0") + + # reference implementation + freqs_cis = precompute_freqs_cis( + head_dim, qkv_len + offset, 10000.0, use_scaled=False + ).to("cuda:0") + q_rope, k_rope = apply_rotary_emb( + q.reshape(batch_size, qkv_len, num_qo_heads, head_dim), + k.reshape(batch_size, qkv_len, num_kv_heads, head_dim), + freqs_cis[offset : offset + qkv_len], + ) + q_rope = q_rope.reshape(nnz, num_qo_heads, head_dim) + k_rope = k_rope.reshape(nnz, num_kv_heads, head_dim) + + # flashinfer implementation + flashinfer.apply_rope_inplace( + q, k, indptr, offsets, interleave=True, rope_theta=1e4 + ) + + # compare + np.testing.assert_allclose( + q_rope.cpu().numpy(), q.cpu().numpy(), rtol=1e-3, atol=1e-3 + ) + np.testing.assert_allclose( + k_rope.cpu().numpy(), k.cpu().numpy(), rtol=1e-3, atol=1e-3 + ) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("qkv_len", [1, 4, 19, 204]) +@pytest.mark.parametrize("num_qo_heads", [8, 16]) +@pytest.mark.parametrize("num_kv_heads", [8]) +@pytest.mark.parametrize("offset", [0, 15, 99]) +@pytest.mark.parametrize("head_dim", [64, 128, 256]) +def test_llama31_rope( + batch_size, + qkv_len, + num_qo_heads, + num_kv_heads, + offset, + head_dim, +): + nnz = batch_size * qkv_len + qkv_packed = torch.randn( + nnz, + (num_qo_heads + 2 * num_kv_heads) * head_dim, + dtype=torch.float16, + device="cuda:0", + ) + q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim) + k = qkv_packed[ + :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim + ].reshape(nnz, num_kv_heads, head_dim) + indptr = torch.tensor( + [i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0" + ) + offsets = torch.full((batch_size,), offset, dtype=torch.int32, device="cuda:0") + + # reference implementation + freqs_cis = precompute_freqs_cis( + head_dim, qkv_len + offset, 5e5, use_scaled=True + ).to("cuda:0") + q_rope, k_rope = apply_rotary_emb( + q.reshape(batch_size, qkv_len, num_qo_heads, head_dim), + k.reshape(batch_size, qkv_len, num_kv_heads, head_dim), + freqs_cis[offset : offset + qkv_len], + ) + q_rope = q_rope.reshape(nnz, num_qo_heads, head_dim) + k_rope = k_rope.reshape(nnz, num_kv_heads, head_dim) + + # flashinfer implementation + flashinfer.apply_llama31_rope_inplace( + q, k, indptr, offsets, interleave=True, rope_theta=5e5 + ) + + # compare + np.testing.assert_allclose( + q_rope.cpu().numpy(), q.cpu().numpy(), rtol=1e-3, atol=1e-3 + ) + np.testing.assert_allclose( + k_rope.cpu().numpy(), k.cpu().numpy(), rtol=1e-3, atol=1e-3 + ) + + +if __name__ == "__main__": + test_llama_rope(2, 1, 8, 8, 1, 128) + test_llama31_rope(1, 1, 8, 8, 0, 128) diff --git a/src/tvm_wrapper.cu b/src/tvm_wrapper.cu index 73c41dbf..809fb585 100644 --- a/src/tvm_wrapper.cu +++ b/src/tvm_wrapper.cu @@ -681,12 +681,17 @@ void _FlashInferBatchQKApplyRotaryInPlace(DLTensor* q, DLTensor* k, DLTensor* in DLTensor* offsets, int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim, double rope_scale, double rope_theta) { + size_t q_stride_n = q->strides[0]; + size_t q_stride_h = q->strides[1]; + size_t k_stride_n = k->strides[0]; + size_t k_stride_h = k->strides[1]; DISPATCH_TVM_CUDA_DTYPE( q->dtype, dtype, {DISPATCH_TVM_CUDA_IDTYPE(indptr->dtype, idtype, { cudaError_t status = BatchQKApplyRotaryInPlace( static_cast(q->data), static_cast(k->data), static_cast(indptr->data), static_cast(offsets->data), batch_size, - num_qo_heads, num_kv_heads, head_dim, rope_scale, rope_theta); + num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n, k_stride_h, + /*interleave=*/false, rope_scale, rope_theta); if (status != cudaSuccess) { LOG(FATAL) << "FlashInfer CUDA kernel error " << cudaGetErrorString(status); }