From d10f8e1d43bfb0656b6848ad0c681ecbdec812d6 Mon Sep 17 00:00:00 2001 From: "shiyi.c_98" Date: Wed, 17 Jan 2024 16:32:10 -0800 Subject: [PATCH] [Experimental] Prefix Caching Support (#1669) Co-authored-by: DouHappy <2278958187@qq.com> Co-authored-by: Zhuohan Li --- .buildkite/test-pipeline.yaml | 4 + examples/offline_inference_with_prefix.py | 51 ++ tests/kernels/test_prefix_prefill.py | 168 ++++ tests/prefix_caching/test_prefix_caching.py | 41 + tests/samplers/test_sampler.py | 18 +- tests/worker/test_model_runner.py | 5 +- vllm/block.py | 4 + vllm/core/block_manager.py | 47 +- vllm/core/scheduler.py | 5 + vllm/engine/async_llm_engine.py | 16 +- vllm/engine/llm_engine.py | 18 +- vllm/entrypoints/api_server.py | 6 +- vllm/entrypoints/llm.py | 17 +- vllm/model_executor/input_metadata.py | 6 + vllm/model_executor/layers/attention.py | 88 ++- .../layers/triton_kernel/__init__.py | 0 .../layers/triton_kernel/prefix_prefill.py | 728 ++++++++++++++++++ vllm/prefix.py | 87 +++ vllm/sequence.py | 7 +- vllm/worker/model_runner.py | 111 ++- 20 files changed, 1356 insertions(+), 71 deletions(-) create mode 100644 examples/offline_inference_with_prefix.py create mode 100644 tests/kernels/test_prefix_prefill.py create mode 100644 tests/prefix_caching/test_prefix_caching.py create mode 100644 vllm/model_executor/layers/triton_kernel/__init__.py create mode 100644 vllm/model_executor/layers/triton_kernel/prefix_prefill.py create mode 100644 vllm/prefix.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index a6f3a3f0a2e3..adf2bb2b43c1 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -31,6 +31,10 @@ steps: - pytest -v -s models --forked soft_fail: true +- label: Prefix Caching Test + commands: + - pytest -v -s prefix_caching + - label: Samplers Test command: pytest -v -s samplers --forked diff --git a/examples/offline_inference_with_prefix.py b/examples/offline_inference_with_prefix.py new file mode 100644 index 000000000000..df9f1364ee51 --- /dev/null +++ b/examples/offline_inference_with_prefix.py @@ -0,0 +1,51 @@ +from vllm import LLM, SamplingParams + +prefix = ( + "You are an expert school principal, skilled in effectively managing " + "faculty and staff. Draft 10-15 questions for a potential first grade " + "Head Teacher for my K-12, all-girls', independent school that emphasizes " + "community, joyful discovery, and life-long learning. The candidate is " + "coming in for a first-round panel interview for a 8th grade Math " + "teaching role. They have 5 years of previous teaching experience " + "as an assistant teacher at a co-ed, public school with experience " + "in middle school math teaching. Based on these information, fulfill " + "the following paragraph: ") + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.0) + +# Create an LLM. +llm = LLM(model="facebook/opt-125m") + +generating_prompts = [prefix + prompt for prompt in prompts] + +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(generating_prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +print("-" * 80) + +# -1 since the last token can change when concatenating prompts. +prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1 + +# Generate with prefix +outputs = llm.generate(generating_prompts, sampling_params, + prefix_pos=[prefix_pos] * len(generating_prompts)) + +# Print the outputs. You should see the same outputs as before +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py new file mode 100644 index 000000000000..8fa6358d3ec7 --- /dev/null +++ b/tests/kernels/test_prefix_prefill.py @@ -0,0 +1,168 @@ +import random +import pytest +import time + +import torch +from vllm.model_executor.layers.triton_kernel.prefix_prefill import ( + context_attention_fwd) +from xformers import ops as xops +from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask + +NUM_HEADS = [12] +HEAD_SIZES = [128] +DTYPES = [torch.float16] + + +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@torch.inference_mode() +def test_contexted_kv_attention( + num_heads: int, + head_size: int, + dtype: torch.dtype, +) -> None: + random.seed(0) + torch.manual_seed(0) + MAX_SEQ_LEN = 1024 + MAX_CTX_LEN = 1024 + BS = 10 + cache_size = 640 + block_size = 32 + max_block_per_request = 64 + subquery_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] + ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] + seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)] + + num_tokens = sum(subquery_lens) + query = torch.empty(num_tokens, + num_heads, + head_size, + dtype=dtype, + device='cuda') + query.uniform_(-1e-3, 1e-3) + output = torch.empty(num_tokens, + num_heads, + head_size, + dtype=dtype, + device='cuda') + + kv = torch.empty(sum(seq_lens), + 2, + num_heads, + head_size, + dtype=dtype, + device='cuda') + kv.uniform_(-1e-3, 1e-3) + key, value = kv.unbind(dim=1) + + k_cache = torch.zeros(cache_size, + block_size, + num_heads, + head_size, + dtype=dtype, + device='cuda') + v_cache = torch.zeros(cache_size, + block_size, + num_heads, + head_size, + dtype=dtype, + device='cuda') + k = torch.zeros(sum(subquery_lens), + num_heads, + head_size, + dtype=dtype, + device='cuda') + v = torch.zeros(sum(subquery_lens), + num_heads, + head_size, + dtype=dtype, + device='cuda') + values = torch.arange(0, cache_size, dtype=torch.long, device='cuda') + values = values[torch.randperm(cache_size)] + block_table = values[:BS * max_block_per_request].view( + BS, max_block_per_request) + b_seq_len = torch.tensor(seq_lens, dtype=torch.long, device='cuda') + b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long, device='cuda') + b_start_loc = torch.cumsum(torch.tensor([0] + subquery_lens[:-1], + dtype=torch.long, + device='cuda'), + dim=0) + max_input_len = MAX_SEQ_LEN + # copy kv to cache + b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], + dtype=torch.long, + device='cuda'), + dim=0) + for i in range(BS): + for j in range(subquery_lens[i]): + k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + + j]) + v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + + b_ctx_len[i] + j]) + cur_ctx = 0 + block_id = 0 + while cur_ctx < b_ctx_len[i]: + start_loc = b_seq_start_loc[i] + cur_ctx + if cur_ctx + block_size > b_ctx_len[i]: + end_loc = b_seq_start_loc[i] + b_ctx_len[i] + else: + end_loc = start_loc + block_size + start_slot = block_table[i, block_id] * block_size + end_slot = start_slot + end_loc - start_loc + k_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_( + key[start_loc:end_loc]) + v_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_( + value[start_loc:end_loc]) + cur_ctx += block_size + block_id += 1 + # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] + # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] + k_cache = k_cache.view(-1, block_size, num_heads, head_size // 8, + 8).permute(0, 2, 3, 1, 4).contiguous() + # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] + # to V_cache[num_blocks, num_kv_heads, head_size, block_size] + v_cache = v_cache.view(-1, block_size, num_heads, + head_size).permute(0, 2, 3, 1).contiguous() + + context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table, + b_start_loc, b_seq_len, b_ctx_len, max_input_len) + torch.cuda.synchronize() + start_time = time.time() + context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table, + b_start_loc, b_seq_len, b_ctx_len, max_input_len) + torch.cuda.synchronize() + end_time = time.time() + print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") + + scale = float(1.0 / (head_size**0.5)) + + attn_op = xops.fmha.cutlass.FwOp() + + attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens( + subquery_lens, seq_lens) + output_ref = xops.memory_efficient_attention_forward( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + attn_bias=attn_bias, + p=0.0, + scale=scale, + op=attn_op, + ) + torch.cuda.synchronize() + start_time = time.time() + output_ref = xops.memory_efficient_attention_forward( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + attn_bias=attn_bias, + p=0.0, + scale=scale, + op=attn_op, + ) + torch.cuda.synchronize() + end_time = time.time() + print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") + output_ref = output_ref.squeeze(0) + assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py new file mode 100644 index 000000000000..1e301bedfc21 --- /dev/null +++ b/tests/prefix_caching/test_prefix_caching.py @@ -0,0 +1,41 @@ +"""Compare the with and without prefix caching. + +Run `pytest tests/prefix_caching/test_prefix_caching.py`. +""" +import pytest + +from vllm import LLM, SamplingParams + +prefix = ( + "You are an expert school principal, skilled in effectively managing " + "faculty and staff. Draft 10-15 questions for a potential first grade " + "Head Teacher for my K-12, all-girls', independent school that emphasizes " + "community, joyful discovery, and life-long learning. The candidate is " + "coming in for a first-round panel interview for a 8th grade Math " + "teaching role. They have 5 years of previous teaching experience " + "as an assistant teacher at a co-ed, public school with experience " + "in middle school math teaching. Based on these information, fulfill " + "the following paragraph: ") + + +@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +@pytest.mark.parametrize("max_tokens", [16]) +def test_prefix_caching( + example_prompts, + model: str, + max_tokens: int, +): + llm = LLM(model=model) + # -1 since the last token can change when concatenating prompts. + prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1 + prompts = [prefix + prompt for prompt in example_prompts] + sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) + outputs_without_prefix = llm.generate(prompts, sampling_params) + outputs_with_prefix = llm.generate(prompts, + sampling_params, + prefix_pos=[prefix_pos] * len(prompts)) + for output_without_prefix, output_with_prefix in zip( + outputs_without_prefix, outputs_with_prefix): + assert (output_without_prefix.outputs[0].token_ids == + output_with_prefix.outputs[0].token_ids) + assert len(llm.llm_engine.scheduler.prefix_pool.prefixes) == 1 diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 996aa8e0a8d9..bcd0cd60bfc5 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -66,7 +66,8 @@ def test_sampler_all_greedy(seed: int): prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens) + prompt_lens, + subquery_lens=prompt_lens) sampler_output = sampler(embedding=None, hidden_states=input_tensor, sampling_metadata=sampling_metadata) @@ -105,7 +106,8 @@ def test_sampler_all_random(seed: int): prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens) + prompt_lens, + subquery_lens=prompt_lens) sampler_output = sampler(embedding=None, hidden_states=input_tensor, sampling_metadata=sampling_metadata) @@ -140,7 +142,8 @@ def test_sampler_all_beam(seed: int): prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens) + prompt_lens, + subquery_lens=prompt_lens) sampler(embedding=None, hidden_states=input_tensor, sampling_metadata=sampling_metadata) @@ -193,7 +196,8 @@ def test_sampler_mixed(seed: int): prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens) + prompt_lens, + subquery_lens=prompt_lens) sampler_output = sampler(embedding=None, hidden_states=input_tensor, sampling_metadata=sampling_metadata) @@ -234,7 +238,8 @@ def pick_ith(token_ids, logits): prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens) + prompt_lens, + subquery_lens=prompt_lens) sampler_output = sampler(embedding=None, hidden_states=input_tensor, sampling_metadata=sampling_metadata) @@ -288,7 +293,8 @@ def test_sampler_top_k_top_p(seed: int): prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens) + prompt_lens, + subquery_lens=prompt_lens) sample_probs = None diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 250d84caf56d..edbe10684741 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -33,11 +33,12 @@ def test_prepare_prompt(): expected_selected_token_indices.append(selected_token_start_idx + prompt_len - 1) selected_token_start_idx += max_seq_len - input_tokens, input_positions, _, return_prompt_lens = ( + input_tokens, input_positions, _, return_prompt_lens, _ = ( model_runner._prepare_prompt(seq_group_metadata_list)) assert return_prompt_lens == prompt_lens sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens) + prompt_lens, + subquery_lens=prompt_lens) assert input_tokens.shape == (batch_size, max_seq_len) assert input_positions.shape == (batch_size, max_seq_len) torch.testing.assert_close(input_tokens, input_positions) diff --git a/vllm/block.py b/vllm/block.py index 435aa50ca22e..5fe39ed47b2f 100644 --- a/vllm/block.py +++ b/vllm/block.py @@ -66,3 +66,7 @@ def __repr__(self) -> str: return (f'PhysicalTokenBlock(device={self.device}, ' f'block_number={self.block_number}, ' f'ref_count={self.ref_count})') + + +# Mapping: logical block number -> physical block. +BlockTable = List[PhysicalTokenBlock] diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 3bde005997bd..7f91051f03ac 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -2,13 +2,10 @@ import enum from typing import Dict, List, Optional, Set, Tuple -from vllm.block import PhysicalTokenBlock +from vllm.block import BlockTable, PhysicalTokenBlock from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device -# Mapping: logical block number -> physical block. -BlockTable = List[PhysicalTokenBlock] - class BlockAllocator: """Manages free physical token blocks for a device. @@ -105,6 +102,10 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # the same prompt. This may not be true for preempted sequences. seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] num_required_blocks = len(seq.logical_token_blocks) + + if seq_group.prefix is not None and seq_group.prefix.allocated: + num_required_blocks -= seq_group.prefix.get_num_blocks() + if self.block_sliding_window is not None: num_required_blocks = min(num_required_blocks, self.block_sliding_window) @@ -125,8 +126,21 @@ def allocate(self, seq_group: SequenceGroup) -> None: seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] # Allocate new physical token blocks that will store the prompt tokens. + num_prompt_blocks = len(seq.logical_token_blocks) + block_table: BlockTable = [] - for logical_idx in range(len(seq.logical_token_blocks)): + prefix_block_table: BlockTable = [] + num_prefix_blocks = 0 + + prefix = seq_group.prefix + if prefix is not None and prefix.allocated: + # Prefix has already been allocated. Use the existing block table. + num_prompt_blocks -= prefix.get_num_blocks() + for block in prefix.block_table: + block.ref_count += seq_group.num_seqs() + block_table.append(block) + + for logical_idx in range(num_prompt_blocks): if (self.block_sliding_window is not None and logical_idx >= self.block_sliding_window): block = block_table[logical_idx % self.block_sliding_window] @@ -136,6 +150,15 @@ def allocate(self, seq_group: SequenceGroup) -> None: block.ref_count = seq_group.num_seqs() block_table.append(block) + if prefix is not None and not prefix.allocated: + # Allocate blocks for the prefix, we will compute the prefix's + # KV cache in this run. + num_prefix_blocks = prefix.get_num_blocks() + prefix_block_table = block_table[:num_prefix_blocks] + for block in prefix_block_table: + block.ref_count += 1 + prefix.set_block_table(prefix_block_table) + # Assign the block table for each sequence. for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): self.block_tables[seq.seq_id] = block_table.copy() @@ -210,10 +233,18 @@ def can_swap_in(self, seq_group: SequenceGroup) -> bool: def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: # CPU block -> GPU block. + if seq_group.prefix is not None: + # make sure to swap in the prefix first + assert seq_group.prefix.allocated and seq_group.prefix.computed + mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): new_block_table: BlockTable = [] block_table = self.block_tables[seq.seq_id] + if seq_group.prefix is not None: + for block in seq_group.prefix.block_table: + new_block_table.append(block) + block.ref_count += 1 for cpu_block in block_table: if cpu_block in mapping: @@ -245,6 +276,12 @@ def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: block_table = self.block_tables[seq.seq_id] for gpu_block in block_table: + if (seq_group.prefix is not None + and gpu_block in seq_group.prefix.block_table): + # NOTE: We do not swap out the prefix blocks for now. + self.gpu_allocator.free(gpu_block) + continue + if gpu_block in mapping: cpu_block = mapping[gpu_block] cpu_block.ref_count += 1 diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 9fe01a14aedc..eb46d43968f5 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -9,6 +9,7 @@ from vllm.logger import init_logger from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceStatus) +from vllm.prefix import PrefixPool logger = init_logger(__name__) @@ -76,6 +77,9 @@ def __init__( num_cpu_blocks=self.cache_config.num_cpu_blocks, sliding_window=self.cache_config.sliding_window) + # Create the prefix pool to cache the prefixes. + self.prefix_pool = PrefixPool(self.cache_config.block_size) + # Sequence groups in the WAITING state. self.waiting: Deque[SequenceGroup] = deque() # Sequence groups in the RUNNING state. @@ -316,6 +320,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: seq_data=seq_data, sampling_params=seq_group.sampling_params, block_tables=block_tables, + prefix=seq_group.prefix, ) seq_group_metadata_list.append(seq_group_metadata) return seq_group_metadata_list, scheduler_outputs diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 8a5b00ca7f7c..cbf2978c01c2 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -371,6 +371,7 @@ async def add_request( sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, + prefix_pos: Optional[int] = None, ) -> AsyncStream: if self.log_requests: shortened_prompt = prompt @@ -383,6 +384,7 @@ async def add_request( max_log_len] logger.info(f"Received request {request_id}: " f"prompt: {shortened_prompt!r}, " + f"prefix_pos: {prefix_pos}," f"sampling params: {sampling_params}, " f"prompt token ids: {shortened_token_ids}.") @@ -401,7 +403,8 @@ async def add_request( prompt=prompt, sampling_params=sampling_params, prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time) + arrival_time=arrival_time, + prefix_pos=prefix_pos) return stream @@ -410,7 +413,8 @@ async def generate( prompt: Optional[str], sampling_params: SamplingParams, request_id: str, - prompt_token_ids: Optional[List[int]] = None + prompt_token_ids: Optional[List[int]] = None, + prefix_pos: Optional[int] = None, ) -> AsyncIterator[RequestOutput]: """Generate outputs for a request. @@ -425,6 +429,11 @@ async def generate( request_id: The unique id of the request. prompt_token_ids: The token IDs of the prompt. If None, we use the tokenizer to convert the prompts to token IDs. + prefix_pos: If not None, we use the given position as the prefix + position for each prompt. We will cache the prefix's KV + cache and reuse it for the next request with the same prefix. + This is an experimental feature, and may be replaced with + automatic prefix caching in the future. Yields: The output `RequestOutput` objects from the LLMEngine for the @@ -482,7 +491,8 @@ async def generate( prompt, sampling_params, prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time) + arrival_time=arrival_time, + prefix_pos=prefix_pos) async for request_output in stream: yield request_output diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e30bf5db4928..7072a8bbc5b3 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -337,6 +337,7 @@ def add_request( sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, + prefix_pos: Optional[int] = None, ) -> None: """Add a request to the engine's request pool. @@ -353,6 +354,11 @@ def add_request( use the tokenizer to convert the prompts to token IDs. arrival_time: The arrival time of the request. If None, we use the current monotonic time. + prefix_pos: If not None, we use the given position as the prefix + position for each prompt. We will cache the prefix's KV + cache and reuse it for the next request with the same prefix. + This is an experimental feature, and may be replaced with + automatic prefix caching in the future. Details: - Set arrival_time to the current time if it is None. @@ -389,9 +395,13 @@ def add_request( seq_id = next(self.seq_counter) seq = Sequence(seq_id, prompt, prompt_token_ids, block_size) + # Check whether the input specifies prefix + prefix = self.scheduler.prefix_pool.add_or_get_prefix( + prompt_token_ids[:prefix_pos]) if prefix_pos is not None else None + # Create the sequence group. seq_group = SequenceGroup(request_id, [seq], sampling_params, - arrival_time) + arrival_time, prefix) # Add the sequence group to the scheduler. self.scheduler.add_seq_group(seq_group) @@ -662,6 +672,12 @@ def _process_model_outputs( request_output = RequestOutput.from_seq_group(seq_group) request_outputs.append(request_output) + # Update prefix state, now all the uncomputed prefixes are computed. + for seq_group in scheduled_seq_groups: + if (seq_group.prefix is not None and seq_group.prefix.allocated + and not seq_group.prefix.computed): + seq_group.prefix.computed = True + if self.log_stats: # Log the system stats. self._log_system_stats(scheduler_outputs.prompt_run, diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 9c27bcf2636c..f7b8d258fae4 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -33,11 +33,15 @@ async def generate(request: Request) -> Response: """ request_dict = await request.json() prompt = request_dict.pop("prompt") + prefix_pos = request_dict.pop("prefix_pos", None) stream = request_dict.pop("stream", False) sampling_params = SamplingParams(**request_dict) request_id = random_uuid() - results_generator = engine.generate(prompt, sampling_params, request_id) + results_generator = engine.generate(prompt, + sampling_params, + request_id, + prefix_pos=prefix_pos) # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 0700298b03a3..b819e233c06b 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -120,6 +120,7 @@ def generate( prompts: Optional[Union[str, List[str]]] = None, sampling_params: Optional[SamplingParams] = None, prompt_token_ids: Optional[List[List[int]]] = None, + prefix_pos: Optional[Union[int, List[int]]] = None, use_tqdm: bool = True, ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -134,6 +135,11 @@ def generate( None, we use the default sampling parameters. prompt_token_ids: A list of token IDs for the prompts. If None, we use the tokenizer to convert the prompts to token IDs. + prefix_pos: If not None, we use the given position as the prefix + position for each prompt. We will cache the prefix's KV + cache and reuse it for the next request with the same prefix. + This is an experimental feature, and may be replaced with + automatic prefix caching in the future. use_tqdm: Whether to use tqdm to display the progress bar. Returns: @@ -159,9 +165,10 @@ def generate( prompt_token_ids) for i in range(num_requests): prompt = prompts[i] if prompts is not None else None + prefix_pos_i = prefix_pos[i] if prefix_pos is not None else None token_ids = None if prompt_token_ids is None else prompt_token_ids[ i] - self._add_request(prompt, sampling_params, token_ids) + self._add_request(prompt, sampling_params, token_ids, prefix_pos_i) return self._run_engine(use_tqdm) def _add_request( @@ -169,10 +176,14 @@ def _add_request( prompt: Optional[str], sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]], + prefix_pos: Optional[int] = None, ) -> None: request_id = str(next(self.request_counter)) - self.llm_engine.add_request(request_id, prompt, sampling_params, - prompt_token_ids) + self.llm_engine.add_request(request_id, + prompt, + sampling_params, + prompt_token_ids, + prefix_pos=prefix_pos) def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: # Initialize tqdm. diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index da615ecccf99..ef49cc5902ea 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -18,12 +18,18 @@ def __init__( self, is_prompt: bool, slot_mapping: torch.Tensor, + prompt_lens: Optional[torch.Tensor], + max_seq_len: Optional[int], + start_loc: Optional[torch.Tensor], max_context_len: Optional[int], context_lens: Optional[torch.Tensor], block_tables: Optional[torch.Tensor], use_cuda_graph: bool, ) -> None: self.is_prompt = is_prompt + self.prompt_lens = prompt_lens + self.max_seq_len = max_seq_len + self.start_loc = start_loc self.max_context_len = max_context_len self.slot_mapping = slot_mapping self.context_lens = context_lens diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index f1008ec8159f..8b5c6ab30d7b 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -10,6 +10,8 @@ from vllm._C import ops from vllm._C import cache_ops from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.triton_kernel.prefix_prefill import ( + context_attention_fwd) from vllm.utils import is_hip _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] @@ -115,45 +117,65 @@ def forward( self.num_kv_heads, self.num_queries_per_kv, value.shape[-1]) + # normal attention + if (key_cache is None or value_cache is None + or input_metadata.block_tables.numel() == 0): + # Set attention bias if not provided. This typically happens at + # the very attention layer of every iteration. + # FIXME(woosuk): This is a hack. + if input_metadata.attn_bias is None: + if self.alibi_slopes is None: + attn_bias = BlockDiagonalCausalMask.from_seqlens( + [seq_len] * batch_size) + if self.sliding_window is not None: + attn_bias = attn_bias.make_local_attention( + self.sliding_window) + input_metadata.attn_bias = attn_bias + else: + input_metadata.attn_bias = _make_alibi_bias( + self.alibi_slopes, self.num_kv_heads, batch_size, + seq_len, query.dtype) - # Set attention bias if not provided. This typically happens at the - # very attention layer of every iteration. - # FIXME(woosuk): This is a hack. - if input_metadata.attn_bias is None: + # TODO(woosuk): Too many view operations. Let's try to reduce + # them in the future for code readability. if self.alibi_slopes is None: - attn_bias = BlockDiagonalCausalMask.from_seqlens( - [seq_len] * batch_size) - if self.sliding_window is not None: - attn_bias = attn_bias.make_local_attention( - self.sliding_window) - input_metadata.attn_bias = attn_bias + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) else: - input_metadata.attn_bias = _make_alibi_bias( - self.alibi_slopes, self.num_kv_heads, batch_size, - seq_len, query.dtype) + query = query.unflatten(0, (batch_size, seq_len)) + key = key.unflatten(0, (batch_size, seq_len)) + value = value.unflatten(0, (batch_size, seq_len)) - # TODO(woosuk): Too many view operations. Let's try to reduce them - # in the future for code readability. - if self.alibi_slopes is None: - query = query.unsqueeze(0) - key = key.unsqueeze(0) - value = value.unsqueeze(0) + out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=input_metadata.attn_bias, + p=0.0, + scale=self.scale, + op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if + (is_hip()) else None, + ) + output = out.view_as(query) else: - query = query.unflatten(0, (batch_size, seq_len)) - key = key.unflatten(0, (batch_size, seq_len)) - value = value.unflatten(0, (batch_size, seq_len)) + # prefix-enabled attention + output = torch.empty_like(query) + context_attention_fwd( + query, + key, + value, + output, + key_cache, + value_cache, + input_metadata.block_tables, # [BS, max_block_per_request] + input_metadata.start_loc, + input_metadata.prompt_lens, + input_metadata.context_lens, + input_metadata.max_seq_len, + getattr(self, "alibi_slopes", None), + ) - out = xops.memory_efficient_attention_forward( - query, - key, - value, - attn_bias=input_metadata.attn_bias, - p=0.0, - scale=self.scale, - op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if - (is_hip()) else None, - ) - output = out.view_as(query) else: # Decoding run. output = _paged_attention( diff --git a/vllm/model_executor/layers/triton_kernel/__init__.py b/vllm/model_executor/layers/triton_kernel/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/model_executor/layers/triton_kernel/prefix_prefill.py b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py new file mode 100644 index 000000000000..8fa70054f02c --- /dev/null +++ b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py @@ -0,0 +1,728 @@ +# The kernels in this file are adapted from LightLLM's context_attention_fwd: +# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py + +import torch +import triton +import triton.language as tl + +if triton.__version__ >= "2.1.0": + + @triton.jit + def _fwd_kernel( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + q = tl.load( + Q + off_q, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = (bn[None, :] * stride_k_cache_bs + + cur_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * + stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = ( + bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k = tl.load(K_cache + off_k, + mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(V_cache + off_v, + mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + return + + @triton.jit + def _fwd_kernel_flash_attn_v2( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + q = tl.load( + Q + off_q, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = (bn[None, :] * stride_k_cache_bs + + cur_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * + stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = ( + bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k = tl.load(K_cache + off_k, + mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(V_cache + off_v, + mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + # acc /= l_i[:, None] + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + return + + @triton.jit + def _fwd_kernel_alibi( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + Alibi_slopes, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + # attn_bias[] + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + # cur_batch_seq_len: the length of prompts + # cur_batch_ctx_len: the length of prefix + # cur_batch_in_all_start_index: the start id of the dim=0 + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + q = tl.load( + Q + off_q, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange( + 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = 0 + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = (bn[None, :] * stride_k_cache_bs + + cur_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * + stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = ( + bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k = tl.load(K_cache + off_k, + mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), + alibi, float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(V_cache + off_v, + mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v, allow_tf32=False) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + # init alibi + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange( + 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = cur_batch_ctx_len + # # init debuger + # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc + # offset_db_k = tl.arange(0, BLOCK_N) + # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL] + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k, allow_tf32=False) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), + alibi, float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v, allow_tf32=False) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + acc = acc / l_i[:, None] + + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + return + + @torch.inference_mode() + def context_attention_fwd(q, + k, + v, + o, + k_cache, + v_cache, + b_loc, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + alibi_slopes=None): + BLOCK = 128 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + + sm_scale = 1.0 / (Lq**0.5) + batch, head = b_seq_len.shape[0], q.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, + + num_warps = 8 if Lk <= 64 else 8 + if alibi_slopes is not None: + _fwd_kernel_alibi[grid]( + q, + k, + v, + k_cache, + v_cache, + b_loc, + sm_scale, + b_start_loc, + b_seq_len, + b_ctx_len, + alibi_slopes, + v_cache.shape[3], + 8, + o, + b_loc.stride(0), + b_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + k_cache.stride( + 4 + ), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride( + 3), #[num_blocks, num_kv_heads, head_size, block_size] + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + _fwd_kernel[grid]( + q, + k, + v, + k_cache, + v_cache, + b_loc, + sm_scale, + b_start_loc, + b_seq_len, + b_ctx_len, + v_cache.shape[3], + 8, + o, + b_loc.stride(0), + b_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + k_cache.stride( + 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride( + 3), #[num_blocks, num_kv_heads, head_size, block_size] + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return diff --git a/vllm/prefix.py b/vllm/prefix.py new file mode 100644 index 000000000000..415da1fc6d2b --- /dev/null +++ b/vllm/prefix.py @@ -0,0 +1,87 @@ +from typing import Dict, List, Sequence, Tuple, Optional + +from vllm.block import BlockTable + + +class Prefix: + """Data and states associated with a prefix of prompt tokens for multiple + sequence groups. + + NOTE: This feature is experimental and may be replaced with automatic + prefix caching in the future. + + Args: + prefix_id: The id of the prefix in the prefix pool. + token_ids: The token ids of the prefix. + block_size: The block size of the executed model. + """ + + def __init__( + self, + token_ids: Sequence[int], + block_size: int, + ) -> None: + self.token_ids = tuple(token_ids) + self.block_size = block_size + self.length = len(token_ids) + self.hash = hash(token_ids) + assert self.length % block_size == 0 + self.block_table: Optional[BlockTable] = None + self.computed = False + + @property + def allocated(self) -> bool: + return self.block_table is not None + + def get_num_blocks(self) -> int: + return self.length // self.block_size + + def get_block_numbers(self) -> List[int]: + return [block.block_number for block in self.block_table] + + def get_length(self) -> int: + return self.length + + def __hash__(self) -> int: + return self.hash + + def set_block_table(self, block_table: BlockTable) -> None: + self.block_table = block_table.copy() + + +class PrefixPool: + """Manages all the prompt prefixes. + + NOTE: This feature is experimental and may be replaced with automatic + prefix caching in the future. + + Args: + block_size: The block size of the executed model. + + Attributes: + prefixes: A list of all the prefixes. + block_size: The block size of the executed model. + """ + + def __init__( + self, + block_size: int, + ) -> None: + # TODO(zhuohan): Add a capacity limit to the prefix pool. + self.prefixes: Dict[int, Prefix] = {} + self.block_size = block_size + + def _truncate_token_ids(self, token_ids: Sequence[int]) -> Tuple[int]: + new_length = len(token_ids) // self.block_size * self.block_size + return tuple(token_ids[:new_length]) + + def add_or_get_prefix(self, token_ids: Sequence[int]) -> Optional[Prefix]: + token_ids = self._truncate_token_ids(token_ids) + if len(token_ids) == 0: + # Prefix is empty. + return None + prefix = Prefix(token_ids, self.block_size) + prefix_hash = hash(prefix) + if prefix_hash not in self.prefixes: + self.prefixes[prefix_hash] = prefix + return self.prefixes[prefix_hash] diff --git a/vllm/sequence.py b/vllm/sequence.py index 7d36eeac0aa0..fd10bc9b5b8c 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Union from vllm.block import LogicalTokenBlock +from vllm.prefix import Prefix from vllm.sampling_params import SamplingParams PromptLogprobs = List[Optional[Dict[int, float]]] @@ -236,11 +237,13 @@ def __init__( seqs: List[Sequence], sampling_params: SamplingParams, arrival_time: float, + prefix: Optional[Prefix] = None, ) -> None: self.request_id = request_id self.seqs_dict = {seq.seq_id: seq for seq in seqs} self.sampling_params = sampling_params self.arrival_time = arrival_time + self.prefix: Optional[Prefix] = prefix self.prompt_logprobs: Optional[PromptLogprobs] = None @property @@ -327,7 +330,6 @@ def __repr__(self) -> str: class SequenceGroupMetadata: """Metadata for a sequence group. Used to create `InputMetadata`. - Args: request_id: The ID of the request. is_prompt: Whether the request is at prompt stage. @@ -335,6 +337,7 @@ class SequenceGroupMetadata: sampling_params: The sampling parameters used to generate the outputs. block_tables: The block tables. (Seq id -> list of physical block numbers) + prefix: The prefix of the prompt of the sequence group. """ def __init__( @@ -344,12 +347,14 @@ def __init__( seq_data: Dict[int, SequenceData], sampling_params: SamplingParams, block_tables: Dict[int, List[int]], + prefix: Optional[Prefix] = None, ) -> None: self.request_id = request_id self.is_prompt = is_prompt self.seq_data = seq_data self.sampling_params = sampling_params self.block_tables = block_tables + self.prefix = prefix class SequenceOutput: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 460d9907e88c..d29088650650 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -74,13 +74,17 @@ def set_block_size(self, block_size: int) -> None: def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int]]: + ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], + List[int]]: assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] slot_mapping: List[List[int]] = [] prompt_lens: List[int] = [] + context_lens: List[int] = [] + subquery_lens: List[int] = [] + prefix_block_tables: List[List[int]] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -91,11 +95,23 @@ def _prepare_prompt( prompt_tokens = seq_data.get_token_ids() prompt_len = len(prompt_tokens) prompt_lens.append(prompt_len) + prefix_len = 0 + prefix = seq_group_metadata.prefix + if prefix is not None and prefix.computed: + prefix_len = prefix.get_length() + prompt_tokens = prompt_tokens[prefix_len:] + prefix_block_tables.append(prefix.get_block_numbers()) + else: + prefix_block_tables.append([]) + # actual prompt lens + context_lens.append(prefix_len) + subquery_lens.append(prompt_len - prefix_len) input_tokens.append(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.append(list(range(prompt_len))) + input_positions.append( + list(range(prefix_len, prefix_len + len(prompt_tokens)))) if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized @@ -113,8 +129,11 @@ def _prepare_prompt( # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. start_idx = 0 if self.sliding_window is not None: + assert prefix_len == 0, ( + "Prefix caching is currently not supported with " + "sliding window attention") start_idx = max(0, prompt_len - self.sliding_window) - for i in range(prompt_len): + for i in range(prefix_len, prompt_len): if i < start_idx: slot_mapping[-1].append(_PAD_SLOT_ID) continue @@ -124,7 +143,7 @@ def _prepare_prompt( slot = block_number * self.block_size + block_offset slot_mapping[-1].append(slot) - max_prompt_len = max(prompt_lens) + max_prompt_len = max(subquery_lens) input_tokens = _make_tensor_with_pad(input_tokens, max_prompt_len, pad=0, @@ -137,16 +156,39 @@ def _prepare_prompt( max_prompt_len, pad=_PAD_SLOT_ID, dtype=torch.long) + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.int, + device='cuda') + # Prepare prefix block tables + max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) + block_tables = _make_tensor_with_pad( + prefix_block_tables, + max_len=max_prompt_block_table_len, + pad=0, + dtype=torch.int, + ) + start_loc_tensor = torch.arange(0, + len(prompt_lens) * max_prompt_len, + max_prompt_len, + dtype=torch.long, + device='cuda') + prompt_lens_tensor = torch.tensor(prompt_lens, + dtype=torch.long, + device='cuda') input_metadata = InputMetadata( is_prompt=True, slot_mapping=slot_mapping, + prompt_lens=prompt_lens_tensor, + max_seq_len=max_prompt_len, + start_loc=start_loc_tensor, max_context_len=None, - context_lens=None, - block_tables=None, + context_lens=context_lens_tensor, + block_tables=block_tables, use_cuda_graph=False, ) - return input_tokens, input_positions, input_metadata, prompt_lens + return (input_tokens, input_positions, input_metadata, prompt_lens, + subquery_lens) def _prepare_decode( self, @@ -248,6 +290,9 @@ def _prepare_decode( input_metadata = InputMetadata( is_prompt=False, slot_mapping=slot_mapping, + prompt_lens=None, + max_seq_len=None, + start_loc=None, max_context_len=max_context_len, context_lens=context_lens, block_tables=block_tables, @@ -259,6 +304,7 @@ def _prepare_sample( self, seq_group_metadata_list: List[SequenceGroupMetadata], prompt_lens: List[int], + subquery_lens: Optional[List[int]], ) -> SamplingMetadata: seq_groups: List[Tuple[List[int], SamplingParams]] = [] selected_token_indices: List[int] = [] @@ -266,7 +312,7 @@ def _prepare_sample( categorized_sample_indices = {t: [] for t in SamplingType} categorized_sample_indices_start_idx = 0 - max_prompt_len = max(prompt_lens) if prompt_lens else 1 + max_subquery_len = max(subquery_lens) if subquery_lens else 1 for i, seq_group_metadata in enumerate(seq_group_metadata_list): seq_ids = list(seq_group_metadata.seq_data.keys()) sampling_params = seq_group_metadata.sampling_params @@ -274,10 +320,11 @@ def _prepare_sample( if seq_group_metadata.is_prompt: assert len(seq_ids) == 1 - prompt_len = prompt_lens[i] + assert subquery_lens is not None + subquery_len = subquery_lens[i] if sampling_params.prompt_logprobs is not None: # NOTE: prompt token positions do not need sample, skip - categorized_sample_indices_start_idx += prompt_len - 1 + categorized_sample_indices_start_idx += subquery_len - 1 categorized_sample_indices[ sampling_params.sampling_type].append( @@ -287,10 +334,10 @@ def _prepare_sample( if sampling_params.prompt_logprobs is not None: selected_token_indices.extend( range(selected_token_start_idx, - selected_token_start_idx + prompt_len - 1)) + selected_token_start_idx + subquery_len - 1)) selected_token_indices.append(selected_token_start_idx + - prompt_len - 1) - selected_token_start_idx += max_prompt_len + subquery_len - 1) + selected_token_start_idx += max_subquery_len else: num_seqs = len(seq_ids) selected_token_indices.extend( @@ -335,14 +382,16 @@ def prepare_input_tensors( is_prompt = seq_group_metadata_list[0].is_prompt # Prepare input tensors. if is_prompt: - (input_tokens, input_positions, input_metadata, - prompt_lens) = self._prepare_prompt(seq_group_metadata_list) + (input_tokens, input_positions, input_metadata, prompt_lens, + subquery_lens) = self._prepare_prompt(seq_group_metadata_list) else: (input_tokens, input_positions, input_metadata ) = self._prepare_decode(seq_group_metadata_list) + subquery_lens = None prompt_lens = [] sampling_metadata = self._prepare_sample(seq_group_metadata_list, - prompt_lens) + prompt_lens, + subquery_lens) def get_size_or_none(x: Optional[torch.Tensor]): return x.size() if x is not None else None @@ -359,6 +408,12 @@ def get_size_or_none(x: Optional[torch.Tensor]): input_metadata.is_prompt, "slot_mapping_size": get_size_or_none(input_metadata.slot_mapping), + "prompt_lens_size": + get_size_or_none(input_metadata.prompt_lens), + "max_seq_len": + input_metadata.max_seq_len, + "start_loc_size": + get_size_or_none(input_metadata.start_loc), "max_context_len": input_metadata.max_context_len, "context_lens_size": @@ -376,6 +431,10 @@ def get_size_or_none(x: Optional[torch.Tensor]): broadcast(input_positions, src=0) if input_metadata.slot_mapping is not None: broadcast(input_metadata.slot_mapping, src=0) + if input_metadata.prompt_lens is not None: + broadcast(input_metadata.prompt_lens, src=0) + if input_metadata.start_loc is not None: + broadcast(input_metadata.start_loc, src=0) if input_metadata.context_lens is not None: broadcast(input_metadata.context_lens, src=0) if input_metadata.block_tables is not None: @@ -400,6 +459,20 @@ def get_size_or_none(x: Optional[torch.Tensor]): broadcast(slot_mapping, src=0) else: slot_mapping = None + if py_data["prompt_lens_size"] is not None: + prompt_lens = torch.empty(*py_data["prompt_lens_size"], + dtype=torch.long, + device="cuda") + broadcast(prompt_lens, src=0) + else: + prompt_lens = None + if py_data["start_loc_size"] is not None: + start_loc = torch.empty(*py_data["start_loc_size"], + dtype=torch.long, + device="cuda") + broadcast(start_loc, src=0) + else: + start_loc = None if py_data["context_lens_size"] is not None: context_lens = torch.empty(*py_data["context_lens_size"], dtype=torch.int, @@ -422,6 +495,9 @@ def get_size_or_none(x: Optional[torch.Tensor]): input_metadata = InputMetadata( is_prompt=py_data["is_prompt"], slot_mapping=slot_mapping, + prompt_lens=prompt_lens, + max_seq_len=py_data["max_seq_len"], + start_loc=start_loc, max_context_len=py_data["max_context_len"], context_lens=context_lens, block_tables=block_tables, @@ -534,6 +610,9 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: input_metadata = InputMetadata( is_prompt=False, slot_mapping=slot_mapping[:batch_size], + prompt_lens=None, + max_seq_len=None, + start_loc=None, max_context_len=self.max_context_len_to_capture, context_lens=context_lens[:batch_size], block_tables=block_tables[:batch_size],