From 034e0ee0106fafe729faadb26ab5ef1d8f1a0f14 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 16 Oct 2023 00:59:57 -0700 Subject: [PATCH] Implement PagedAttention V2 (#1348) --- .../kernels/benchmark_paged_attention.py | 197 +++++++ csrc/attention.cpp | 28 +- csrc/attention/attention_kernels.cu | 484 +++++++++++++++--- csrc/attention/dtype_bfloat16.cuh | 5 + tests/kernels/test_attention.py | 71 ++- vllm/model_executor/layers/attention.py | 118 +++-- 6 files changed, 764 insertions(+), 139 deletions(-) create mode 100644 benchmarks/kernels/benchmark_paged_attention.py diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py new file mode 100644 index 000000000000..0ef803076767 --- /dev/null +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -0,0 +1,197 @@ +import argparse +import random +import time + +import torch + +from vllm import attention_ops + +NUM_BLOCKS = 1024 +PARTITION_SIZE = 512 + + +@torch.inference_mode() +def main( + version: str, + num_seqs: int, + context_len: int, + num_query_heads: int, + num_kv_heads: int, + head_size: int, + use_alibi: bool, + block_size: int, + dtype: torch.dtype, + seed: int, + do_profile: bool, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + scale = float(1.0 / (head_size**0.5)) + query = torch.empty(num_seqs, + num_query_heads, + head_size, + dtype=dtype, + device="cuda") + query.uniform_(-scale, scale) + + assert num_query_heads % num_kv_heads == 0 + num_queries_per_kv = num_query_heads // num_kv_heads + head_mapping = torch.repeat_interleave( + torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"), + num_queries_per_kv) + alibi_slopes = None + if use_alibi: + alibi_slopes = torch.randn(num_query_heads, + dtype=torch.float, + device="cuda") + + context_lens = [context_len for _ in range(num_seqs)] + max_context_len = max(context_lens) + context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda") + + # Create the block tables. + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + block_tables = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") + + # Create the KV cache. + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x) + key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device="cuda") + key_cache.uniform_(-scale, scale) + value_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size, block_size) + value_cache = torch.empty(size=value_cache_shape, + dtype=dtype, + device="cuda") + value_cache.uniform_(-scale, scale) + + # Prepare for the paged attention kernel. + output = torch.empty_like(query) + if version == "v2": + num_partitions = ((max_context_len + PARTITION_SIZE - 1) // + PARTITION_SIZE) + tmp_output = torch.empty( + size=(num_seqs, num_query_heads, num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_query_heads, num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + + def run_benchmark(num_iters: int, profile: bool = False) -> float: + torch.cuda.synchronize() + if profile: + torch.cuda.cudart().cudaProfilerStart() + start_time = time.perf_counter() + + for _ in range(num_iters): + if version == "v1": + attention_ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + head_mapping, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + ) + elif version == "v2": + attention_ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + head_mapping, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + ) + else: + raise ValueError(f"Invalid version: {version}") + torch.cuda.synchronize() + + end_time = time.perf_counter() + if profile: + torch.cuda.cudart().cudaProfilerStart() + return (end_time - start_time) / num_iters + + # Warmup. + print("Warming up...") + run_benchmark(num_iters=3, profile=False) + + # Benchmark. + if do_profile: + latency = run_benchmark(num_iters=1, profile=True) + else: + latency = run_benchmark(num_iters=100, profile=False) + print(f"Kernel running time: {latency * 1000000:.3f} us") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description="Benchmark the paged attention kernel.") + parser.add_argument("--version", + type=str, + choices=["v1", "v2"], + default="v2") + parser.add_argument("--batch-size", type=int, default=8) + parser.add_argument("--context-len", type=int, default=4096) + parser.add_argument("--num-query-heads", type=int, default=64) + parser.add_argument("--num-kv-heads", type=int, default=8) + parser.add_argument("--head-size", + type=int, + choices=[64, 80, 96, 112, 128, 256], + default=128) + parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) + parser.add_argument("--use-alibi", action="store_true") + parser.add_argument("--dtype", + type=str, + choices=["half", "bfloat16", "float"], + default="half") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--profile", action="store_true") + args = parser.parse_args() + print(args) + + if args.num_query_heads % args.num_kv_heads != 0: + raise ValueError("num_query_heads must be divisible by num_kv_heads") + dtype_to_torch_dtype = { + "half": torch.half, + "bfloat16": torch.bfloat16, + "float": torch.float, + } + main( + version=args.version, + num_seqs=args.batch_size, + context_len=args.context_len, + num_query_heads=args.num_query_heads, + num_kv_heads=args.num_kv_heads, + head_size=args.head_size, + block_size=args.block_size, + use_alibi=args.use_alibi, + dtype=dtype_to_torch_dtype[args.dtype], + seed=args.seed, + do_profile=args.profile, + ) diff --git a/csrc/attention.cpp b/csrc/attention.cpp index 6be8a6d25ae4..bd93fd71b733 100644 --- a/csrc/attention.cpp +++ b/csrc/attention.cpp @@ -1,7 +1,7 @@ #include #include -void single_query_cached_kv_attention( +void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, @@ -14,9 +14,29 @@ void single_query_cached_kv_attention( int max_context_len, const c10::optional& alibi_slopes); +void paged_attention_v2( + torch::Tensor& out, + torch::Tensor& exp_sums, + torch::Tensor& max_logits, + torch::Tensor& tmp_out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& head_mapping, + float scale, + torch::Tensor& block_tables, + torch::Tensor& context_lens, + int block_size, + int max_context_len, + const c10::optional& alibi_slopes); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def( - "single_query_cached_kv_attention", - &single_query_cached_kv_attention, - "Compute the attention between an input query and the cached key/value tensors"); + "paged_attention_v1", + &paged_attention_v1, + "Compute the attention between an input query and the cached keys/values using PagedAttention."); + m.def( + "paged_attention_v2", + &paged_attention_v2, + "PagedAttention V2."); } diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 505c63d2efd7..ee6b715adaef 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -26,6 +26,7 @@ #define WARP_SIZE 32 #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) namespace vllm { @@ -65,14 +66,18 @@ inline __device__ float block_sum(float* red_smem, float sum) { return __shfl_sync(uint32_t(-1), sum, 0); } -// Grid: (num_heads, num_seqs). +// TODO(woosuk): Merge the last two dimensions of the grid. +// Grid: (num_heads, num_seqs, max_num_partitions). template< typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, - int NUM_THREADS> -__global__ void single_query_cached_kv_attention_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + int NUM_THREADS, + int PARTITION_SIZE = 0> // Zero means no partitioning. +__device__ void paged_attention_kernel( + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] @@ -85,10 +90,33 @@ __global__ void single_query_cached_kv_attention_kernel( const int q_stride, const int kv_block_stride, const int kv_head_stride) { + const int seq_idx = blockIdx.y; + const int partition_idx = blockIdx.z; + const int max_num_partitions = gridDim.z; + constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; + const int context_len = context_lens[seq_idx]; + if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) { + // No work to do. Terminate the thread block. + return; + } + + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks; + + // [start_block_idx, end_block_idx) is the range of blocks to process. + const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; + const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); + const int num_blocks = end_block_idx - start_block_idx; + + // [start_token_idx, end_token_idx) is the range of tokens to process. + const int start_token_idx = start_block_idx * BLOCK_SIZE; + const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); + const int num_tokens = end_token_idx - start_token_idx; + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); - constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; + constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; const int thread_idx = threadIdx.x; const int warp_idx = thread_idx / WARP_SIZE; @@ -97,7 +125,6 @@ __global__ void single_query_cached_kv_attention_kernel( const int head_idx = blockIdx.x; const int num_heads = gridDim.x; const int kv_head_idx = head_mapping[head_idx]; - const int seq_idx = blockIdx.y; const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; // A vector type to store a part of a key or a query. @@ -142,15 +169,12 @@ __global__ void single_query_cached_kv_attention_kernel( constexpr int x = 16 / sizeof(scalar_t); float qk_max = -FLT_MAX; - const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; - const int context_len = context_lens[seq_idx]; - const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; - // Iterate over the key blocks. // Each warp fetches a block of keys for each iteration. // Each thread group in a warp fetches a key from the block, and computes // dot product with the query. - for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { const int physical_block_number = block_table[block_idx]; // Load a key to registers. @@ -184,7 +208,7 @@ __global__ void single_query_cached_kv_attention_kernel( // Store the partial reductions to shared memory. // NOTE(woosuk): It is required to zero out the masked logits. const bool mask = token_idx >= context_len; - logits[token_idx] = mask ? 0.f : qk; + logits[token_idx - start_token_idx] = mask ? 0.f : qk; // Update the max value. qk_max = mask ? qk_max : fmaxf(qk_max, qk); } @@ -215,7 +239,7 @@ __global__ void single_query_cached_kv_attention_kernel( // Get the sum of the exp values. float exp_sum = 0.f; - for (int i = thread_idx; i < context_len; i += NUM_THREADS) { + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { float val = __expf(logits[i] - qk_max); logits[i] = val; exp_sum += val; @@ -224,11 +248,23 @@ __global__ void single_query_cached_kv_attention_kernel( // Compute softmax. const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); - for (int i = thread_idx; i < context_len; i += NUM_THREADS) { + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { logits[i] *= inv_sum; } __syncthreads(); + // If partitioning is enabled, store the max logit and exp_sum. + if (USE_PARTITIONING && thread_idx == 0) { + float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + + partition_idx; + *max_logits_ptr = qk_max; + float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + + partition_idx; + *exp_sums_ptr = exp_sum; + } + // Each thread will fetch 16 bytes from the value cache at a time. constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); using V_vec = typename Vec::Type; @@ -237,7 +273,7 @@ __global__ void single_query_cached_kv_attention_kernel( constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; - constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER; + constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. float accs[NUM_ROWS_PER_THREAD]; @@ -248,12 +284,12 @@ __global__ void single_query_cached_kv_attention_kernel( scalar_t zero_value; zero(zero_value); - for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { const int physical_block_number = block_table[block_idx]; const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; L_vec logits_vec; - from_float(logits_vec, *reinterpret_cast(logits + token_idx)); + from_float(logits_vec, *reinterpret_cast(logits + token_idx - start_token_idx)); const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride + kv_head_idx * kv_head_stride; @@ -263,7 +299,7 @@ __global__ void single_query_cached_kv_attention_kernel( if (row_idx < HEAD_SIZE) { const int offset = row_idx * BLOCK_SIZE + physical_block_offset; V_vec v_vec = *reinterpret_cast(v_ptr + offset); - if (block_idx == num_blocks - 1) { + if (block_idx == num_context_blocks - 1) { // NOTE(woosuk): When v_vec contains the tokens that are out of the context, // we should explicitly zero out the values since they may contain NaNs. // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 @@ -327,7 +363,9 @@ __global__ void single_query_cached_kv_attention_kernel( // Write the final output. if (warp_idx == 0) { - scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + + partition_idx * HEAD_SIZE; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; @@ -338,13 +376,167 @@ __global__ void single_query_cached_kv_attention_kernel( } } +// Grid: (num_heads, num_seqs, 1). +template< + typename scalar_t, + int HEAD_SIZE, + int BLOCK_SIZE, + int NUM_THREADS> +__global__ void paged_attention_v1_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int* __restrict__ head_mapping, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride) { + paged_attention_kernel( + /* exp_sums */ nullptr, /* max_logits */ nullptr, + out, q, k_cache, v_cache, head_mapping, scale, block_tables, context_lens, + max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride); +} + +// Grid: (num_heads, num_seqs, max_num_partitions). +template< + typename scalar_t, + int HEAD_SIZE, + int BLOCK_SIZE, + int NUM_THREADS, + int PARTITION_SIZE> +__global__ void paged_attention_v2_kernel( + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int* __restrict__ head_mapping, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride) { + paged_attention_kernel( + exp_sums, max_logits, tmp_out, q, k_cache, v_cache, head_mapping, scale, + block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, + q_stride, kv_block_stride, kv_head_stride); +} + +// Grid: (num_heads, num_seqs). +template< + typename scalar_t, + int HEAD_SIZE, + int NUM_THREADS, + int PARTITION_SIZE> +__global__ void paged_attention_v2_reduce_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_partitions) { + const int num_heads = gridDim.x; + const int head_idx = blockIdx.x; + const int seq_idx = blockIdx.y; + const int context_len = context_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + if (num_partitions == 1) { + // No need to reduce. Only copy tmp_out to out. + scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { + out_ptr[i] = tmp_out_ptr[i]; + } + // Terminate the thread block. + return; + } + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int warp_idx = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + + // Size: 2 * num_partitions. + extern __shared__ char shared_mem[]; + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // Load max logits to shared memory. + float* shared_max_logits = reinterpret_cast(shared_mem); + const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float max_logit = -FLT_MAX; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + const float l = max_logits_ptr[i]; + shared_max_logits[i] = l; + max_logit = fmaxf(max_logit, l); + } + __syncthreads(); + + // Get the global max logit. + // Reduce within the warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = max_logit; + } + __syncthreads(); + // Reduce across warps. + max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); + } + // Broadcast the max value to all threads. + max_logit = __shfl_sync(uint32_t(-1), max_logit, 0); + + // Load rescaled exp sums to shared memory. + float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions); + const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float global_exp_sum = 0.0f; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + float l = shared_max_logits[i]; + float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit); + global_exp_sum += rescaled_exp_sum; + shared_exp_sums[i] = rescaled_exp_sum; + } + __syncthreads(); + global_exp_sum = block_sum(&red_smem[NUM_WARPS], global_exp_sum); + const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); + + // Aggregate tmp_out to out. + const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; +#pragma unroll + for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { + float acc = 0.0f; + for (int j = 0; j < num_partitions; ++j) { + acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum; + } + from_float(out_ptr[i], acc); + } +} + } // namespace vllm -#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ +#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ cudaFuncSetAttribute( \ - vllm::single_query_cached_kv_attention_kernel, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ - vllm::single_query_cached_kv_attention_kernel \ + vllm::paged_attention_v1_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ + vllm::paged_attention_v1_kernel \ <<>>( \ out_ptr, \ query_ptr, \ @@ -365,7 +557,7 @@ template< typename T, int BLOCK_SIZE, int NUM_THREADS = 128> -void single_query_cached_kv_attention_launcher( +void paged_attention_v1_launcher( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, @@ -401,45 +593,206 @@ void single_query_cached_kv_attention_launcher( int* context_lens_ptr = context_lens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; + int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE; int logits_size = padded_max_context_len * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len // Keep that in sync with the logic here! int shared_mem_size = std::max(logits_size, outputs_size); - dim3 grid(num_heads, num_seqs); + dim3 grid(num_heads, num_seqs, 1); + dim3 block(NUM_THREADS); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. However, we can easily extend this + // to support any head size which is a multiple of 16. + case 64: + LAUNCH_PAGED_ATTENTION_V1(64); + break; + case 80: + LAUNCH_PAGED_ATTENTION_V1(80); + break; + case 96: + LAUNCH_PAGED_ATTENTION_V1(96); + break; + case 112: + LAUNCH_PAGED_ATTENTION_V1(112); + break; + case 128: + LAUNCH_PAGED_ATTENTION_V1(128); + break; + case 256: + LAUNCH_PAGED_ATTENTION_V1(256); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_V1_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v1_launcher( \ + out, \ + query, \ + key_cache, \ + value_cache, \ + head_mapping, \ + scale, \ + block_tables, \ + context_lens, \ + max_context_len, \ + alibi_slopes); + +// NOTE(woosuk): To reduce the compilation time, we omitted block sizes +// 1, 2, 4, 64, 128, 256. +#define CALL_V1_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 8: \ + CALL_V1_LAUNCHER(T, 8); \ + break; \ + case 16: \ + CALL_V1_LAUNCHER(T, 16); \ + break; \ + case 32: \ + CALL_V1_LAUNCHER(T, 32); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + +void paged_attention_v1( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& head_mapping, // [num_heads] + float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& context_lens, // [num_seqs] + int block_size, + int max_context_len, + const c10::optional& alibi_slopes) { + if (query.dtype() == at::ScalarType::Float) { + CALL_V1_LAUNCHER_BLOCK_SIZE(float); + } else if (query.dtype() == at::ScalarType::Half) { + CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } +} + +#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ + vllm::paged_attention_v2_kernel \ + <<>>( \ + exp_sums_ptr, \ + max_logits_ptr, \ + tmp_out_ptr, \ + query_ptr, \ + key_cache_ptr, \ + value_cache_ptr, \ + head_mapping_ptr, \ + scale, \ + block_tables_ptr, \ + context_lens_ptr, \ + max_num_blocks_per_seq, \ + alibi_slopes_ptr, \ + q_stride, \ + kv_block_stride, \ + kv_head_stride); \ + vllm::paged_attention_v2_reduce_kernel \ + <<>>( \ + out_ptr, \ + exp_sums_ptr, \ + max_logits_ptr, \ + tmp_out_ptr, \ + context_lens_ptr, \ + max_num_partitions); + +template< + typename T, + int BLOCK_SIZE, + int NUM_THREADS = 128, + int PARTITION_SIZE = 512> +void paged_attention_v2_launcher( + torch::Tensor& out, + torch::Tensor& exp_sums, + torch::Tensor& max_logits, + torch::Tensor& tmp_out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& head_mapping, + float scale, + torch::Tensor& block_tables, + torch::Tensor& context_lens, + int max_context_len, + const c10::optional& alibi_slopes) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); + assert(head_size % thread_group_size == 0); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = alibi_slopes ? + reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* head_mapping_ptr = reinterpret_cast(head_mapping.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* context_lens_ptr = context_lens.data_ptr(); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + int logits_size = PARTITION_SIZE * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + + // For paged attention v2 kernel. + dim3 grid(num_heads, num_seqs, max_num_partitions); + int shared_mem_size = std::max(logits_size, outputs_size); + // For paged attention v2 reduce kernel. + dim3 reduce_grid(num_heads, num_seqs); + int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); + dim3 block(NUM_THREADS); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); switch (head_size) { - // NOTE(woosuk): To reduce the compilation time, we omitted head sizes - // 32, 160, 192. - // case 32: - // LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS); - // break; + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. However, we can easily extend this + // to support any head size which is a multiple of 16. case 64: - LAUNCH_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS); + LAUNCH_PAGED_ATTENTION_V2(64); break; case 80: - LAUNCH_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS); + LAUNCH_PAGED_ATTENTION_V2(80); break; case 96: - LAUNCH_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS); + LAUNCH_PAGED_ATTENTION_V2(96); break; case 112: - LAUNCH_ATTENTION_KERNEL(T, 112, BLOCK_SIZE, NUM_THREADS); + LAUNCH_PAGED_ATTENTION_V2(112); break; case 128: - LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS); + LAUNCH_PAGED_ATTENTION_V2(128); break; - // case 160: - // LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS); - // break; - // case 192: - // LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS); - // break; case 256: - LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS); + LAUNCH_PAGED_ATTENTION_V2(256); break; default: TORCH_CHECK(false, "Unsupported head size: ", head_size); @@ -447,9 +800,12 @@ void single_query_cached_kv_attention_launcher( } } -#define CALL_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ - single_query_cached_kv_attention_launcher( \ +#define CALL_V2_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v2_launcher( \ out, \ + exp_sums, \ + max_logits, \ + tmp_out, \ query, \ key_cache, \ value_cache, \ @@ -462,42 +818,27 @@ void single_query_cached_kv_attention_launcher( // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ +#define CALL_V2_LAUNCHER_BLOCK_SIZE(T) \ switch (block_size) { \ - /* case 1: */ \ - /* CALL_KERNEL_LAUNCHER(T, 1); */ \ - /* break; */ \ - /* case 2: */ \ - /* CALL_KERNEL_LAUNCHER(T, 2); */ \ - /* break; */ \ - /* case 4: */ \ - /* CALL_KERNEL_LAUNCHER(T, 4); */ \ - /* break; */ \ case 8: \ - CALL_KERNEL_LAUNCHER(T, 8); \ + CALL_V2_LAUNCHER(T, 8); \ break; \ case 16: \ - CALL_KERNEL_LAUNCHER(T, 16); \ + CALL_V2_LAUNCHER(T, 16); \ break; \ case 32: \ - CALL_KERNEL_LAUNCHER(T, 32); \ + CALL_V2_LAUNCHER(T, 32); \ break; \ - /* case 64: */ \ - /* CALL_KERNEL_LAUNCHER(T, 64); */ \ - /* break; */ \ - /* case 128: */ \ - /* CALL_KERNEL_LAUNCHER(T, 128); */ \ - /* break; */ \ - /* case 256: */ \ - /* CALL_KERNEL_LAUNCHER(T, 256); */ \ - /* break; */ \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ break; \ } -void single_query_cached_kv_attention( +void paged_attention_v2( torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] @@ -509,11 +850,11 @@ void single_query_cached_kv_attention( int max_context_len, const c10::optional& alibi_slopes) { if (query.dtype() == at::ScalarType::Float) { - CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float); + CALL_V2_LAUNCHER_BLOCK_SIZE(float); } else if (query.dtype() == at::ScalarType::Half) { - CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t); + CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t); } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); + CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); } else { TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); } @@ -522,3 +863,4 @@ void single_query_cached_kv_attention( #undef WARP_SIZE #undef MAX #undef MIN +#undef DIVIDE_ROUND_UP diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh index 2154bfcf8631..5786f77f7bca 100644 --- a/csrc/attention/dtype_bfloat16.cuh +++ b/csrc/attention/dtype_bfloat16.cuh @@ -420,6 +420,11 @@ inline __device__ void from_float(bf16_8_t& dst, Float8_ src) { #endif } +// From bfloat16 to float32. +inline __device__ float to_float(__nv_bfloat16 u) { + return __bfloat162float(u); +} + // Zero-out a variable. inline __device__ void zero(__nv_bfloat16& dst) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 59d8b0a59ce6..31d78dd1bcf9 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -14,13 +14,14 @@ # - 512 as a buffer MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 NUM_BLOCKS = 128 # Arbitrary values for testing +PARTITION_SIZE = 512 DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_GEN_SEQS = [7] # Arbitrary values for testing -NUM_PREFILL_SEQS = [1, 3, 7] # Arbitrary values for testing +NUM_PREFILL_SEQS = [3] # Arbitrary values for testing NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing HEAD_SIZES = [64, 80, 96, 112, 128, 256] -BLOCK_SIZES = [8, 16, 32] +BLOCK_SIZES = [16, 32] USE_ALIBI = [False, True] SEEDS = [0] @@ -96,6 +97,7 @@ def ref_single_query_cached_kv_attention( output[i].copy_(out, non_blocking=True) +@pytest.mark.parametrize("version", ["v1", "v2"]) @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -103,9 +105,9 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) -@torch.inference_mode() -def test_single_query_cached_kv_attention( +def test_paged_attention( kv_cache_factory, + version: str, num_seqs: int, num_heads: Tuple[int, int], head_size: int, @@ -162,19 +164,54 @@ def test_single_query_cached_kv_attention( # Call the paged attention kernel. output = torch.empty_like(query) - attention_ops.single_query_cached_kv_attention( - output, - query, - key_cache, - value_cache, - head_mapping, - scale, - block_tables, - context_lens, - block_size, - max_context_len, - alibi_slopes, - ) + if version == "v1": + attention_ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + head_mapping, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + ) + elif version == "v2": + num_partitions = ((max_context_len + PARTITION_SIZE - 1) // + PARTITION_SIZE) + assert PARTITION_SIZE % block_size == 0 + num_seqs, num_heads, head_size = output.shape + tmp_output = torch.empty( + size=(num_seqs, num_heads, num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + attention_ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + head_mapping, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + ) + else: + assert False, f"Unknown version: {version}" # Run the reference implementation. ref_output = torch.empty_like(query) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index b1d0588d97f7..0677ebbae792 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -15,6 +15,8 @@ RotaryEmbedding) _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] +# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. +_PARTITION_SIZE = 512 class PagedAttention(nn.Module): @@ -130,6 +132,14 @@ def multi_query_kv_attention( output.copy_(out.squeeze(0)) return output + def get_alibi_slopes(self) -> Optional[torch.Tensor]: + """Returns the slopes for the alibi attention bias. + + Returns: + slopes: shape = [num_heads] + """ + return None + def single_query_cached_kv_attention( self, output: torch.Tensor, @@ -137,6 +147,7 @@ def single_query_cached_kv_attention( key_cache: torch.Tensor, value_cache: torch.Tensor, input_metadata: InputMetadata, + alibi_slopes: Optional[torch.Tensor], ) -> None: """PagedAttention for the generation tokens. @@ -148,21 +159,65 @@ def single_query_cached_kv_attention( value_cache: shape = [num_blocks, num_kv_heads, head_size, block_size] input_metadata: metadata for paged attention. + alibi_slopes: shape = [num_heads] """ block_size = value_cache.shape[3] - attention_ops.single_query_cached_kv_attention( - output, - query, - key_cache, - value_cache, - self.head_mapping, - self.scale, - input_metadata.block_tables, - input_metadata.context_lens, - block_size, - input_metadata.max_context_len, - None, # alibi_slopes - ) + num_seqs, num_heads, head_size = query.shape + max_num_partitions = ( + (input_metadata.max_context_len + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) + # NOTE(woosuk): We use a simple heuristic to decide whether to use + # PagedAttention V1 or V2. If the number of partitions is 1, we use + # V1 to avoid the overhead of reduction. Also, if the number of + # sequences or heads is large, we use V1 since there is enough work + # to parallelize. + # TODO(woosuk): Tune this heuristic. + use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512 + if use_v1: + # Run PagedAttention V1. + attention_ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + self.head_mapping, + self.scale, + input_metadata.block_tables, + input_metadata.context_lens, + block_size, + input_metadata.max_context_len, + alibi_slopes, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + attention_ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + self.head_mapping, + self.scale, + input_metadata.block_tables, + input_metadata.context_lens, + block_size, + input_metadata.max_context_len, + alibi_slopes, + ) def forward( self, @@ -253,7 +308,7 @@ def forward( self.single_query_cached_kv_attention( output[num_prompt_tokens:num_valid_tokens], query[num_prompt_tokens:num_valid_tokens], key_cache, - value_cache, input_metadata) + value_cache, input_metadata, self.get_alibi_slopes()) # Reshape the output tensor. # NOTE(woosuk): The output tensor may include paddings. @@ -431,36 +486,5 @@ def multi_query_kv_attention( start += prompt_len return output - def single_query_cached_kv_attention( - self, - output: torch.Tensor, - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - input_metadata: InputMetadata, - ) -> None: - """PagedAttention with ALiBi bias for the generation tokens. - - Args: - output: shape = [num_generation_tokens, num_heads, head_size] - query: shape = [num_generation_tokens, num_heads, head_size] - key_cache: shape = [num_blocks, num_kv_heads, head_size/x, - block_size, x] - value_cache: shape = [num_blocks, num_kv_heads, head_size, - block_size] - input_metadata: metadata for paged attention. - """ - block_size = value_cache.shape[3] - attention_ops.single_query_cached_kv_attention( - output, - query, - key_cache, - value_cache, - self.head_mapping, - self.scale, - input_metadata.block_tables, - input_metadata.context_lens, - block_size, - input_metadata.max_context_len, - self.alibi_slopes, - ) + def get_alibi_slopes(self) -> Optional[torch.Tensor]: + return self.alibi_slopes