-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
380 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,380 @@ | ||
import numpy as np | ||
|
||
import tvm | ||
import tvm.testing | ||
import tvm.meta_schedule.testing.te_workload as te_workload | ||
from tvm import te | ||
from tvm.te import create_prim_func | ||
from tvm.tir import Schedule | ||
from tvm.script import tir as T | ||
from tvm import tir | ||
|
||
|
||
@T.prim_func | ||
def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None: | ||
A_shared = T.match_buffer(a, (16, 32), "int8", align=128, offset_factor=16, scope="shared") | ||
A_warp = T.match_buffer(c, (32, 16), "int8", align=128, offset_factor=16, scope="warp") | ||
|
||
with T.block("root"): | ||
T.reads(A_shared[0:16, 0:32]) | ||
T.writes(A_warp[0:32, 0:16]) | ||
|
||
for ax0, ax1 in T.grid(16, 32): | ||
with T.block("A_shared_warp"): | ||
v0, v1 = T.axis.remap("SS", [ax0, ax1]) | ||
T.reads(A_shared[v0, v1]) | ||
T.writes(A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2]) | ||
A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = A_shared[ | ||
v0, v1 | ||
] | ||
|
||
|
||
@T.prim_func | ||
def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None: | ||
s1 = T.var("int32") | ||
s0 = T.var("int32") | ||
A_shared = T.match_buffer( | ||
a, | ||
(16, 32), | ||
"int8", | ||
align=128, | ||
offset_factor=16, | ||
scope="shared", | ||
strides=[s1, s0], | ||
) | ||
A_warp = T.match_buffer(c, (32, 16), "int8", align=128, offset_factor=16, scope="warp") | ||
with T.block("root"): | ||
T.reads(A_shared[0:16, 0:32]) | ||
T.writes(A_warp[0:32, 0:16]) | ||
tx = T.env_thread("threadIdx.x") | ||
T.launch_thread(tx, 32) | ||
|
||
T.evaluate( | ||
T.ptx_ldmatrix( | ||
0, | ||
4, | ||
".b16", | ||
A_warp.data, | ||
16 * tx, | ||
A_shared.data, | ||
32 * (tx % 16) + 16 * (tx // 16), | ||
dtype="int8", | ||
) | ||
) | ||
|
||
|
||
@T.prim_func | ||
def ldmatrix_b_desc(a: T.handle, c: T.handle) -> None: | ||
B_shared = T.match_buffer(a, (16, 32), "int8", align=128, offset_factor=16, scope="shared") | ||
B_warp = T.match_buffer(c, (32, 16), "int8", align=128, offset_factor=16, scope="warp") | ||
|
||
with T.block("root"): | ||
T.reads(B_shared[0:16, 0:32]) | ||
T.writes(B_warp[0:32, 0:16]) | ||
|
||
for ax0, ax1 in T.grid(16, 32): | ||
with T.block("B_shared_warp"): | ||
v0, v1 = T.axis.remap("SS", [ax0, ax1]) | ||
T.reads(B_shared[v0, v1]) | ||
T.writes(B_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2]) | ||
B_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = B_shared[ | ||
v0, v1 | ||
] | ||
|
||
|
||
@T.prim_func | ||
def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None: | ||
s1 = T.var("int32") | ||
s0 = T.var("int32") | ||
B_shared = T.match_buffer( | ||
a, | ||
(16, 32), | ||
"int8", | ||
align=128, | ||
offset_factor=16, | ||
scope="shared", | ||
strides=[s1, s0], | ||
) | ||
B_warp = T.match_buffer(c, (32, 16), "int8", align=128, offset_factor=16, scope="warp") | ||
with T.block("root"): | ||
T.reads(B_shared[0:16, 0:32]) | ||
T.writes(B_warp[0:32, 0:16]) | ||
tx = T.env_thread("threadIdx.x") | ||
T.launch_thread(tx, 32) | ||
|
||
T.evaluate( | ||
T.ptx_ldmatrix( | ||
1, | ||
4, | ||
".b16", | ||
B_warp.data, | ||
16 * tx, | ||
B_shared.data, | ||
32 * (tx % 16) + 16 * (tx // 16), | ||
dtype="int8", | ||
) | ||
) | ||
|
||
|
||
@T.prim_func | ||
def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: | ||
A = T.match_buffer(a, (32, 16), "int8", align=128, offset_factor=16, scope="warp") | ||
B = T.match_buffer(b, (32, 16), "int8", align=128, offset_factor=16, scope="warp") | ||
C = T.match_buffer(c, (32, 8), "int32", align=128, offset_factor=16, scope="warp") | ||
|
||
with T.block("root"): | ||
T.reads(C[0:32, 0:8], A[0:32, 0:16], B[0:32, 0:16]) | ||
T.writes(C[0:32, 0:8]) | ||
for i, j, k in T.grid(32, 8, 16): | ||
with T.block("C"): | ||
i, j, k = T.axis.remap("SSR", [i, j, k]) | ||
T.reads( | ||
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2], | ||
A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2], | ||
B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2], | ||
) | ||
T.writes(C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2]) | ||
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = C[ | ||
i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2 | ||
] + T.cast( | ||
A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2], "int32" | ||
) * T.cast( | ||
B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2], "int32" | ||
) | ||
|
||
|
||
@T.prim_func | ||
def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: | ||
A = T.match_buffer(a, (32, 16), "int8", align=128, offset_factor=16, scope="warp") | ||
B = T.match_buffer(b, (32, 16), "int8", align=128, offset_factor=16, scope="warp") | ||
C = T.match_buffer(c, (32, 8), "int32", align=128, offset_factor=16, scope="warp") | ||
|
||
with T.block("root"): | ||
T.reads(C[0:32, 0:8], A[0:32, 0:16], B[0:32, 0:16]) | ||
T.writes(C[0:32, 0:8]) | ||
tx = T.env_thread("threadIdx.x") | ||
T.launch_thread(tx, 32) | ||
|
||
T.evaluate( | ||
T.ptx_mma( | ||
"m16n8k32", | ||
"row", | ||
"col", | ||
"int8", | ||
"int8", | ||
"int32", | ||
A.data, | ||
A.elem_offset + tx * 16, | ||
B.data, | ||
B.elem_offset + tx * 16, | ||
C.data, | ||
C.elem_offset + tx * 8, | ||
False, | ||
dtype="int32", | ||
) | ||
) | ||
|
||
T.evaluate( | ||
T.ptx_mma( | ||
"m16n8k32", | ||
"row", | ||
"col", | ||
"int8", | ||
"int8", | ||
"int32", | ||
A.data, | ||
A.elem_offset + tx * 16, | ||
B.data, | ||
B.elem_offset + tx * 16 + 8, | ||
C.data, | ||
C.elem_offset + tx * 8 + 4, | ||
False, | ||
dtype="int32", | ||
) | ||
) | ||
|
||
|
||
@T.prim_func | ||
def mma_store_desc(a: T.handle, c: T.handle) -> None: | ||
C_warp = T.match_buffer(a, [32, 8], dtype="int32", scope="warp") | ||
C = T.match_buffer(c, [16, 16], dtype="int32", scope="global") | ||
|
||
with T.block("root"): | ||
T.reads(C_warp[0:32, 0:8]) | ||
T.writes(C[0:16, 0:16]) | ||
for ax1_0, i0, i1 in T.grid(2, 32, 4): | ||
with T.block("C_warp"): | ||
v0 = T.axis.spatial(16, i1 // 2 * 8 + i0 // 4) | ||
v1 = T.axis.spatial(16, ax1_0 * 8 + i0 % 4 * 2 + i1 % 2) | ||
|
||
T.reads(C_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2]) | ||
T.writes(C[v0, v1]) | ||
C[v0, v1] = C_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] | ||
|
||
|
||
@T.prim_func | ||
def mma_store_impl(a: T.handle, c: T.handle) -> None: | ||
s1 = T.var("int32") | ||
s0 = T.var("int32") | ||
|
||
C_warp = T.match_buffer(a, [32, 8], dtype="int32", scope="warp", offset_factor=1) | ||
C = T.match_buffer( | ||
c, [16, 16], dtype="int32", scope="global", offset_factor=1, strides=[s1, s0] | ||
) | ||
|
||
with T.block("root"): | ||
T.reads(C_warp[0:32, 0:8]) | ||
T.writes(C[0:16, 0:16]) | ||
tx = T.env_thread("threadIdx.x") | ||
T.launch_thread(tx, 32) | ||
|
||
T.evaluate( | ||
T.mma_store( | ||
16, 16, C.access_ptr("w"), C_warp.data, C_warp.elem_offset, s1, dtype="int32" | ||
) | ||
) | ||
|
||
|
||
@T.prim_func | ||
def mma_fill_desc(a: T.handle) -> None: | ||
C_warp = T.match_buffer(a, [32, 8], dtype="int32", scope="warp") | ||
|
||
with T.block("root"): | ||
T.reads() | ||
T.writes(C_warp[0:32, 0:8]) | ||
for i0, i1 in T.grid(32, 8): | ||
with T.block("C_warp"): | ||
i = T.axis.spatial(16, i1 // 4 * 8 + i0 // 4) | ||
j = T.axis.spatial(16, (i0 % 4) * 4 + i1 % 4) | ||
T.reads() | ||
T.writes(C_warp[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2]) | ||
C_warp[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = T.int32(0) | ||
|
||
|
||
@T.prim_func | ||
def mma_fill_impl(a: T.handle) -> None: | ||
C_warp = T.match_buffer(a, [32, 8], dtype="int32", scope="warp", offset_factor=1) | ||
|
||
with T.block("root"): | ||
T.reads() | ||
T.writes(C_warp[0:32, 0:8]) | ||
tx = T.env_thread("threadIdx.x") | ||
T.launch_thread(tx, 32) | ||
|
||
T.evaluate(T.mma_fill(8, C_warp.data, C_warp.elem_offset, dtype="int32")) | ||
|
||
|
||
tir.TensorIntrin.register("mma.ldmatrix_a", ldmatrix_a_desc, ldmatrix_a_impl) | ||
tir.TensorIntrin.register("mma.ldmatrix_b", ldmatrix_b_desc, ldmatrix_b_impl) | ||
tir.TensorIntrin.register("mma.mma_sync", mma_sync_desc, mma_sync_impl) | ||
tir.TensorIntrin.register("mma_store", mma_store_desc, mma_store_impl) | ||
tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl) | ||
|
||
|
||
M = N = K = 16 | ||
|
||
def matmul_int8(n, m, k): | ||
a = te.placeholder((n, k), name="A", dtype="int8") | ||
b = te.placeholder((k, m), name="B", dtype="int8") | ||
k = te.reduce_axis((0, k), name="k") | ||
|
||
def f_compute(i, j): | ||
v_a = tir.Cast(dtype="int32", value=a[i, k]) | ||
v_b = tir.Cast(dtype="int32", value=b[k, j]) | ||
return te.sum(v_a * v_b, axis=[k]) | ||
|
||
c = te.compute((n, m), f_compute, name="C") | ||
return (a, b, c) | ||
|
||
|
||
matmul = create_prim_func(matmul_int8(n=N, m=M, k=K)) | ||
|
||
sch = Schedule(matmul) | ||
block = sch.get_block("C") | ||
|
||
i, j, k = sch.get_loops(block) | ||
|
||
i1, i2 = sch.split(i, factors=[None, 16]) | ||
sch.bind(i1, "blockIdx.x") | ||
|
||
|
||
def fetch_to_shared(block, idx): | ||
block_read = sch.cache_read(block, idx, "shared") | ||
if use_gpu: | ||
sch.compute_at(block_read, i1, True) | ||
warp_size = 32 | ||
loops = sch.get_loops(block_read) | ||
fused = sch.fuse(*loops[-2:]) | ||
f_0, f_1 = sch.split(fused, factors=[None, warp_size]) | ||
sch.bind(f_1, "threadIdx.x") | ||
|
||
return block_read | ||
|
||
|
||
A_shared = fetch_to_shared(block, 0) | ||
B_shared = fetch_to_shared(block, 1) | ||
|
||
|
||
def shared_16x16_to_ldmatrix_32x8_layout(i, j): | ||
thread_id = 4 * (i % 8) + (j % 8) // 2 | ||
return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 8) % 2 | ||
|
||
|
||
block = sch.get_block("C") | ||
|
||
A_warp = sch.cache_read(block, 0, "warp") | ||
|
||
# sch.transform_layout(A_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout) | ||
|
||
B_warp = sch.cache_read(block, 1, "warp") | ||
|
||
# sch.transform_layout(B_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout) | ||
|
||
# sch.tensorize(sch.get_loops(A_warp)[1], "mma.ldmatrix_a") | ||
# sch.tensorize(sch.get_loops(B_warp)[1], "mma.ldmatrix_b") | ||
|
||
C_warp = sch.cache_write(block, 0, "warp") | ||
sch.reverse_compute_at(C_warp, sch.get_loops(block)[0]) | ||
sch.transform_layout(C_warp, 0, "read", index_map=shared_16x16_to_ldmatrix_32x8_layout) | ||
|
||
warp_loop1, warp_loop2 = sch.get_loops(C_warp)[-2:] | ||
f_0, f_1 = sch.split(warp_loop1, factors=[None, 8]) | ||
outer, f_2, f_3 = sch.split(warp_loop2, factors=[2, 4, 2]) | ||
sch.reorder(outer, f_1, f_2, f_0, f_3) | ||
fused_1 = sch.fuse(f_1, f_2) | ||
fused_2 = sch.fuse(f_0, f_3) | ||
|
||
# sch.tensorize(outer, "mma_store") | ||
|
||
block_init_c = sch.decompose_reduction(block, sch.get_loops(block)[1]) | ||
|
||
init_loop1, init_loop2 = sch.get_loops(block_init_c)[-2:] | ||
f_0, f_1 = sch.split(init_loop1, factors=[None, 8]) | ||
f_2, f_3 = sch.split(init_loop2, factors=[None, 4]) | ||
sch.reorder(f_1, f_2, f_0, f_3) | ||
fused_1 = sch.fuse(f_1, f_2) | ||
fused_2 = sch.fuse(f_0, f_3) | ||
sch.tensorize(fused_1, "mma_fill") | ||
|
||
# sch.tensorize(sch.get_loops(block)[1], "mma.mma_sync") | ||
|
||
print(sch.mod.script()) | ||
|
||
# lowered = tvm.lower(sch.mod["main"]) | ||
|
||
# target = "cuda" | ||
|
||
# f = tvm.build(sch.mod["main"], target=target, name="dense") | ||
# dev = tvm.device(target, 0) | ||
|
||
# a_np = np.random.randint(-128, 128, (M, K)).astype("int8") | ||
# b_np = np.random.randint(-128, 128, (K, N)).astype("int8") | ||
# c_np = np.dot(a_np.astype("int3232"), b_np.astype("in32")) | ||
|
||
# a = tvm.nd.array(a_np, dev) | ||
# b = tvm.nd.array(b_np, dev) | ||
# c = tvm.nd.array(np.zeros((16, K), dtype="float32"), dev) | ||
|
||
# # print(f.imported_modules[0].get_source()) | ||
# f(a, b, c) | ||
# np.testing.assert_equal(c.numpy(), c_np) |