diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 31af140167e1..572bc5c9a131 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -118,7 +118,6 @@ class WarpStoreCoeffFinder : private StmtExprVisitor { int num_matrix = op->args[1].as()->value; warp_coeff_ = num_matrix * 2; } else if (op->op.same_as(builtin::mma_fill()) && op->args[1].as() == buffer_) { - LOG(INFO) << op->args[0]; auto* ptr = op->args[0].as(); CHECK(ptr); warp_coeff_ = ptr->value;; @@ -500,7 +499,7 @@ Pass LowerWarpMemory() { WarpMemoryRewriter warp_memory_rewriter(warp_size); auto stmt = warp_memory_rewriter.Rewrite(std::move(n->body)); n->body = UpdatePointerStorageScope(warp_memory_rewriter.new_storage_scopes_)(stmt); - LOG(INFO) << f; + // LOG(INFO) << f; return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {}); diff --git a/tests/python/unittest/test_mma_16x8x16_4k_tune.py b/tests/python/unittest/test_mma_16x8x16_4k_tune.py index e7f5454e9e59..53eaedb7bbde 100644 --- a/tests/python/unittest/test_mma_16x8x16_4k_tune.py +++ b/tests/python/unittest/test_mma_16x8x16_4k_tune.py @@ -53,9 +53,9 @@ def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None: 4, ".b16", A_warp.data, - 8 * tx, - A_shared.data, - 16 * (tx % 16) + 8 * (tx // 16), + A_warp.elem_offset + 8 * tx, + A_shared.access_ptr("r"), + s1 * (tx % 16) + 8 * (tx // 16), dtype="float16", ) ) @@ -106,9 +106,9 @@ def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None: 4, ".b16", B_warp.data, - 8 * tx, - B_shared.data, - 16 * (tx % 16) + 8 * (tx // 16), + B_warp.elem_offset + 8 * tx, + B_shared.access_ptr("r"), + s1 * (tx % 16) + 8 * (tx // 16), dtype="float16", ) ) @@ -313,9 +313,10 @@ def schedule(sch: tir.Schedule): k_factors = sch.sample_perfect_tile(k, n=3) num_ty = sch.get(i_factors[2]) * sch.get(j_factors[2]) else: - i_factors = [1, 16, 4, 2, 2] - j_factors = [1, 32, 1, 8, 1] - k_factors = [64, 4, 1] + i_factors = [4, 8, 2, 4, 1] + j_factors = [1, 64, 2, 1, 2] + k_factors = [128, 2, 1] + num_ty = i_factors[2] * j_factors[2] i0, i1, i2, i3, i4 = sch.split(i, factors=i_factors) @@ -487,12 +488,12 @@ def shared_16x16_to_ldmatrix_32x8_layout(i, j): c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev) f = tvm.build(sch.mod["main"], target="cuda", name="dense") -# print(f.imported_modules[0].get_source()) -# f(a, b, c) -# tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) -# print("ok") +print(f.imported_modules[0].get_source()) +f(a, b, c) +tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) +print("ok") -# evaluator = f.time_evaluator(f.entry_name, dev, number=1000) -# gflops = (N * M * K) * 2 / 1e9 -# time_ms = evaluator(a, b, c).mean * 1e3 -# print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3))) +evaluator = f.time_evaluator(f.entry_name, dev, number=1000) +gflops = (N * M * K) * 2 / 1e9 +time_ms = evaluator(a, b, c).mean * 1e3 +print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))