-
-
Notifications
You must be signed in to change notification settings - Fork 4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support FP8-E5M2 KV Cache #2279
Changes from 12 commits
0ac4ba1
63ec85b
f98c816
9881221
850137b
a852f54
7be2ed4
4bcf15c
1a13c5a
5cdb619
82516df
556e5b2
c67277b
c3760f8
7bae850
525003b
537b5a7
7e837dd
fe5f053
58d9817
dddd6eb
a61d828
1cb7af6
589297a
6223984
4f85f9b
0ff1d14
b4db831
3072560
d837bbb
b493300
5461bd6
f66fb4e
7e5d61b
b4aedf5
eac2720
bb6cc13
455d0b5
dbd464c
11411e1
4945577
b52e702
fee9a13
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ def main(args: argparse.Namespace): | |
trust_remote_code=args.trust_remote_code, | ||
dtype=args.dtype, | ||
enforce_eager=args.enforce_eager, | ||
kv_cache_dtype=args.kv_cache_dtype, | ||
) | ||
|
||
sampling_params = SamplingParams( | ||
|
@@ -115,6 +116,11 @@ def run_to_completion(profile_dir: Optional[str] = None): | |
parser.add_argument('--enforce-eager', | ||
action='store_true', | ||
help='enforce eager mode and disable CUDA graph') | ||
parser.add_argument('--kv-cache-dtype', | ||
type=str, | ||
choices=['fp8', None], | ||
default=None, | ||
help='Data type for kv cache storage.') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please specify the default behavior here. And one question, why There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Additionally, can we call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The default kv cache data type is same with the model dtype. So it may be not suitable to make |
||
parser.add_argument( | ||
'--profile', | ||
action='store_true', | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -4,7 +4,7 @@ | |||||
|
||||||
import torch | ||||||
|
||||||
from vllm._C import ops | ||||||
from vllm._C import ops, cache_ops | ||||||
|
||||||
NUM_BLOCKS = 1024 | ||||||
PARTITION_SIZE = 512 | ||||||
|
@@ -21,6 +21,7 @@ def main( | |||||
use_alibi: bool, | ||||||
block_size: int, | ||||||
dtype: torch.dtype, | ||||||
use_fp8_kv_cache: bool, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please use a string option instead of bool.
Suggested change
|
||||||
seed: int, | ||||||
do_profile: bool, | ||||||
) -> None: | ||||||
|
@@ -59,15 +60,36 @@ def main( | |||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") | ||||||
|
||||||
# Create the KV cache. | ||||||
x = 16 // torch.tensor([], dtype=dtype).element_size() | ||||||
cache_dtype = dtype if not use_fp8_kv_cache else torch.uint8 | ||||||
x = 16 // torch.tensor([], dtype=cache_dtype).element_size() | ||||||
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x) | ||||||
key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device="cuda") | ||||||
key_cache.uniform_(-scale, scale) | ||||||
key_cache = torch.empty(size=key_cache_shape, | ||||||
dtype=cache_dtype, | ||||||
device="cuda") | ||||||
if not use_fp8_kv_cache: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
key_cache.uniform_(-scale, scale) | ||||||
else: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
# NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type, | ||||||
# it may occur Inf or NaN if we directly use torch.randint | ||||||
# to generate random data for fp8 cache. | ||||||
# For example, s.11111.00 in fp8e5m2 format repesents Inf. | ||||||
# | E4M3 | E5M2 | ||||||
#-----|-------------|------------------- | ||||||
# Inf | N/A | s.11111.00 | ||||||
# NaN | s.1111.111 | s.11111.{01,10,11} | ||||||
key_cache_tmp = torch.empty_like(key_cache, dtype=dtype) | ||||||
key_cache_tmp.uniform_(-scale, scale) | ||||||
cache_ops.convert_fp8(key_cache_tmp, key_cache) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
value_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size, block_size) | ||||||
value_cache = torch.empty(size=value_cache_shape, | ||||||
dtype=dtype, | ||||||
dtype=cache_dtype, | ||||||
device="cuda") | ||||||
value_cache.uniform_(-scale, scale) | ||||||
if not use_fp8_kv_cache: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||||||
value_cache.uniform_(-scale, scale) | ||||||
else: | ||||||
value_cache_tmp = torch.empty_like(value_cache, dtype=dtype) | ||||||
value_cache_tmp.uniform_(-scale, scale) | ||||||
cache_ops.convert_fp8(value_cache_tmp, value_cache) | ||||||
|
||||||
# Prepare for the paged attention kernel. | ||||||
output = torch.empty_like(query) | ||||||
|
@@ -106,6 +128,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float: | |||||
block_size, | ||||||
max_context_len, | ||||||
alibi_slopes, | ||||||
use_fp8_kv_cache, | ||||||
) | ||||||
elif version == "v2": | ||||||
ops.paged_attention_v2( | ||||||
|
@@ -123,6 +146,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float: | |||||
block_size, | ||||||
max_context_len, | ||||||
alibi_slopes, | ||||||
use_fp8_kv_cache, | ||||||
) | ||||||
else: | ||||||
raise ValueError(f"Invalid version: {version}") | ||||||
|
@@ -166,6 +190,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float: | |||||
type=str, | ||||||
choices=["half", "bfloat16", "float"], | ||||||
default="half") | ||||||
parser.add_argument("--use-fp8-kv-cache", action="store_true") | ||||||
parser.add_argument("--seed", type=int, default=0) | ||||||
parser.add_argument("--profile", action="store_true") | ||||||
args = parser.parse_args() | ||||||
|
@@ -188,6 +213,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float: | |||||
block_size=args.block_size, | ||||||
use_alibi=args.use_alibi, | ||||||
dtype=dtype_to_torch_dtype[args.dtype], | ||||||
use_fp8_kv_cache=args.use_fp8_kv_cache, | ||||||
seed=args.seed, | ||||||
do_profile=args.profile, | ||||||
) |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -4,3 +4,4 @@ | |||||
#include "dtype_float16.cuh" | ||||||
#include "dtype_float32.cuh" | ||||||
#include "dtype_bfloat16.cuh" | ||||||
#include "dtype_fp8.cuh" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, please be explicit in this PR that the
fp8
meansfp8_e5m2
. Given that there are multiple ways to implementfp8
, this will make things more clear.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done