Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Experimental] Prefix Caching Support #1669

Merged
merged 44 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
8d8ebb1
add the prefix prefill triton kernel
caoshiyi Nov 15, 2023
49aaf42
add the prefix class
caoshiyi Nov 15, 2023
ac96e7f
modify block manager & scheduler
caoshiyi Nov 15, 2023
bf3faa2
prepare inputs to the kernel
caoshiyi Nov 15, 2023
074a90f
fix
caoshiyi Nov 16, 2023
de69ca4
add prefix_pos
caoshiyi Nov 16, 2023
f5bf25a
fix prefix state transition
caoshiyi Nov 17, 2023
49b19bb
tested on single request
caoshiyi Nov 17, 2023
1619d33
add prefix_pos for offline inference
caoshiyi Nov 18, 2023
2309330
minor
caoshiyi Nov 18, 2023
d8b2809
fix blocktable padding
caoshiyi Nov 21, 2023
50e7497
fix multi-gpu state transition
caoshiyi Jan 1, 2024
9b4a63d
Merge branch 'main' into prefix
caoshiyi Jan 1, 2024
33bfcff
clean
caoshiyi Jan 2, 2024
1cef69f
format
caoshiyi Jan 2, 2024
98ee509
add support for alibi bias
caoshiyi Jan 2, 2024
6e07602
Merge branch 'main' into prefix
caoshiyi Jan 5, 2024
a948cd3
format
caoshiyi Jan 5, 2024
ead42a2
clean
caoshiyi Jan 6, 2024
7bcb509
clean
caoshiyi Jan 6, 2024
8bc52ca
clean & minor
caoshiyi Jan 7, 2024
a534d6c
Merge branch 'main' into prefix
caoshiyi Jan 7, 2024
abb843b
format
caoshiyi Jan 7, 2024
dc08e14
move prefix prefill kernel test to a separate file
caoshiyi Jan 7, 2024
f0f8f66
format
caoshiyi Jan 7, 2024
2389fdb
Merge branch 'main' into prefix
zhuohan123 Jan 13, 2024
3678af6
fix test
zhuohan123 Jan 13, 2024
037950b
format kernel
zhuohan123 Jan 13, 2024
b414c77
fix format
zhuohan123 Jan 13, 2024
bb4ca73
[WIP] Refactor
zhuohan123 Jan 15, 2024
37cd3fc
fix comments
zhuohan123 Jan 15, 2024
b09bbaa
Merge branch 'prefix' of https://github.com/caoshiyi/vllm into prefix
zhuohan123 Jan 16, 2024
58d2839
add example and test
zhuohan123 Jan 16, 2024
8370e7b
Merge branch 'main' into prefix
zhuohan123 Jan 16, 2024
58c5cff
add prefix caching test to ci
zhuohan123 Jan 16, 2024
dc6e959
fix ci
zhuohan123 Jan 16, 2024
7dc2e87
add comment
zhuohan123 Jan 17, 2024
49b2684
add TODO
zhuohan123 Jan 17, 2024
c9050d3
fix
zhuohan123 Jan 17, 2024
cfe1444
fix swapping logic
zhuohan123 Jan 17, 2024
3bb9802
Merge branch 'main' into prefix
zhuohan123 Jan 17, 2024
29f4f96
fix bug
zhuohan123 Jan 17, 2024
bd56a69
fix correctness
zhuohan123 Jan 17, 2024
6b00283
add notes and small fix
zhuohan123 Jan 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ steps:
- pytest -v -s models --forked
soft_fail: true

- label: Prefix Caching Test
commands:
- pytest -v -s prefix_caching

- label: Samplers Test
command: pytest -v -s samplers --forked

Expand Down
51 changes: 51 additions & 0 deletions examples/offline_inference_with_prefix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from vllm import LLM, SamplingParams

prefix = (
"You are an expert school principal, skilled in effectively managing "
"faculty and staff. Draft 10-15 questions for a potential first grade "
"Head Teacher for my K-12, all-girls', independent school that emphasizes "
"community, joyful discovery, and life-long learning. The candidate is "
"coming in for a first-round panel interview for a 8th grade Math "
"teaching role. They have 5 years of previous teaching experience "
"as an assistant teacher at a co-ed, public school with experience "
"in middle school math teaching. Based on these information, fulfill "
"the following paragraph: ")

# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0)

# Create an LLM.
llm = LLM(model="facebook/opt-125m")

generating_prompts = [prefix + prompt for prompt in prompts]

# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(generating_prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

print("-" * 80)

# -1 since the last token can change when concatenating prompts.
prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1

# Generate with prefix
outputs = llm.generate(generating_prompts, sampling_params,
prefix_pos=[prefix_pos] * len(generating_prompts))

# Print the outputs. You should see the same outputs as before
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
168 changes: 168 additions & 0 deletions tests/kernels/test_prefix_prefill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import random
import pytest
import time

import torch
from vllm.model_executor.layers.triton_kernel.prefix_prefill import (
context_attention_fwd)
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask

NUM_HEADS = [12]
HEAD_SIZES = [128]
DTYPES = [torch.float16]


@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_contexted_kv_attention(
num_heads: int,
head_size: int,
dtype: torch.dtype,
) -> None:
random.seed(0)
torch.manual_seed(0)
MAX_SEQ_LEN = 1024
MAX_CTX_LEN = 1024
BS = 10
cache_size = 640
block_size = 32
max_block_per_request = 64
subquery_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)]

num_tokens = sum(subquery_lens)
query = torch.empty(num_tokens,
num_heads,
head_size,
dtype=dtype,
device='cuda')
query.uniform_(-1e-3, 1e-3)
output = torch.empty(num_tokens,
num_heads,
head_size,
dtype=dtype,
device='cuda')

kv = torch.empty(sum(seq_lens),
2,
num_heads,
head_size,
dtype=dtype,
device='cuda')
kv.uniform_(-1e-3, 1e-3)
key, value = kv.unbind(dim=1)

k_cache = torch.zeros(cache_size,
block_size,
num_heads,
head_size,
dtype=dtype,
device='cuda')
v_cache = torch.zeros(cache_size,
block_size,
num_heads,
head_size,
dtype=dtype,
device='cuda')
k = torch.zeros(sum(subquery_lens),
num_heads,
head_size,
dtype=dtype,
device='cuda')
v = torch.zeros(sum(subquery_lens),
num_heads,
head_size,
dtype=dtype,
device='cuda')
values = torch.arange(0, cache_size, dtype=torch.long, device='cuda')
values = values[torch.randperm(cache_size)]
block_table = values[:BS * max_block_per_request].view(
BS, max_block_per_request)
b_seq_len = torch.tensor(seq_lens, dtype=torch.long, device='cuda')
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long, device='cuda')
b_start_loc = torch.cumsum(torch.tensor([0] + subquery_lens[:-1],
dtype=torch.long,
device='cuda'),
dim=0)
max_input_len = MAX_SEQ_LEN
# copy kv to cache
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
dtype=torch.long,
device='cuda'),
dim=0)
for i in range(BS):
for j in range(subquery_lens[i]):
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
j])
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
b_ctx_len[i] + j])
cur_ctx = 0
block_id = 0
while cur_ctx < b_ctx_len[i]:
start_loc = b_seq_start_loc[i] + cur_ctx
if cur_ctx + block_size > b_ctx_len[i]:
end_loc = b_seq_start_loc[i] + b_ctx_len[i]
else:
end_loc = start_loc + block_size
start_slot = block_table[i, block_id] * block_size
end_slot = start_slot + end_loc - start_loc
k_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_(
key[start_loc:end_loc])
v_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_(
value[start_loc:end_loc])
cur_ctx += block_size
block_id += 1
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
k_cache = k_cache.view(-1, block_size, num_heads, head_size // 8,
8).permute(0, 2, 3, 1, 4).contiguous()
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
v_cache = v_cache.view(-1, block_size, num_heads,
head_size).permute(0, 2, 3, 1).contiguous()

