From 936280324ea2c91429a6a85a1b8ee89c7b825928 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 14 May 2022 11:31:39 +0900 Subject: [PATCH] remove obsolet files --- .../test_mma_16x8x16_4k_tune_simple.py | 430 ------------------ .../unittest/test_mma_16x8x16_simple.py | 333 -------------- tests/python/unittest/test_mma_16x8x8_4k.py | 358 --------------- 3 files changed, 1121 deletions(-) delete mode 100644 tests/python/unittest/test_mma_16x8x16_4k_tune_simple.py delete mode 100644 tests/python/unittest/test_mma_16x8x16_simple.py delete mode 100644 tests/python/unittest/test_mma_16x8x8_4k.py diff --git a/tests/python/unittest/test_mma_16x8x16_4k_tune_simple.py b/tests/python/unittest/test_mma_16x8x16_4k_tune_simple.py deleted file mode 100644 index fe3dfcc99175..000000000000 --- a/tests/python/unittest/test_mma_16x8x16_4k_tune_simple.py +++ /dev/null @@ -1,430 +0,0 @@ -import tempfile -import tvm -from tvm.script import tir as T -import tvm.meta_schedule.testing.te_workload as te_workload -from tvm import te, tir -from tvm import meta_schedule as ms -import tvm.testing -import numpy as np - - -@T.prim_func -def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None: - A_shared = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared") - A_warp = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="warp") - - with T.block("root"): - T.reads(A_shared[0:16, 0:16]) - T.writes(A_warp[0:16, 0:16]) - - for ax0, ax1 in T.grid(16, 16): - with T.block("A_shared_warp"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(A_shared[v0, v1]) - T.writes(A_warp[v0, v1]) - A_warp[v0, v1] = A_shared[v0, v1] - - -@T.prim_func -def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None: - s1 = T.var("int32") - s0 = T.var("int32") - A_shared = T.match_buffer( - a, - (16, 16), - "float16", - align=128, - offset_factor=16, - scope="shared", - strides=[s1, s0], - ) - A_warp = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="warp") - with T.block("root"): - T.reads(A_shared[0:16, 0:16]) - T.writes(A_warp[0:16, 0:16]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - - T.evaluate( - T.ptx_ldmatrix( - 0, - 4, - ".b16", - A_warp.data, - A_warp.elem_offset + 8 * tx, - A_shared.access_ptr("r"), - s1 * (tx % 16) + 8 * (tx // 16), - dtype="float16", - ) - ) - - -@T.prim_func -def ldmatrix_b_desc(a: T.handle, c: T.handle) -> None: - B_shared = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared") - B_warp = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="warp") - - with T.block("root"): - T.reads(B_shared[0:16, 0:16]) - T.writes(B_warp[0:16, 0:16]) - - for ax0, ax1 in T.grid(16, 16): - with T.block("B_shared_warp"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(B_shared[v0, v1]) - T.writes(B_warp[v0, v1]) - B_warp[v0, v1] = B_shared[v0, v1] - - -@T.prim_func -def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None: - s1 = T.var("int32") - s0 = T.var("int32") - B_shared = T.match_buffer( - a, - (16, 16), - "float16", - align=128, - offset_factor=16, - scope="shared", - strides=[s1, s0], - ) - B_warp = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="warp") - with T.block("root"): - T.reads(B_shared[0:16, 0:16]) - T.writes(B_warp[0:16, 0:16]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - - T.evaluate( - T.ptx_ldmatrix( - 1, - 4, - ".b16", - B_warp.data, - B_warp.elem_offset + 8 * tx, - B_shared.access_ptr("r"), - s1 * (tx % 16) + 8 * (tx // 16), - dtype="float16", - ) - ) - - -@T.prim_func -def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="warp") - B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16, scope="warp") - C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="warp") - - with T.block("root"): - T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16]) - T.writes(C[0:16, 0:16]) - for i, j, k in T.grid(16, 16, 16): - with T.block("C"): - i, j, k = T.axis.remap("SSR", [i, j, k]) - # T.reads(C[i % 16, j % 16], A[i % 16, k % 16], B[k % 16, j % 16]) - # T.writes(C[i % 16, j % 16]) - # C[i % 16, j % 16] = C[i % 16, j % 16] + T.cast(A[i % 16, k % 16], "float32") * T.cast(B[k % 16, j % 16], "float32") - T.reads(C[i, j], A[i, k], B[k, j]) - T.writes(C[i, j]) - C[i, j] = C[i, j] + T.cast(A[i, k], "float32") * T.cast(B[k, j], "float32") - - -@T.prim_func -def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="warp") - B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16, scope="warp") - C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="warp") - - with T.block("root"): - T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16]) - T.writes(C[0:16, 0:16]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - - T.evaluate( - T.ptx_mma( - "m16n8k16", - "row", - "col", - "fp16", - "fp16", - "fp32", - A.data, - A.elem_offset + tx * 8, - B.data, - B.elem_offset + tx * 8, - C.data, - C.elem_offset + tx * 8, - False, - dtype="float32", - ) - ) - - T.evaluate( - T.ptx_mma( - "m16n8k16", - "row", - "col", - "fp16", - "fp16", - "fp32", - A.data, - A.elem_offset + tx * 8, - B.data, - B.elem_offset + tx * 8 + 4, - C.data, - C.elem_offset + tx * 8 + 4, - False, - dtype="float32", - ) - ) - - -@T.prim_func -def mma_store_desc(a: T.handle, c: T.handle) -> None: - C_warp = T.match_buffer(a, [16, 16], dtype="float32", scope="warp") - C = T.match_buffer(c, [16, 16], dtype="float32", scope="global") - - with T.block("root"): - T.reads(C_warp[0:16, 0:16]) - T.writes(C[0:16, 0:16]) - for i0, i1 in T.grid(16, 16): - with T.block("C_warp"): - v0, v1 = T.axis.remap("SS", [i0, i1]) - T.reads(C_warp[v0, v1]) - T.writes(C[v0, v1]) - C[v0, v1] = C_warp[v0, v1] - - -@T.prim_func -def mma_store_impl(a: T.handle, c: T.handle) -> None: - s1 = T.var("int32") - s0 = T.var("int32") - - C_warp = T.match_buffer(a, [16, 16], dtype="float32", scope="warp", offset_factor=1) - C = T.match_buffer( - c, [16, 16], dtype="float32", scope="global", offset_factor=1, strides=[s1, s0] - ) - - with T.block("root"): - T.reads(C_warp[0:16, 0:16]) - T.writes(C[0:16, 0:16]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - - T.evaluate( - T.mma_store( - 16, 16, C.access_ptr("w"), C_warp.data, C_warp.elem_offset, s1, dtype="float32" - ) - ) - - -@T.prim_func -def mma_fill_desc(a: T.handle) -> None: - C_warp = T.match_buffer(a, [16, 16], dtype="float32", scope="warp") - - with T.block("root"): - T.reads() - T.writes(C_warp[0:16, 0:16]) - for i0, i1 in T.grid(16, 16): - with T.block("C_warp"): - i, j = T.axis.remap("SS", [i0, i1]) - T.reads() - # T.writes(C_warp[i % 16, j % 16]) - # C_warp[i % 16, j % 16] = T.float32(0) - T.writes(C_warp[i, j]) - C_warp[i, j] = T.float32(0) - - -@T.prim_func -def mma_fill_impl(a: T.handle) -> None: - C_warp = T.match_buffer(a, [16, 16], dtype="float32", scope="warp", offset_factor=1) - - with T.block("root"): - T.reads() - T.writes(C_warp[0:16, 0:16]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - - T.evaluate(T.mma_fill(8, C_warp.data, C_warp.elem_offset, dtype="float32")) - - -tir.TensorIntrin.register("mma.ldmatrix_a", ldmatrix_a_desc, ldmatrix_a_impl) -tir.TensorIntrin.register("mma.ldmatrix_b", ldmatrix_b_desc, ldmatrix_b_impl) -tir.TensorIntrin.register("mma_sync", mma_sync_desc, mma_sync_impl) -tir.TensorIntrin.register("mma_store", mma_store_desc, mma_store_impl) -tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl) - -N = 4096 -M = 4096 -K = 4096 - -workload = te.create_prim_func(te_workload.matmul_fp16(n=N, m=M, k=K)) - -tune = False - - -def schedule(sch: tir.Schedule): - block = sch.get_block("C") - i, j, k = sch.get_loops(block) - i, i_tc = sch.split(i, factors=[None, 16]) - j, j_tc = sch.split(j, factors=[None, 16]) - k, k_tc = sch.split(k, factors=[None, 16]) - - sch.reorder( - i, - j, - k, - i_tc, - j_tc, - k_tc, - ) - block_inner = sch.blockize(i_tc) - - block_outer, block_inner = block_inner, block - - if tune: - i_factors = sch.sample_perfect_tile(i, n=5) - j_factors = sch.sample_perfect_tile(j, n=5) - k_factors = sch.sample_perfect_tile(k, n=3) - num_ty = sch.get(i_factors[2]) * sch.get(j_factors[2]) - else: - i_factors = [4, 8, 2, 4, 1] - j_factors = [1, 64, 2, 1, 2] - k_factors = [128, 2, 1] - - num_ty = i_factors[2] * j_factors[2] - - i0, i1, i2, i3, i4 = sch.split(i, factors=i_factors) - j0, j1, j2, j3, j4 = sch.split(j, factors=j_factors) - k0, k1, k2 = sch.split(k, k_factors) - - sch.reorder( - i0, - j0, # S => blockIdx.x - i1, - j1, # S => blockIdx.y - j2, - i2, # S => threadIdx.y - # cache_write here - k0, # R - # vectorized cooperative fetching here - k1, # R - i3, - j3, # S - k2, # R - i4, - j4, - # S - ) - - block_idx = sch.fuse(i0, j0) - block_idy = sch.fuse(i1, j1) - thread_idy = sch.fuse(j2, i2) - sch.bind(block_idx, "blockIdx.x") - sch.bind(block_idy, "blockIdx.y") - sch.bind(thread_idy, "threadIdx.y") - - def fetch_to_shared(block, idx, ndim): - block_read = sch.cache_read(block, idx, "shared") - sch.compute_at(block_read, k0) - vector_size = 8 - warp_size = 32 - fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) - f_0, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vector_size]) - sch.bind(f_2, "threadIdx.x") - sch.bind(f_1, "threadIdx.y") - sch.vectorize(f_3) - sch.storage_align(block_read, 0, axis=-2, factor=32, offset=8) - - return block_read - - A_sh = fetch_to_shared(block_outer, 0, 2) - B_sh = fetch_to_shared(block_outer, 1, 2) - - loop = sch.get_loops(block_outer)[-1] - - A_warp = sch.cache_read(block_outer, 0, "warp") - B_warp = sch.cache_read(block_outer, 1, "warp") - - sch.compute_at(A_warp, k1) - sch.compute_at(B_warp, k1) - - C_warp = sch.cache_write(block_outer, 0, "warp") - sch.reverse_compute_at(C_warp, thread_idy) - - ii, jj = sch.get_loops(C_warp)[-2:] - io, ii = sch.split(ii, factors=[None, 16]) - jo, ji = sch.split(jj, factors=[None, 16]) - sch.reorder(io, jo, ii, ji) - - sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3]) - block_init_c = sch.get_block("C_init") - - def tile_wmma_fragment(block_read, height): - i, j = sch.get_loops(block_read)[-2:] - i0, i1 = sch.split(i, factors=[None, height]) - j0, j1 = sch.split(j, factors=[None, 16]) - sch.reorder(i0, j0, i1, j1) - return i1 - - loop_a = tile_wmma_fragment(A_warp, 16) - loop_b = tile_wmma_fragment(B_warp, 16) - - # sch.transform_layout(A_warp, 0, "write", index_map=lambda i, j: (i // 16, j // 16, i % 16, j % 16)) - # sch.transform_layout(B_warp, 0, "write", index_map=lambda i, j: (i // 16, j // 16, i % 16, j % 16)) - # sch.transform_layout(C_warp, 0, "read", index_map=lambda i, j: (i // 16, j // 16, i % 16, j % 16)) - - sch.tensorize(loop_a, "mma.ldmatrix_a") - sch.tensorize(loop_b, "mma.ldmatrix_b") - sch.tensorize(sch.get_loops(block_inner)[-3], "mma_sync") - sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill") - sch.tensorize(sch.get_loops(C_warp)[-2], "mma_store") - - -ir_module = tvm.IRModule({"main": workload}) -sch = tvm.tir.Schedule(ir_module) -schedule(sch) -print(sch.mod.script()) -# print(tvm.tir.transform.CompactBufferAllocation()(sch.mod)) - -# if tune: -# with tempfile.TemporaryDirectory() as work_dir: -# sch = ms.tune_tir( -# mod=workload, -# target=tvm.target.Target("nvidia/geforce-rtx-3070"), -# config=ms.TuneConfig( -# strategy="evolutionary", -# num_trials_per_iter=32, -# max_trials_per_task=128, -# max_trials_global=128, -# ), -# work_dir=work_dir, -# space=ms.space_generator.ScheduleFn(schedule), -# ) -# if sch is None: -# print("No valid schedule found!") -# else: -# print(sch.mod.script()) -# print(sch.trace) - -f = tvm.build(sch.mod["main"], target="cuda", name="dense") - -dev = tvm.device("cuda", 0) -a_np = np.random.uniform(size=(N, K)).astype("float16") -b_np = np.random.uniform(size=(K, M)).astype("float16") -c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")) -a = tvm.nd.array(a_np, dev) -b = tvm.nd.array(b_np, dev) -c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev) - -print(f.imported_modules[0].get_source()) -f(a, b, c) -# tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) -# print("ok") - -evaluator = f.time_evaluator(f.entry_name, dev, number=10) -gflops = (N * M * K) * 2 / 1e9 -time_ms = evaluator(a, b, c).mean * 1e3 -print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3))) diff --git a/tests/python/unittest/test_mma_16x8x16_simple.py b/tests/python/unittest/test_mma_16x8x16_simple.py deleted file mode 100644 index c99c184b9c7a..000000000000 --- a/tests/python/unittest/test_mma_16x8x16_simple.py +++ /dev/null @@ -1,333 +0,0 @@ -import numpy as np - -import tvm -import tvm.testing -import tvm.meta_schedule.testing.te_workload as te_workload -from tvm import te -from tvm.te import create_prim_func -from tvm.tir import Schedule -from tvm.script import tir as T -from tvm import tir - - -@T.prim_func -def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None: - A_shared = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared") - A_warp = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="warp") - - with T.block("root"): - T.reads(A_shared[0:16, 0:16]) - T.writes(A_warp[0:16, 0:16]) - - for ax0, ax1 in T.grid(16, 16): - with T.block("A_shared_warp"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(A_shared[v0, v1]) - T.writes(A_warp[v0, v1]) - A_warp[v0, v1] = A_shared[v0, v1] - - -@T.prim_func -def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None: - s1 = T.var("int32") - s0 = T.var("int32") - A_shared = T.match_buffer( - a, - (16, 16), - "float16", - align=128, - offset_factor=16, - scope="shared", - strides=[s1, s0], - ) - A_warp = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp") - with T.block("root"): - T.reads(A_shared[0:16, 0:16]) - T.writes(A_warp[0:32, 0:8]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - - T.evaluate( - T.ptx_ldmatrix( - 0, - 4, - ".b16", - A_warp.data, - 8 * tx, - A_shared.data, - 16 * (tx % 16) + 8 * (tx // 16), - dtype="float16", - ) - ) - - -@T.prim_func -def ldmatrix_b_desc(a: T.handle, c: T.handle) -> None: - B_shared = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared") - B_warp = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="warp") - - with T.block("root"): - T.reads(B_shared[0:16, 0:16]) - T.writes(B_warp[0:16, 0:16]) - - for ax0, ax1 in T.grid(16, 16): - with T.block("B_shared_warp"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(B_shared[v0, v1]) - T.writes(B_warp[v0, v1]) - B_warp[v0, v1] = B_shared[v0, v1] - - -@T.prim_func -def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None: - s1 = T.var("int32") - s0 = T.var("int32") - B_shared = T.match_buffer( - a, - (16, 16), - "float16", - align=128, - offset_factor=16, - scope="shared", - strides=[s1, s0], - ) - B_warp = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp") - with T.block("root"): - T.reads(B_shared[0:16, 0:16]) - T.writes(B_warp[0:32, 0:8]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - - T.evaluate( - T.ptx_ldmatrix( - 1, - 4, - ".b16", - B_warp.data, - 8 * tx, - B_shared.data, - 16 * (tx % 16) + 8 * (tx // 16), - dtype="float16", - ) - ) - - -@T.prim_func -def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="warp") - B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16, scope="warp") - C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="warp") - - with T.block("root"): - T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16]) - T.writes(C[0:16, 0:16]) - for i, j, k in T.grid(16, 16, 16): - with T.block("C"): - i, j, k = T.axis.remap("SSR", [i, j, k]) - T.reads(C[i, j], A[i, k], B[k, j]) - T.writes(C[i, j]) - C[i, j] = C[i, j] + T.cast(A[i, k], "float32") * T.cast(B[k, j], "float32") - - -@T.prim_func -def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp") - B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp") - C = T.match_buffer(c, (32, 8), "float32", align=128, offset_factor=16, scope="warp") - - with T.block("root"): - T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8]) - T.writes(C[0:32, 0:8]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - - T.evaluate( - T.ptx_mma( - "m16n8k16", - "row", - "col", - "fp16", - "fp16", - "fp32", - A.data, - A.elem_offset + tx * 8, - B.data, - B.elem_offset + tx * 8, - C.data, - C.elem_offset + tx * 8, - False, - dtype="float32", - ) - ) - - T.evaluate( - T.ptx_mma( - "m16n8k16", - "row", - "col", - "fp16", - "fp16", - "fp32", - A.data, - A.elem_offset + tx * 8, - B.data, - B.elem_offset + tx * 8 + 4, - C.data, - C.elem_offset + tx * 8 + 4, - False, - dtype="float32", - ) - ) - - -@T.prim_func -def mma_store_desc(a: T.handle, c: T.handle) -> None: - C_warp = T.match_buffer(a, [16, 16], dtype="float32", scope="warp") - C = T.match_buffer(c, [16, 16], dtype="float32", scope="global") - - with T.block("root"): - T.reads(C_warp[0:16, 0:16]) - T.writes(C[0:16, 0:16]) - for i0, i1 in T.grid(16, 16): - with T.block("C_warp"): - v0, v1 = T.axis.remap("SS", [i0, i1]) - T.reads(C_warp[v0, v1]) - T.writes(C[v0, v1]) - C[v0, v1] = C_warp[v0, v1] - - -@T.prim_func -def mma_store_impl(a: T.handle, c: T.handle) -> None: - s1 = T.var("int32") - s0 = T.var("int32") - - C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp", offset_factor=1) - C = T.match_buffer( - c, [16, 16], dtype="float32", scope="global", offset_factor=1, strides=[s1, s0] - ) - - with T.block("root"): - T.reads(C_warp[0:32, 0:8]) - T.writes(C[0:16, 0:16]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - - T.evaluate( - T.mma_store( - 16, 16, C.access_ptr("w"), C_warp.data, C_warp.elem_offset, s1, dtype="float32" - ) - ) - - -@T.prim_func -def mma_fill_desc(a: T.handle) -> None: - C_warp = T.match_buffer(a, [16, 16], dtype="float32", scope="warp") - - with T.block("root"): - T.reads() - T.writes(C_warp[0:16, 0:16]) - for i0, i1 in T.grid(16, 16): - with T.block("C_warp"): - i, j = T.axis.remap("SS", [i0, i1]) - T.reads() - T.writes(C_warp[i, j]) - C_warp[i, j] = T.float32(0) - - -@T.prim_func -def mma_fill_impl(a: T.handle) -> None: - C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp", offset_factor=1) - - with T.block("root"): - T.reads() - T.writes(C_warp[0:32, 0:8]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - - T.evaluate(T.mma_fill(8, C_warp.data, C_warp.elem_offset, dtype="float32")) - - -tir.TensorIntrin.register("mma.ldmatrix_a", ldmatrix_a_desc, ldmatrix_a_impl) -tir.TensorIntrin.register("mma.ldmatrix_b", ldmatrix_b_desc, ldmatrix_b_impl) -tir.TensorIntrin.register("mma.mma_sync", mma_sync_desc, mma_sync_impl) -tir.TensorIntrin.register("mma_store", mma_store_desc, mma_store_impl) -tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl) - - -def dense(n: int, m: int, k: int): - a = te.placeholder((n, k), name="A", dtype="float16") - b = te.placeholder((m, k), name="B", dtype="float16") - k = te.reduce_axis((0, k), name="k") - c = te.compute( - (n, m), - lambda i, j: te.sum( - tvm.tir.Cast("float32", a[i, k]) * tvm.tir.Cast("float32", b[j, k]), - axis=[k], - ), - name="C", - ) - return (a, b, c) - - -M = N = K = 16 -# matmul = create_prim_func(dense(n=16, m=K, k=K)) -matmul = create_prim_func(te_workload.matmul_fp16(n=N, m=M, k=K)) - -sch = Schedule(matmul) -block = sch.get_block("C") - -i, j, k = sch.get_loops(block) - -i1, i2 = sch.split(i, factors=[None, 16]) -sch.bind(i1, "blockIdx.x") - -def fetch_to_shared(block, idx): - block_read = sch.cache_read(block, idx, "shared") - sch.compute_at(block_read, i1, True) - warp_size = 32 - loops = sch.get_loops(block_read) - fused = sch.fuse(*loops[-2:]) - f_0, f_1 = sch.split(fused, factors=[None, warp_size]) - sch.bind(f_1, "threadIdx.x") - - return block_read - - -A_shared = fetch_to_shared(block, 0) -B_shared = fetch_to_shared(block, 1) - -block = sch.get_block("C") - -A_warp = sch.cache_read(block, 0, "warp") -B_warp = sch.cache_read(block, 1, "warp") -C_warp = sch.cache_write(block, 0, "warp") -sch.reverse_compute_at(C_warp, sch.get_loops(block)[0]) -block_init_c = sch.decompose_reduction(block, sch.get_loops(block)[1]) - -sch.tensorize(sch.get_loops(A_warp)[1], "mma.ldmatrix_a") -sch.tensorize(sch.get_loops(B_warp)[1], "mma.ldmatrix_b") -sch.tensorize(sch.get_loops(block)[1], "mma.mma_sync") -sch.tensorize(sch.get_loops(C_warp)[1], "mma_store") -sch.tensorize(sch.get_loops(block_init_c)[1], "mma_fill") - -print(sch.mod.script()) - -# lowered = tvm.lower(sch.mod["main"]) - -target = "cuda" - -f = tvm.build(sch.mod["main"], target=target, name="dense") -dev = tvm.device(target, 0) - -a_np = np.random.uniform(size=(16, K)).astype("float16") -b_np = np.random.uniform(size=(K, K)).astype("float16") -c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")) - -a = tvm.nd.array(a_np, dev) -b = tvm.nd.array(b_np, dev) -c = tvm.nd.array(np.zeros((16, K), dtype="float32"), dev) - -# print(f.imported_modules[0].get_source()) -f(a, b, c) -tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) -print("ok") diff --git a/tests/python/unittest/test_mma_16x8x8_4k.py b/tests/python/unittest/test_mma_16x8x8_4k.py deleted file mode 100644 index 0a6a3f006a98..000000000000 --- a/tests/python/unittest/test_mma_16x8x8_4k.py +++ /dev/null @@ -1,358 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import tvm -from tvm import te, tir -from tvm.script import tir as T -import tvm.testing -import numpy as np - - -@T.prim_func -def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None: - A_shared = T.match_buffer( - a, (16, 8), "float16", align=128, offset_factor=16, scope="shared" - ) - A_warp = T.match_buffer( - c, (32, 4), "float16", align=128, offset_factor=16, scope="warp" - ) - - with T.block("root"): - T.reads(A_shared[0:16, 0:8]) - T.writes(A_warp[0:32, 0:4]) - - for ax0, ax1 in T.grid(16, 8): - with T.block("A_shared_warp"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(A_shared[v0, v1]) - T.writes(A_warp[v0 % 8 * 4 + v1 // 2, v0 // 8 * 2 + v1 % 2]) - A_warp[v0 % 8 * 4 + v1 // 2, v0 // 8 * 2 + v1 % 2] = A_shared[v0, v1] - - -@T.prim_func -def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None: - s1 = T.var("int32") - s0 = T.var("int32") - A_shared = T.match_buffer( - a, - (16, 8), - "float16", - align=128, - offset_factor=16, - scope="shared", - strides=[s1, s0], - ) - A_warp = T.match_buffer( - c, (32, 4), "float16", align=128, offset_factor=16, scope="warp" - ) - with T.block("root"): - T.reads(A_shared[0:16, 0:8]) - T.writes(A_warp[0:32, 0:4]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - - T.evaluate( - T.ptx_ldmatrix( - 0, - 2, - ".b16", - A_warp.data, - 4 * tx, - A_shared.access_ptr("r"), - s1 * (tx % 16), - dtype="float16", - ) - ) - - -@T.prim_func -def ldmatrix_b_desc(a: T.handle, c: T.handle) -> None: - B_shared = T.match_buffer( - a, (8, 8), "float16", align=128, offset_factor=16, scope="shared" - ) - B_shared_warp = T.match_buffer( - c, (32, 2), "float16", align=128, offset_factor=16, scope="warp" - ) - - with T.block("root"): - T.reads(B_shared[0:8, 0:8]) - T.writes(B_shared_warp[0:32, 0:2]) - - for ax0, ax1 in T.grid(8, 8): - with T.block("A_shared_warp"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(B_shared[v0, v1]) - T.writes(B_shared_warp[v1 * 4 + v0 // 2, v0 % 2]) - B_shared_warp[v1 * 4 + v0 // 2, v0 % 2] = B_shared[v0, v1] - - -@T.prim_func -def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None: - s1 = T.var("int32") - s0 = T.var("int32") - B_shared = T.match_buffer( - a, - (8, 8), - "float16", - align=128, - offset_factor=16, - scope="shared", - strides=[s1, s0], - ) - B_warp = T.match_buffer( - c, (32, 2), "float16", align=128, offset_factor=16, scope="warp" - ) - with T.block("root"): - T.reads(B_shared[0:8, 0:8]) - T.writes(B_warp[0:32, 0:2]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - - T.evaluate( - T.ptx_ldmatrix( - 0, - 1, - ".b16", - B_warp.data, - 2 * tx, - B_shared.access_ptr("r"), - s1 * (tx % 8), - dtype="float16", - ) - ) - - -@T.prim_func -def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, [32, 4], dtype="float16", scope="warp") - B = T.match_buffer(b, [32, 2], dtype="float16", scope="warp") - C = T.match_buffer(c, [32, 4], dtype="float32", scope="warp") - with T.block("root"): - T.reads(C[0:32, 0:4], A[0:32, 0:4], B[0:32, 0:2]) - T.writes(C[0:32, 0:4]) - for i0, i1, i2 in T.grid(16, 8, 8): - with T.block("C"): - i, j, k = T.axis.remap("SSR", [i0, i1, i2]) - - T.reads( - C[i % 8 * 4 + j % 8 // 2, i % 16 // 8 * 2 + j % 2], - A[i % 8 * 4 + k % 8 // 2, i % 16 // 8 * 2 + k % 2], - B[k % 8 * 4 + j % 8 // 2, j % 2], - ) - T.writes(C[i % 8 * 4 + j % 8 // 2, i % 16 // 8 * 2 + j % 2]) - C[i % 8 * 4 + j % 8 // 2, i % 16 // 8 * 2 + j % 2] = C[ - i % 8 * 4 + j % 8 // 2, i % 16 // 8 * 2 + j % 2 - ] + T.cast( - A[i % 8 * 4 + k % 8 // 2, i % 16 // 8 * 2 + k % 2], "float32" - ) * T.cast( - B[k % 8 * 4 + j % 8 // 2, j % 2], "float32" - ) - - -@T.prim_func -def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (32, 4), "float16", align=128, offset_factor=1, scope="warp") - B = T.match_buffer(b, (32, 2), "float16", align=128, offset_factor=1, scope="warp") - C = T.match_buffer(c, (32, 4), "float32", align=128, offset_factor=1, scope="warp") - - with T.block("root"): - T.reads(C[0:32, 0:4], A[0:32, 0:4], B[0:32, 0:2]) - T.writes(C[0:32, 0:4]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - T.evaluate( - T.ptx_mma( - "m16n8k8", - "row", - "col", - "fp16", - "fp16", - "fp32", - A.data, - A.elem_offset + tx * 4, - B.data, - B.elem_offset + tx * 2, - C.data, - C.elem_offset + tx * 4, - False, - dtype="float32", - ) - ) - - -tir.TensorIntrin.register("mma.ldmatrix_a", ldmatrix_a_desc, ldmatrix_a_impl) -tir.TensorIntrin.register("mma.ldmatrix_b", ldmatrix_b_desc, ldmatrix_b_impl) -tir.TensorIntrin.register("mma_sync", mma_sync_desc, mma_sync_impl) - - -def dense(n: int, m: int, k: int): - a = te.placeholder((n, k), name="A", dtype="float16") - b = te.placeholder((m, k), name="B", dtype="float16") - k = te.reduce_axis((0, k), name="k") - c = te.compute( - (n, m), - lambda i, j: te.sum( - tvm.tir.Cast("float32", a[i, k]) * tvm.tir.Cast("float32", b[j, k]), - axis=[k], - ), - name="C", - ) - return (a, b, c) - - -def test_integration_matmul(): - N = 4096 - M = 4096 - K = 4096 - - workload = te.create_prim_func(dense(n=N, m=M, k=K)) - - def schedule(sch: tir.Schedule): - block = sch.get_block("C") - i, j, k = sch.get_loops(block) - - i, i_tc = sch.split(i, factors=[None, 16]) - j, j_tc = sch.split(j, factors=[None, 8]) - k_outer, k_tc = sch.split(k, factors=[None, 8]) - - sch.reorder( - # fmt: off - i, j, k_outer, - # tensor core - i_tc, j_tc, k_tc - ) - - block_outer = sch.blockize(i_tc) - block, _ = block_outer, block - - sch.bind(sch.fuse(i, j), "blockIdx.x") - - def fetch_to_shared(block, idx, ndim): - block_read = sch.cache_read(block, idx, "shared") - sch.compute_at(block_read, k_outer) - warp_size = 32 - fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) - f_0, f_1 = sch.split(fused, factors=[None, warp_size]) - sch.bind(f_1, "threadIdx.x") - - fetch_to_shared(block, 0, 2) - fetch_to_shared(block, 1, 2) - - # fetch to A_warp 16 * 8 -> 32 * 4 - A_warp = sch.cache_read(block, 0, "warp") - - def lambda_a(i, j): - i_0 = i // 16 - j_0 = j // 8 - - i = i % 16 - j = j % 8 - return ( - i_0, - j_0, - (i % 8) * 4 + (j % 8) // 2, - 4 * (j // 8) + (i // 8) * 2 + (j % 8) % 2, - ) - - sch.transform_layout(A_warp, 0, "write", index_map=lambda_a) - - sch.tensorize(sch.get_loops(A_warp)[2], "mma.ldmatrix_a") - - def lambda_b(i, j): - i_0 = i // 8 - j_0 = j // 8 - i = i % 8 - j = j % 8 - return i_0, j_0, i // 2 + j * 4, i % 2 - - B_warp = sch.cache_read(block, 1, "warp") - sch.transform_layout( - B_warp, - 0, - "write", - index_map=lambda_b, - ) - sch.tensorize(sch.get_loops(B_warp)[2], "mma.ldmatrix_b") - - # fetch to C_warp 16 * 8 -> 32 * 4 - C_warp = sch.cache_write(block, 0, "warp") - sch.reverse_compute_at(C_warp, sch.get_loops(block)[0]) - # need to do a reverse_compute_at to place it under blockidx.x - - sch.transform_layout( - C_warp, - 0, - "read", - index_map=lambda_a, - ) - - warp_loop1, warp_loop2 = sch.get_loops(C_warp)[-2:] - f_0, f_1 = sch.split(warp_loop1, factors=[None, 8]) - f_2, f_3 = sch.split(warp_loop2, factors=[None, 2]) - sch.reorder(f_1, f_2, f_0, f_3) - fused_1 = sch.fuse(f_1, f_2) - fused_2 = sch.fuse(f_0, f_3) - sch.bind(fused_1, "threadIdx.x") - - - block_init_c = sch.decompose_reduction(block, sch.get_loops(block)[1]) - - block_init_c = sch.get_block("C_init") - init_loop1, init_loop2 = sch.get_loops(block_init_c)[-2:] - f_0, f_1 = sch.split(init_loop1, factors=[None, 8]) - f_2, f_3 = sch.split(init_loop2, factors=[None, 2]) - sch.reorder(f_1, f_2, f_0, f_3) - fused_1 = sch.fuse(f_1, f_2) - fused_2 = sch.fuse(f_0, f_3) - sch.bind(fused_1, "threadIdx.x") - - block = sch.get_block("C") - - i1, _, _ = sch.get_loops(block)[-3:] - - sch.tensorize(i1, "mma_sync") - - - sch = tir.Schedule(workload) - schedule(sch) - - print(sch.mod["main"].script()) - - target = "cuda" - f = tvm.build(sch.mod["main"], target=target, name="dense") - dev = tvm.device("cuda", 0) - a_np = np.random.uniform(size=(N, K)).astype("float16") - b_np = np.random.uniform(size=(M, K)).astype("float16") - c_np = np.dot(a_np.astype("float32"), b_np.transpose().astype("float32")) - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(b_np, dev) - c = tvm.nd.array(np.zeros((N, M), dtype="float32"), dev) - # sys.exit(0) - f = tvm.build(sch.mod["main"], target="cuda", name="dense") - f(a, b, c) - print(f.imported_modules[0].get_source()) - tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) - - print("ok") - - evaluator = f.time_evaluator(f.entry_name, dev, number=100) - gflops = (N * M * K) * 2 / 1e9 - time_ms = evaluator(a, b, c).mean * 1e3 - print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3))) - - -if __name__ == "__main__": - test_integration_matmul()