Skip to content

Commit

Permalink
starting 16x8x32 int8
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 17, 2022
1 parent 441fd19 commit 20321fa
Showing 1 changed file with 380 additions and 0 deletions.
380 changes: 380 additions & 0 deletions tests/python/unittest/test_mma_16x8x32_int8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,380 @@
import numpy as np

import tvm
import tvm.testing
import tvm.meta_schedule.testing.te_workload as te_workload
from tvm import te
from tvm.te import create_prim_func
from tvm.tir import Schedule
from tvm.script import tir as T
from tvm import tir


@T.prim_func
def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
A_shared = T.match_buffer(a, (16, 32), "int8", align=128, offset_factor=16, scope="shared")
A_warp = T.match_buffer(c, (32, 16), "int8", align=128, offset_factor=16, scope="warp")

with T.block("root"):
T.reads(A_shared[0:16, 0:32])
T.writes(A_warp[0:32, 0:16])

for ax0, ax1 in T.grid(16, 32):
with T.block("A_shared_warp"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A_shared[v0, v1])
T.writes(A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = A_shared[
v0, v1
]


@T.prim_func
def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None:
s1 = T.var("int32")
s0 = T.var("int32")
A_shared = T.match_buffer(
a,
(16, 32),
"int8",
align=128,
offset_factor=16,
scope="shared",
strides=[s1, s0],
)
A_warp = T.match_buffer(c, (32, 16), "int8", align=128, offset_factor=16, scope="warp")
with T.block("root"):
T.reads(A_shared[0:16, 0:32])
T.writes(A_warp[0:32, 0:16])
tx = T.env_thread("threadIdx.x")
T.launch_thread(tx, 32)

T.evaluate(
T.ptx_ldmatrix(
0,
4,
".b16",
A_warp.data,
16 * tx,
A_shared.data,
32 * (tx % 16) + 16 * (tx // 16),
dtype="int8",
)
)


@T.prim_func
def ldmatrix_b_desc(a: T.handle, c: T.handle) -> None:
B_shared = T.match_buffer(a, (16, 32), "int8", align=128, offset_factor=16, scope="shared")
B_warp = T.match_buffer(c, (32, 16), "int8", align=128, offset_factor=16, scope="warp")

with T.block("root"):
T.reads(B_shared[0:16, 0:32])
T.writes(B_warp[0:32, 0:16])

for ax0, ax1 in T.grid(16, 32):
with T.block("B_shared_warp"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(B_shared[v0, v1])
T.writes(B_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
B_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = B_shared[
v0, v1
]


@T.prim_func
def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None:
s1 = T.var("int32")
s0 = T.var("int32")
B_shared = T.match_buffer(
a,
(16, 32),
"int8",
align=128,
offset_factor=16,
scope="shared",
strides=[s1, s0],
)
B_warp = T.match_buffer(c, (32, 16), "int8", align=128, offset_factor=16, scope="warp")
with T.block("root"):
T.reads(B_shared[0:16, 0:32])
T.writes(B_warp[0:32, 0:16])
tx = T.env_thread("threadIdx.x")
T.launch_thread(tx, 32)

T.evaluate(
T.ptx_ldmatrix(
1,
4,
".b16",
B_warp.data,
16 * tx,
B_shared.data,
32 * (tx % 16) + 16 * (tx // 16),
dtype="int8",
)
)


@T.prim_func
def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (32, 16), "int8", align=128, offset_factor=16, scope="warp")
B = T.match_buffer(b, (32, 16), "int8", align=128, offset_factor=16, scope="warp")
C = T.match_buffer(c, (32, 8), "int32", align=128, offset_factor=16, scope="warp")

with T.block("root"):
T.reads(C[0:32, 0:8], A[0:32, 0:16], B[0:32, 0:16])
T.writes(C[0:32, 0:8])
for i, j, k in T.grid(32, 8, 16):
with T.block("C"):
i, j, k = T.axis.remap("SSR", [i, j, k])
T.reads(
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2],
A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2],
B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2],
)
T.writes(C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2])
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = C[
i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2
] + T.cast(
A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2], "int32"
) * T.cast(
B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2], "int32"
)