context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table,
b_start_loc, b_seq_len, b_ctx_len, max_input_len)
torch.cuda.synchronize()
start_time = time.time()
context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table,
b_start_loc, b_seq_len, b_ctx_len, max_input_len)
torch.cuda.synchronize()
end_time = time.time()
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")

scale = float(1.0 / (head_size**0.5))

attn_op = xops.fmha.cutlass.FwOp()

attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
subquery_lens, seq_lens)
output_ref = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
attn_bias=attn_bias,
p=0.0,
scale=scale,
op=attn_op,
)
torch.cuda.synchronize()
start_time = time.time()
output_ref = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
attn_bias=attn_bias,
p=0.0,
scale=scale,
op=attn_op,
)
torch.cuda.synchronize()
end_time = time.time()
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
output_ref = output_ref.squeeze(0)
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
41 changes: 41 additions & 0 deletions tests/prefix_caching/test_prefix_caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Compare the with and without prefix caching.

Run `pytest tests/prefix_caching/test_prefix_caching.py`.
"""
import pytest

from vllm import LLM, SamplingParams

prefix = (
"You are an expert school principal, skilled in effectively managing "
"faculty and staff. Draft 10-15 questions for a potential first grade "
"Head Teacher for my K-12, all-girls', independent school that emphasizes "
"community, joyful discovery, and life-long learning. The candidate is "
"coming in for a first-round panel interview for a 8th grade Math "
"teaching role. They have 5 years of previous teaching experience "
"as an assistant teacher at a co-ed, public school with experience "
"in middle school math teaching. Based on these information, fulfill "
"the following paragraph: ")


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("max_tokens", [16])
def test_prefix_caching(
example_prompts,
model: str,
max_tokens: int,
):
llm = LLM(model=model)
# -1 since the last token can change when concatenating prompts.
prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1
prompts = [prefix + prompt for prompt in example_prompts]
sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs_without_prefix = llm.generate(prompts, sampling_params)
outputs_with_prefix = llm.generate(prompts,
sampling_params,
prefix_pos=[prefix_pos] * len(prompts))
for output_without_prefix, output_with_prefix in zip(
outputs_without_prefix, outputs_with_prefix):
assert (output_without_prefix.outputs[0].token_ids ==
output_with_prefix.outputs[0].token_ids)
assert len(llm.llm_engine.scheduler.prefix_pool.prefixes) == 1
18 changes: 12 additions & 6 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def test_sampler_all_greedy(seed: int):
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
Expand Down Expand Up @@ -105,7 +106,8 @@ def test_sampler_all_random(seed: int):
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
Expand Down Expand Up @@ -140,7 +142,8 @@ def test_sampler_all_beam(seed: int):
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
prompt_lens,
subquery_lens=prompt_lens)
sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
Expand Down Expand Up @@ -193,7 +196,8 @@ def test_sampler_mixed(seed: int):
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
Expand Down Expand Up @@ -234,7 +238,8 @@ def pick_ith(token_ids, logits):
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
Expand Down Expand Up @@ -288,7 +293,8 @@ def test_sampler_top_k_top_p(seed: int):
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
prompt_lens,
subquery_lens=prompt_lens)

sample_probs = None

Expand Down
5 changes: 3 additions & 2 deletions tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ def test_prepare_prompt():
expected_selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += max_seq_len
input_tokens, input_positions, _, return_prompt_lens = (
input_tokens, input_positions, _, return_prompt_lens, _ = (
model_runner._prepare_prompt(seq_group_metadata_list))
assert return_prompt_lens == prompt_lens
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
prompt_lens,
subquery_lens=prompt_lens)
assert input_tokens.shape == (batch_size, max_seq_len)
assert input_positions.shape == (batch_size, max_seq_len)
torch.testing.assert_close(input_tokens, input_positions)
Expand Down
4 changes: 4 additions & 0 deletions vllm/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,7 @@ def __repr__(self) -> str:
return (f'PhysicalTokenBlock(device={self.device}, '
f'block_number={self.block_number}, '
f'ref_count={self.ref_count})')


# Mapping: logical block number -> physical block.
BlockTable = List[PhysicalTokenBlock]
Loading
Loading