From 9489434ee52b546e2abb2ab28173eefd51525ba4 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 14 May 2022 10:01:12 +0900 Subject: [PATCH] simplify store --- .../unittest/test_mma_16x8x16_4k_tune.py | 63 ++++++++----------- 1 file changed, 26 insertions(+), 37 deletions(-) diff --git a/tests/python/unittest/test_mma_16x8x16_4k_tune.py b/tests/python/unittest/test_mma_16x8x16_4k_tune.py index 1abc5168e559..5d78ec8483f2 100644 --- a/tests/python/unittest/test_mma_16x8x16_4k_tune.py +++ b/tests/python/unittest/test_mma_16x8x16_4k_tune.py @@ -204,10 +204,9 @@ def mma_store_desc(a: T.handle, c: T.handle) -> None: with T.block("root"): T.reads(C_warp[0:32, 0:8]) T.writes(C[0:16, 0:16]) - for ax1_0, i0, i1 in T.grid(2, 32, 4): + for i0, i1 in T.grid(16, 16): with T.block("C_warp"): - v0 = T.axis.spatial(16, i1 // 2 * 8 + i0 // 4) - v1 = T.axis.spatial(16, ax1_0 * 8 + i0 % 4 * 2 + i1 % 2) + v0, v1 = T.axis.remap("SS", [i0, i1]) thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(v0, v1) T.reads(C_warp[thread_id, local_id]) T.writes(C[v0, v1]) @@ -375,7 +374,8 @@ def fetch_to_shared(block, idx, ndim): jo, ji = sch.split(jj, factors=[None, 16]) sch.reorder(io, jo, ii, ji) - block_init_c = sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3]) + sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3]) + block_init_c = sch.get_block("C_init") def tile_wmma_fragment(block_read, height): i, j = sch.get_loops(block_read)[-2:] @@ -386,7 +386,6 @@ def tile_wmma_fragment(block_read, height): loop_a = tile_wmma_fragment(A_warp, 16) loop_b = tile_wmma_fragment(B_warp, 16) - mma_loop = sch.get_loops(block_inner)[-3] def index_map(i, j): return ( @@ -401,18 +400,10 @@ def index_map(i, j): sch.tensorize(loop_a, "mma.ldmatrix_a") sch.tensorize(loop_b, "mma.ldmatrix_b") - sch.tensorize(mma_loop, "mma_sync") - - block_init_c = sch.get_block("C_init") + sch.tensorize(sch.get_loops(block_inner)[-3], "mma_sync") sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill") + sch.tensorize(sch.get_loops(C_warp)[-2], "mma_store") - warp_loop1, warp_loop2 = sch.get_loops(C_warp)[-2:] - f_0, f_1 = sch.split(warp_loop1, factors=[None, 8]) - outer, f_2, f_3 = sch.split(warp_loop2, factors=[2, 4, 2]) - sch.reorder(outer, f_1, f_2, f_0, f_3) - fused_1 = sch.fuse(f_1, f_2) - fused_2 = sch.fuse(f_0, f_3) - sch.tensorize(outer, "mma_store") # print(sch.mod.script()) # return @@ -441,25 +432,23 @@ def index_map(i, j): # else: # print(sch.mod.script()) # print(sch.trace) -# else: -# target = "cuda" -# f = tvm.build(sch.mod["main"], target=target, name="dense") - -# dev = tvm.device("cuda", 0) -# a_np = np.random.uniform(size=(N, K)).astype("float16") -# b_np = np.random.uniform(size=(K, M)).astype("float16") -# c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")) -# a = tvm.nd.array(a_np, dev) -# b = tvm.nd.array(b_np, dev) -# c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev) -# f = tvm.build(sch.mod["main"], target="cuda", name="dense") - -# print(f.imported_modules[0].get_source()) -# f(a, b, c) -# tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) -# print("ok") - -# evaluator = f.time_evaluator(f.entry_name, dev, number=1000) -# gflops = (N * M * K) * 2 / 1e9 -# time_ms = evaluator(a, b, c).mean * 1e3 -# print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3))) + +f = tvm.build(sch.mod["main"], target="cuda", name="dense") +dev = tvm.device("cuda", 0) +a_np = np.random.uniform(size=(N, K)).astype("float16") +b_np = np.random.uniform(size=(K, M)).astype("float16") +c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")) +a = tvm.nd.array(a_np, dev) +b = tvm.nd.array(b_np, dev) +c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev) + + +print(f.imported_modules[0].get_source()) +f(a, b, c) +tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) +print("ok") + +evaluator = f.time_evaluator(f.entry_name, dev, number=1000) +gflops = (N * M * K) * 2 / 1e9 +time_ms = evaluator(a, b, c).mean * 1e3 +print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))