@T.prim_func
def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (32, 16), "int8", align=128, offset_factor=16, scope="warp")
B = T.match_buffer(b, (32, 16), "int8", align=128, offset_factor=16, scope="warp")
C = T.match_buffer(c, (32, 8), "int32", align=128, offset_factor=16, scope="warp")

with T.block("root"):
T.reads(C[0:32, 0:8], A[0:32, 0:16], B[0:32, 0:16])
T.writes(C[0:32, 0:8])
tx = T.env_thread("threadIdx.x")
T.launch_thread(tx, 32)

T.evaluate(
T.ptx_mma(
"m16n8k32",
"row",
"col",
"int8",
"int8",
"int32",
A.data,
A.elem_offset + tx * 16,
B.data,
B.elem_offset + tx * 16,
C.data,
C.elem_offset + tx * 8,
False,
dtype="int32",
)
)

T.evaluate(
T.ptx_mma(
"m16n8k32",
"row",
"col",
"int8",
"int8",
"int32",
A.data,
A.elem_offset + tx * 16,
B.data,
B.elem_offset + tx * 16 + 8,
C.data,
C.elem_offset + tx * 8 + 4,
False,
dtype="int32",
)
)


@T.prim_func
def mma_store_desc(a: T.handle, c: T.handle) -> None:
C_warp = T.match_buffer(a, [32, 8], dtype="int32", scope="warp")
C = T.match_buffer(c, [16, 16], dtype="int32", scope="global")

