Skip to content

Commit

Permalink
triton: cascade kernels (#396)
Browse files Browse the repository at this point in the history
  • Loading branch information
LeshengJin committed Jul 29, 2024
1 parent 68c3719 commit 2496f5b
Show file tree
Hide file tree
Showing 5 changed files with 421 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/flashinfer/triton/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import cascade
152 changes: 152 additions & 0 deletions python/flashinfer/triton/cascade.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from typing import Optional

import torch

from .kernels.cascade import (
merge_state_in_place_kernel,
merge_state_kernel,
merge_states_kernel,
variable_length_merge_states_kernel,
)
from .utils import check_device, check_dim, check_input, check_shape


def merge_state(
v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor
):
check_input(v_a)
check_input(s_a)
check_input(v_b)
check_input(s_b)
check_device([v_a, s_a, v_b, s_b])
check_dim(3, v_a)
check_dim(2, s_a)
check_dim(3, v_b)
check_dim(2, s_b)
check_shape(v_a, v_b)
check_shape(s_a, s_b)
assert v_a.size(0) == s_a.size(0)
assert v_a.size(1) == s_b.size(1)
s_a = s_a.to(torch.float32)
s_b = s_b.to(torch.float32)
seq_len = v_a.size(0)
num_heads = v_a.size(1)
head_dim = v_a.size(2)
v_merged = torch.empty_like(v_a).to(s_a.device)
s_merged = torch.empty((seq_len, num_heads)).to(s_a.device)
bdx = head_dim
bdy = num_heads

merge_state_kernel[lambda meta: (seq_len,)](
v_a, s_a, v_b, s_b, v_merged, s_merged, num_heads, head_dim, bdx=bdx, bdy=bdy
)

return v_merged, s_merged


def merge_state_in_place(
v: torch.Tensor,
s: torch.Tensor,
v_other: torch.Tensor,
s_other: torch.Tensor,
mask: Optional[torch.Tensor] = None,
):
check_input(v)
check_input(s)
check_input(v_other)
check_input(s_other)
check_device([v, s, v_other, s_other])
check_dim(3, v)
check_dim(2, s)
check_dim(3, v_other)
check_dim(2, s_other)
check_shape(v, v_other)
check_shape(s, s_other)
assert v.size(0) == s.size(0)
assert v.size(1) == s.size(1)
assert s.dtype == torch.float32
assert s_other.dtype == torch.float32
if mask is not None:
check_dim(1, mask)
assert v.size(0) == mask.size(0)
assert mask.device == device
seq_len = v.size(0)
num_heads = v.size(1)
head_dim = v.size(2)

bdx = head_dim
bdy = num_heads
merge_state_in_place_kernel[(seq_len,)](
v, s, v_other, s_other, num_heads, head_dim, mask, bdx=bdx, bdy=bdy
)


def merge_states(v: torch.Tensor, s: torch.Tensor):
check_input(v)
check_input(s)
check_device([v, s])
check_dim(4, v)
check_dim(3, s)
assert v.size(0) == s.size(0)
assert v.size(1) == s.size(1)
assert v.size(2) == s.size(2)
seq_len = v.size(0)
num_index_sets = v.size(1)
num_heads = v.size(2)
head_dim = v.size(3)
s = s.to(torch.float32)
v_merged = torch.empty(
(seq_len, num_heads, head_dim), dtype=v.dtype, device=v.device
)
s_merged = torch.empty((seq_len, num_heads), dtype=s.dtype, device=s.device)

bdx = head_dim
bdy = num_heads
merge_states_kernel[(seq_len,)](
v,
s,
v_merged,
s_merged,
num_index_sets,
num_heads,
head_dim,
bdx=bdx,
bdy=bdy,
)
return v_merged, s_merged


def variable_length_merge_states(
v: torch.Tensor, s: torch.Tensor, indptr: torch.Tensor
):
check_input(v)
check_input(s)
check_device([v, s])
check_dim(3, v)
check_dim(2, s)
assert v.size(0) == s.size(0)
assert v.size(1) == s.size(1)
seq_len = indptr.size(0) - 1
num_heads = v.size(1)
head_dim = v.size(2)
s = s.to(torch.float32)
indptr = indptr.to(torch.int32)
v_merged = torch.empty(
(seq_len, num_heads, head_dim), dtype=v.dtype, device=v.device
)
s_merged = torch.empty((seq_len, num_heads), dtype=s.dtype, device=s.device)

bdx = head_dim
bdy = num_heads
variable_length_merge_states_kernel[(seq_len,)](
v,
s,
indptr,
v_merged,
s_merged,
num_heads,
head_dim,
bdx=bdx,
bdy=bdy,
)
return v_merged, s_merged
159 changes: 159 additions & 0 deletions python/flashinfer/triton/kernels/cascade.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import triton
import triton.language as tl


@triton.jit
def state_merge(o, m, d, other_o, other_m, other_d):
m_max = tl.maximum(m, other_m)
d = d * tl.exp2(m - m_max) + other_d * tl.exp2(other_m - m_max)
o = o * tl.exp2(m - m_max) + other_o * tl.exp2(other_m - m_max)
return o, m_max, d


@triton.jit
def state_normalize(o, m, d):
o = o / d
return o, m, d


@triton.jit
def state_get_lse(o, m, d):
return m + tl.log2(d)


@triton.jit
def merge_state_kernel(
v_a_ptr,
s_a_ptr,
v_b_ptr,
s_b_ptr,
v_merged_ptr,
s_merged_ptr,
num_heads,
head_dim,
bdx: tl.constexpr,
bdy: tl.constexpr,
):
pos = tl.program_id(axis=0)
for tx in tl.range(bdx):
for head_idx in tl.range(bdy):
s_a_val = tl.load(s_a_ptr + pos * num_heads + head_idx)
s_b_val = tl.load(s_b_ptr + pos * num_heads + head_idx)

offsets = (pos * num_heads + head_idx) * head_dim + tx
v_a = tl.load(v_a_ptr + offsets)
v_b = tl.load(v_b_ptr + offsets)

v_merged, s_max, d = state_merge(
o=v_a, m=s_a_val, d=1, other_o=v_b, other_m=s_b_val, other_d=1
)
v_merged, s_max, d = state_normalize(v_merged, s_max, d)
v_merged_offset = (pos * num_heads + head_idx) * head_dim + tx
tl.store(v_merged_ptr + v_merged_offset, v_merged)

if s_merged_ptr:
tl.store(
s_merged_ptr + pos * num_heads + head_idx,
tl.log2(d) + s_max,
)


@triton.jit
def merge_state_in_place_kernel(
v_ptr,
s_ptr,
v_other_ptr,
s_other_ptr,
num_heads,
head_dim,
mask_ptr,
bdx: tl.constexpr,
bdy: tl.constexpr,
):
pos = tl.program_id(axis=0)
if mask_ptr:
if tl.load(mask_ptr + pos) == 0:
return

for head_idx in tl.range(bdy):
s_val = tl.load(s_ptr + pos * num_heads + head_idx)
s_other_val = tl.load(s_other_ptr + pos * num_heads + head_idx)
s_max = tl.maximum(s_val, s_other_val)
s_val = tl.exp2(s_val - s_max)
s_other_val = tl.exp2(s_other_val - s_max)
scale = s_val / (s_val + s_other_val)
other_scale = s_other_val / (s_val + s_other_val)
for tx in tl.range(bdx):
offset = (pos * num_heads + head_idx) * head_dim + tx
v_vec = tl.load(v_ptr + offset)
v_other_vec = tl.load(v_other_ptr + offset)
v_vec = scale * v_vec + other_scale * v_other_vec
tl.store(v_ptr + offset, v_vec)
if s_ptr:
tl.store(
s_ptr + pos * num_heads + head_idx,
tl.log2(s_val + s_other_val) + s_max,
)


@triton.jit
def merge_states_kernel(
v_ptr,
s_ptr,
v_merged_ptr,
s_merged_ptr,
num_index_sets,
num_heads,
head_dim,
bdx: tl.constexpr,
bdy: tl.constexpr,
):
pos = tl.program_id(axis=0)

for tx in tl.range(bdx):
for head_idx in tl.range(bdy):
o, m, d = 0.0, -5e4, 1.0
for iter in tl.range(num_index_sets):
s = tl.load(
s_ptr + (pos * num_index_sets + iter) * num_heads + head_idx
)
v = tl.load(
v_ptr
+ ((pos * num_index_sets + iter) * num_heads + head_idx) * head_dim
+ tx
)
o, m, d = state_merge(o, m, d, v, s, 1)
o, m, d = state_normalize(o, m, d)
tl.store(v_merged_ptr + (pos * num_heads + head_idx) * head_dim + tx, o)
if s_merged_ptr:
tl.store(
s_merged_ptr + pos * num_heads + head_idx, state_get_lse(o, m, d)
)


@triton.jit
def variable_length_merge_states_kernel(
v_ptr,
s_ptr,
indptr,
v_merged_ptr,
s_merged_ptr,
num_heads,
head_dim,
bdx: tl.constexpr,
bdy: tl.constexpr,
):
pos = tl.program_id(axis=0)
for tx in tl.range(bdx):
for head_idx in tl.range(bdy):
o, m, d = 0.0, -5e4, 1.0
for iter in tl.range(tl.load(indptr + pos), tl.load(indptr + pos + 1)):
s = tl.load(s_ptr + iter * num_heads + head_idx)
v = tl.load(v_ptr + (iter * num_heads + head_idx) * head_dim + tx)
o, m, d = state_merge(o, m, d, v, s, 1)
o, m, d = state_normalize(o, m, d)
tl.store(v_merged_ptr + (pos * num_heads + head_idx) * head_dim + tx, o)
if s_merged_ptr:
tl.store(
s_merged_ptr + pos * num_heads + head_idx, state_get_lse(o, m, d)
)
28 changes: 28 additions & 0 deletions python/flashinfer/triton/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import List

import torch


def check_input(x: torch.Tensor):
assert x.is_cuda, f"{str(x)} must be a CUDA Tensor"
assert x.is_contiguous(), f"{str(x)} must be contiguous"


def check_dim(d, x: torch.Tensor):
assert x.dim() == d, f"{str(x)} must be a {d}D tensor"


def check_shape(a: torch.Tensor, b: torch.Tensor):
assert a.dim() == b.dim(), f"tensors should have same dim"
for i in range(a.dim()):
assert a.size(i) == b.size(
i
), f"tensors shape mismatch, {a.size()} and {b.size()}"


def check_device(tensors: List[torch.Tensor]):
device = tensors[0].device
for t in tensors:
assert (
t.device == device
), f"All tensors should be on the same device, but got {device} and {t.device}"
Loading

0 comments on commit 2496f5b

Please sign in to comment.