Skip to content

Commit

Permalink
simplify store
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 17, 2022
1 parent 1adcb77 commit 9489434
Showing 1 changed file with 26 additions and 37 deletions.
63 changes: 26 additions & 37 deletions tests/python/unittest/test_mma_16x8x16_4k_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,9 @@ def mma_store_desc(a: T.handle, c: T.handle) -> None:
with T.block("root"):
T.reads(C_warp[0:32, 0:8])
T.writes(C[0:16, 0:16])
for ax1_0, i0, i1 in T.grid(2, 32, 4):
for i0, i1 in T.grid(16, 16):
with T.block("C_warp"):
v0 = T.axis.spatial(16, i1 // 2 * 8 + i0 // 4)
v1 = T.axis.spatial(16, ax1_0 * 8 + i0 % 4 * 2 + i1 % 2)
v0, v1 = T.axis.remap("SS", [i0, i1])
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(v0, v1)
T.reads(C_warp[thread_id, local_id])
T.writes(C[v0, v1])
Expand Down Expand Up @@ -375,7 +374,8 @@ def fetch_to_shared(block, idx, ndim):
jo, ji = sch.split(jj, factors=[None, 16])
sch.reorder(io, jo, ii, ji)

block_init_c = sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3])
sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3])
block_init_c = sch.get_block("C_init")

def tile_wmma_fragment(block_read, height):
i, j = sch.get_loops(block_read)[-2:]
Expand All @@ -386,7 +386,6 @@ def tile_wmma_fragment(block_read, height):

loop_a = tile_wmma_fragment(A_warp, 16)
loop_b = tile_wmma_fragment(B_warp, 16)
mma_loop = sch.get_loops(block_inner)[-3]

def index_map(i, j):
return (
Expand All @@ -401,18 +400,10 @@ def index_map(i, j):

sch.tensorize(loop_a, "mma.ldmatrix_a")
sch.tensorize(loop_b, "mma.ldmatrix_b")
sch.tensorize(mma_loop, "mma_sync")

block_init_c = sch.get_block("C_init")
sch.tensorize(sch.get_loops(block_inner)[-3], "mma_sync")
sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill")
sch.tensorize(sch.get_loops(C_warp)[-2], "mma_store")

warp_loop1, warp_loop2 = sch.get_loops(C_warp)[-2:]
f_0, f_1 = sch.split(warp_loop1, factors=[None, 8])
outer, f_2, f_3 = sch.split(warp_loop2, factors=[2, 4, 2])
sch.reorder(outer, f_1, f_2, f_0, f_3)
fused_1 = sch.fuse(f_1, f_2)
fused_2 = sch.fuse(f_0, f_3)
sch.tensorize(outer, "mma_store")
# print(sch.mod.script())
# return

Expand Down Expand Up @@ -441,25 +432,23 @@ def index_map(i, j):
# else:
# print(sch.mod.script())
# print(sch.trace)
# else:
# target = "cuda"
# f = tvm.build(sch.mod["main"], target=target, name="dense")

# dev = tvm.device("cuda", 0)
# a_np = np.random.uniform(size=(N, K)).astype("float16")
# b_np = np.random.uniform(size=(K, M)).astype("float16")
# c_np = np.dot(a_np.astype("float32"), b_np.astype("float32"))
# a = tvm.nd.array(a_np, dev)
# b = tvm.nd.array(b_np, dev)
# c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev)
# f = tvm.build(sch.mod["main"], target="cuda", name="dense")

# print(f.imported_modules[0].get_source())
# f(a, b, c)
# tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
# print("ok")

# evaluator = f.time_evaluator(f.entry_name, dev, number=1000)
# gflops = (N * M * K) * 2 / 1e9
# time_ms = evaluator(a, b, c).mean * 1e3
# print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))

f = tvm.build(sch.mod["main"], target="cuda", name="dense")
dev = tvm.device("cuda", 0)
a_np = np.random.uniform(size=(N, K)).astype("float16")
b_np = np.random.uniform(size=(K, M)).astype("float16")
c_np = np.dot(a_np.astype("float32"), b_np.astype("float32"))
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(b_np, dev)
c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev)


print(f.imported_modules[0].get_source())
f(a, b, c)
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
print("ok")

evaluator = f.time_evaluator(f.entry_name, dev, number=1000)
gflops = (N * M * K) * 2 / 1e9
time_ms = evaluator(a, b, c).mean * 1e3
print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))

0 comments on commit 9489434

Please sign in to comment.