Skip to content

Commit

Permalink
Support FP8-E5M2 KV Cache (#2279)
Browse files Browse the repository at this point in the history
Co-authored-by: zhaoyang <zhao.yang16@zte.com.cn>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
  • Loading branch information
3 people committed Jan 29, 2024
1 parent 7d64841 commit 9090bf0
Show file tree
Hide file tree
Showing 26 changed files with 912 additions and 196 deletions.
8 changes: 8 additions & 0 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def main(args: argparse.Namespace):
trust_remote_code=args.trust_remote_code,
dtype=args.dtype,
enforce_eager=args.enforce_eager,
kv_cache_dtype=args.kv_cache_dtype,
)

sampling_params = SamplingParams(
Expand Down Expand Up @@ -117,6 +118,13 @@ def run_to_completion(profile_dir: Optional[str] = None):
parser.add_argument('--enforce-eager',
action='store_true',
help='enforce eager mode and disable CUDA graph')
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=['auto', 'fp8_e5m2'],
default='auto',
help=
'Data type for kv cache storage. If "auto", will use model data type.')
parser.add_argument(
'--profile',
action='store_true',
Expand Down
12 changes: 11 additions & 1 deletion benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def run_vllm(
dtype: str,
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
Expand All @@ -83,6 +84,7 @@ def run_vllm(
dtype=dtype,
max_model_len=max_model_len,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
)

# Add the requests to the engine.
Expand Down Expand Up @@ -206,7 +208,8 @@ def main(args: argparse.Namespace):
args.quantization, args.tensor_parallel_size,
args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype,
args.max_model_len, args.enforce_eager)
args.max_model_len, args.enforce_eager,
args.kv_cache_dtype)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
Expand Down Expand Up @@ -284,6 +287,13 @@ def main(args: argparse.Namespace):
parser.add_argument("--enforce-eager",
action="store_true",
help="enforce eager execution")
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8_e5m2"],
default="auto",
help=
'Data type for kv cache storage. If "auto", will use model data type.')
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
Expand Down
33 changes: 18 additions & 15 deletions benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Optional
import argparse
import random
import time

import torch

from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
from vllm._C import ops

NUM_BLOCKS = 1024
Expand All @@ -23,6 +25,7 @@ def main(
dtype: torch.dtype,
seed: int,
do_profile: bool,
kv_cache_dtype: Optional[str] = None,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
Expand Down Expand Up @@ -59,15 +62,10 @@ def main(
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")

# Create the KV cache.
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x)
key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device="cuda")
key_cache.uniform_(-scale, scale)
value_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size, block_size)
value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device="cuda")
value_cache.uniform_(-scale, scale)
key_caches, value_caches = create_kv_caches_with_random(
NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, kv_cache_dtype,
dtype)
key_cache, value_cache = key_caches[0], value_caches[0]

# Prepare for the paged attention kernel.
output = torch.empty_like(query)
Expand Down Expand Up @@ -106,6 +104,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
)
elif version == "v2":
ops.paged_attention_v2(
Expand All @@ -123,6 +122,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
)
else:
raise ValueError(f"Invalid version: {version}")
Expand Down Expand Up @@ -168,16 +168,18 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
default="half")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--profile", action="store_true")
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8_e5m2"],
default="auto",
help=
'Data type for kv cache storage. If "auto", will use model data type.')
args = parser.parse_args()
print(args)

if args.num_query_heads % args.num_kv_heads != 0:
raise ValueError("num_query_heads must be divisible by num_kv_heads")
dtype_to_torch_dtype = {
"half": torch.half,
"bfloat16": torch.bfloat16,
"float": torch.float,
}
main(
version=args.version,
num_seqs=args.batch_size,
Expand All @@ -187,7 +189,8 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
head_size=args.head_size,
block_size=args.block_size,
use_alibi=args.use_alibi,
dtype=dtype_to_torch_dtype[args.dtype],
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
seed=args.seed,
do_profile=args.profile,
kv_cache_dtype=args.kv_cache_dtype,
)
1 change: 1 addition & 0 deletions csrc/attention/attention_dtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
#include "dtype_float16.cuh"
#include "dtype_float32.cuh"
#include "dtype_bfloat16.cuh"
#include "dtype_fp8_e5m2.cuh"
Loading

0 comments on commit 9090bf0

Please sign in to comment.