diff --git a/tests/python/unittest/test_mma_16x8x8.py b/tests/python/unittest/test_mma_16x8x8.py new file mode 100644 index 000000000000..2a391536aedb --- /dev/null +++ b/tests/python/unittest/test_mma_16x8x8.py @@ -0,0 +1,322 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import te, tir +from tvm.script import tir as T +import tvm.testing +import numpy as np + + +@T.prim_func +def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None: + A_shared = T.match_buffer( + a, (16, 8), "float16", align=128, offset_factor=16, scope="shared" + ) + A_warp = T.match_buffer( + c, (32, 4), "float16", align=128, offset_factor=16, scope="warp" + ) + + with T.block("root"): + T.reads(A_shared[0:16, 0:8]) + T.writes(A_warp[0:32, 0:4]) + + for ax0, ax1 in T.grid(16, 8): + with T.block("A_shared_warp"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A_shared[v0, v1]) + T.writes( + A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] + ) + A_warp[ + v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2 + ] = A_shared[v0, v1] + + +@T.prim_func +def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None: + s1 = T.var("int32") + s0 = T.var("int32") + A_shared = T.match_buffer( + a, + (16, 8), + "float16", + align=128, + offset_factor=16, + scope="shared", + strides=[s1, s0], + ) + A_warp = T.match_buffer( + c, (32, 4), "float16", align=128, offset_factor=16, scope="warp" + ) + with T.block("root"): + T.reads(A_shared[0:16, 0:8]) + T.writes(A_warp[0:32, 0:4]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, 32) + + T.evaluate( + T.ptx_ldmatrix( + 0, + 2, + ".b16", + A_warp.data, + 4 * tx, + A_shared.data, + 8 * (tx % 16), + dtype="float16", + ) + ) + + +@T.prim_func +def ldmatrix_b_desc(a: T.handle, c: T.handle) -> None: + B_shared = T.match_buffer( + a, (8, 8), "float16", align=128, offset_factor=16, scope="shared" + ) + B_shared_warp = T.match_buffer( + c, (32, 2), "float16", align=128, offset_factor=16, scope="warp" + ) + + with T.block("root"): + T.reads(B_shared[0:8, 0:8]) + T.writes(B_shared_warp[0:32, 0:2]) + + for ax0, ax1 in T.grid(8, 8): + with T.block("A_shared_warp"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(B_shared[v0, v1]) + T.writes(B_shared_warp[v1 * 4 + v0 // 2, v0 % 2]) + B_shared_warp[v1 * 4 + v0 // 2, v0 % 2] = B_shared[v0, v1] + + +@T.prim_func +def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None: + s1 = T.var("int32") + s0 = T.var("int32") + B_shared = T.match_buffer( + a, + (8, 8), + "float16", + align=128, + offset_factor=16, + scope="shared", + strides=[s1, s0], + ) + B_warp = T.match_buffer( + c, (32, 2), "float16", align=128, offset_factor=16, scope="warp" + ) + with T.block("root"): + T.reads(B_shared[0:8, 0:8]) + T.writes(B_warp[0:32, 0:2]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, 32) + + T.evaluate( + T.ptx_ldmatrix( + 0, + 1, + ".b16", + B_warp.data, + 2 * tx, + B_shared.data, + 8 * (tx % 8), + dtype="float16", + ) + ) + + +@T.prim_func +def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [32, 4], dtype="float16", scope="warp") + B = T.match_buffer(b, [32, 2], dtype="float16", scope="warp") + C = T.match_buffer(c, [32, 4], dtype="float32", scope="warp") + with T.block("root"): + T.reads(C[0 : 32, 0 : 4], A[0 : 32, 0 : 4], B[0 : 32, 0 : 2]) + T.writes(C[0 : 32, 0 : 4]) + for i0, i1, i2 in T.grid(16, 8, 8): + with T.block("C"): + i, j, k = T.axis.remap("SSR", [i0, i1, i2]) + + T.reads(C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2], A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2], B[k * 4 + j // 2, j % 2]) + T.writes(C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2]) + C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] + T.cast(A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2], "float32") * T.cast(B[k * 4 + j // 2, j % 2], "float32") + + +@T.prim_func +def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (32, 4), "float16", align=128, offset_factor=1, scope="warp") + B = T.match_buffer(b, (32, 2), "float16", align=128, offset_factor=1, scope="warp") + C = T.match_buffer(c, (32, 4), "float32", align=128, offset_factor=1, scope="warp") + + with T.block("root"): + T.reads(C[0:32, 0:4], A[0:32, 0:4], B[0:32, 0:2]) + T.writes(C[0:32, 0:4]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, 32) + T.evaluate( + T.ptx_mma( + "m16n8k8", + "row", + "col", + "fp16", + "fp16", + "fp32", + A.data, + A.elem_offset + tx * 4, + B.data, + B.elem_offset + tx * 2, + C.data, + C.elem_offset + tx * 4, + False, + dtype="float32", + ) + ) + + +tir.TensorIntrin.register("mma.ldmatrix_a", ldmatrix_a_desc, ldmatrix_a_impl) +tir.TensorIntrin.register("mma.ldmatrix_b", ldmatrix_b_desc, ldmatrix_b_impl) +tir.TensorIntrin.register("mma_sync", mma_sync_desc, mma_sync_impl) + + +def dense(n: int, m: int, k: int): + a = te.placeholder((n, k), name="A", dtype="float16") + b = te.placeholder((m, k), name="B", dtype="float16") + k = te.reduce_axis((0, k), name="k") + c = te.compute( + (n, m), + lambda i, j: te.sum( + tvm.tir.Cast("float32", a[i, k]) * tvm.tir.Cast("float32", b[j, k]), + axis=[k], + ), + name="C", + ) + return (a, b, c) + + +def test_integration_matmul(): + N = 16 + M = 8 + K = 8 + + workload = te.create_prim_func(dense(n=N, m=M, k=K)) + + def schedule(sch: tir.Schedule): + block = sch.get_block("C") + i, j, k = sch.get_loops(block) + + # Step 2. Rule-Multi-Level-Tiling + i1, i2 = sch.split(i, factors=[None, 16]) + sch.bind(i1, "blockIdx.x") + + def fetch_to_shared(block, idx, ndim): + block_read = sch.cache_read(block, idx, "shared") + sch.compute_at(block_read, i1) + warp_size = 32 + fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) + f_0, f_1 = sch.split(fused, factors=[None, warp_size]) + sch.bind(f_1, "threadIdx.x") + + fetch_to_shared(block, 0, 2) + fetch_to_shared(block, 1, 2) + + # fetch to A_warp 16 * 8 -> 32 * 4 + A_warp = sch.cache_read(block, 0, "warp") + sch.transform_layout( + A_warp, + 0, + "write", + index_map=lambda i, j: ( + (i % 8) * 4 + (j % 8) // 2, + 4 * (j // 8) + (i // 8) * 2 + (j % 8) % 2, + ), + ) + + sch.tensorize(sch.get_loops(A_warp)[1], "mma.ldmatrix_a") + + B_warp = sch.cache_read(block, 1, "warp") + sch.transform_layout( + B_warp, + 0, + "write", + index_map=lambda i, j: (i // 2 + j * 4, i % 2), + ) + sch.tensorize(sch.get_loops(B_warp)[1], "mma.ldmatrix_b") + + # fetch to C_warp 16 * 8 -> 32 * 4 + C_warp = sch.cache_write(block, 0, "warp") + sch.reverse_compute_at(C_warp, sch.get_loops(block)[0]) + # need to do a reverse_compute_at to place it under blockidx.x + sch.transform_layout( + C_warp, + 0, + "read", + index_map=lambda i, j: ( + (i % 8) * 4 + (j % 8) // 2, + 4 * (j // 8) + (i // 8) * 2 + (j % 8) % 2, + ), + ) + warp_loop1, warp_loop2 = sch.get_loops(C_warp)[-2:] + f_0, f_1 = sch.split(warp_loop1, factors=[None, 8]) + f_2, f_3 = sch.split(warp_loop2, factors=[None, 2]) + sch.reorder(f_1, f_2, f_0, f_3) + fused_1 = sch.fuse(f_1, f_2) + fused_2 = sch.fuse(f_0, f_3) + sch.bind(fused_1, "threadIdx.x") + + # Decompose -> separate C_init from C_warp + loop = sch.get_loops(block)[1] + block_init_c = sch.decompose_reduction(block, loop) + + # C_init() 16 * 8 -> 32 * 4 + # as binding is already transformed by previous step + # only split/reorder/fuse is needed here + C_init = block_init_c + init_loop1, init_loop2 = sch.get_loops(C_init)[-2:] + f_0, f_1 = sch.split(init_loop1, factors=[None, 8]) + f_2, f_3 = sch.split(init_loop2, factors=[None, 2]) + sch.reorder(f_1, f_2, f_0, f_3) + fused_1 = sch.fuse(f_1, f_2) + fused_2 = sch.fuse(f_0, f_3) + sch.bind(fused_1, "threadIdx.x") + + # tensorize + i0, i1, i2, i3 = sch.get_loops(block) + sch.tensorize(i1, "mma_sync") + + sch = tir.Schedule(workload) + schedule(sch) + + print(sch.mod["main"].script()) + + target = "cuda" + f = tvm.build(sch.mod["main"], target=target, name="dense") + dev = tvm.device("cuda", 0) + a_np = np.random.uniform(size=(N, K)).astype("float16") + b_np = np.random.uniform(size=(M, K)).astype("float16") + c_np = np.dot(a_np.astype("float32"), b_np.transpose().astype("float32")) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.array(np.zeros((N, M), dtype="float32"), dev) + # sys.exit(0) + f = tvm.build(sch.mod["main"], target="cuda", name="dense") + f(a, b, c) + print(f.imported_modules[0].get_source()) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) + + +if __name__ == "__main__": + test_integration_matmul()