with T.block("root"):
T.reads(C_warp[0:32, 0:8])
T.writes(C[0:16, 0:16])
for ax1_0, i0, i1 in T.grid(2, 32, 4):
with T.block("C_warp"):
v0 = T.axis.spatial(16, i1 // 2 * 8 + i0 // 4)
v1 = T.axis.spatial(16, ax1_0 * 8 + i0 % 4 * 2 + i1 % 2)

T.reads(C_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
T.writes(C[v0, v1])
C[v0, v1] = C_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2]


@T.prim_func
def mma_store_impl(a: T.handle, c: T.handle) -> None:
s1 = T.var("int32")
s0 = T.var("int32")

C_warp = T.match_buffer(a, [32, 8], dtype="int32", scope="warp", offset_factor=1)
C = T.match_buffer(
c, [16, 16], dtype="int32", scope="global", offset_factor=1, strides=[s1, s0]
)

with T.block("root"):
T.reads(C_warp[0:32, 0:8])
T.writes(C[0:16, 0:16])
tx = T.env_thread("threadIdx.x")
T.launch_thread(tx, 32)

T.evaluate(
T.mma_store(
16, 16, C.access_ptr("w"), C_warp.data, C_warp.elem_offset, s1, dtype="int32"
)
)


@T.prim_func
def mma_fill_desc(a: T.handle) -> None:
C_warp = T.match_buffer(a, [32, 8], dtype="int32", scope="warp")

with T.block("root"):
T.reads()
T.writes(C_warp[0:32, 0:8])
for i0, i1 in T.grid(32, 8):
with T.block("C_warp"):
i = T.axis.spatial(16, i1 // 4 * 8 + i0 // 4)
j = T.axis.spatial(16, (i0 % 4) * 4 + i1 % 4)
T.reads()
T.writes(C_warp[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2])
C_warp[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = T.int32(0)


@T.prim_func
def mma_fill_impl(a: T.handle) -> None:
C_warp = T.match_buffer(a, [32, 8], dtype="int32", scope="warp", offset_factor=1)

with T.block("root"):
T.reads()
T.writes(C_warp[0:32, 0:8])
tx = T.env_thread("threadIdx.x")
T.launch_thread(tx, 32)

T.evaluate(T.mma_fill(8, C_warp.data, C_warp.elem_offset, dtype="int32"))


tir.TensorIntrin.register("mma.ldmatrix_a", ldmatrix_a_desc, ldmatrix_a_impl)
tir.TensorIntrin.register("mma.ldmatrix_b", ldmatrix_b_desc, ldmatrix_b_impl)
tir.TensorIntrin.register("mma.mma_sync", mma_sync_desc, mma_sync_impl)
tir.TensorIntrin.register("mma_store", mma_store_desc, mma_store_impl)
tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl)


M = N = K = 16

def matmul_int8(n, m, k):
a = te.placeholder((n, k), name="A", dtype="int8")
b = te.placeholder((k, m), name="B", dtype="int8")
k = te.reduce_axis((0, k), name="k")

def f_compute(i, j):
v_a = tir.Cast(dtype="int32", value=a[i, k])
v_b = tir.Cast(dtype="int32", value=b[k, j])
return te.sum(v_a * v_b, axis=[k])

c = te.compute((n, m), f_compute, name="C")
return (a, b, c)


matmul = create_prim_func(matmul_int8(n=N, m=M, k=K))

sch = Schedule(matmul)
block = sch.get_block("C")

i, j, k = sch.get_loops(block)

i1, i2 = sch.split(i, factors=[None, 16])
sch.bind(i1, "blockIdx.x")


def fetch_to_shared(block, idx):
block_read = sch.cache_read(block, idx, "shared")
if use_gpu:
sch.compute_at(block_read, i1, True)
warp_size = 32
loops = sch.get_loops(block_read)
fused = sch.fuse(*loops[-2:])
f_0, f_1 = sch.split(fused, factors=[None, warp_size])
sch.bind(f_1, "threadIdx.x")

return block_read


A_shared = fetch_to_shared(block, 0)
B_shared = fetch_to_shared(block, 1)


def shared_16x16_to_ldmatrix_32x8_layout(i, j):
thread_id = 4 * (i % 8) + (j % 8) // 2
return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 8) % 2


block = sch.get_block("C")

A_warp = sch.cache_read(block, 0, "warp")

# sch.transform_layout(A_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout)

B_warp = sch.cache_read(block, 1, "warp")

# sch.transform_layout(B_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout)

# sch.tensorize(sch.get_loops(A_warp)[1], "mma.ldmatrix_a")
# sch.tensorize(sch.get_loops(B_warp)[1], "mma.ldmatrix_b")

C_warp = sch.cache_write(block, 0, "warp")
sch.reverse_compute_at(C_warp, sch.get_loops(block)[0])
sch.transform_layout(C_warp, 0, "read", index_map=shared_16x16_to_ldmatrix_32x8_layout)

warp_loop1, warp_loop2 = sch.get_loops(C_warp)[-2:]
f_0, f_1 = sch.split(warp_loop1, factors=[None, 8])
outer, f_2, f_3 = sch.split(warp_loop2, factors=[2, 4, 2])
sch.reorder(outer, f_1, f_2, f_0, f_3)
fused_1 = sch.fuse(f_1, f_2)
fused_2 = sch.fuse(f_0, f_3)

# sch.tensorize(outer, "mma_store")

block_init_c = sch.decompose_reduction(block, sch.get_loops(block)[1])

init_loop1, init_loop2 = sch.get_loops(block_init_c)[-2:]
f_0, f_1 = sch.split(init_loop1, factors=[None, 8])
f_2, f_3 = sch.split(init_loop2, factors=[None, 4])
sch.reorder(f_1, f_2, f_0, f_3)
fused_1 = sch.fuse(f_1, f_2)
fused_2 = sch.fuse(f_0, f_3)
sch.tensorize(fused_1, "mma_fill")

# sch.tensorize(sch.get_loops(block)[1], "mma.mma_sync")

print(sch.mod.script())

# lowered = tvm.lower(sch.mod["main"])

# target = "cuda"

# f = tvm.build(sch.mod["main"], target=target, name="dense")
# dev = tvm.device(target, 0)

# a_np = np.random.randint(-128, 128, (M, K)).astype("int8")
# b_np = np.random.randint(-128, 128, (K, N)).astype("int8")
# c_np = np.dot(a_np.astype("int3232"), b_np.astype("in32"))

# a = tvm.nd.array(a_np, dev)
# b = tvm.nd.array(b_np, dev)
# c = tvm.nd.array(np.zeros((16, K), dtype="float32"), dev)

# # print(f.imported_modules[0].get_source())
# f(a, b, c)
# np.testing.assert_equal(c.numpy(), c_np)

0 comments on commit 20321fa

Please sign in to comment.