Skip to content

Commit

Permalink
16x8x16 4k tune working
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 17, 2022
1 parent c3cb170 commit 5b2d486
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 19 deletions.
3 changes: 1 addition & 2 deletions src/tir/transforms/lower_warp_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ class WarpStoreCoeffFinder : private StmtExprVisitor {
int num_matrix = op->args[1].as<IntImmNode>()->value;
warp_coeff_ = num_matrix * 2;
} else if (op->op.same_as(builtin::mma_fill()) && op->args[1].as<VarNode>() == buffer_) {
LOG(INFO) << op->args[0];
auto* ptr = op->args[0].as<IntImmNode>();
CHECK(ptr);
warp_coeff_ = ptr->value;;
Expand Down Expand Up @@ -500,7 +499,7 @@ Pass LowerWarpMemory() {
WarpMemoryRewriter warp_memory_rewriter(warp_size);
auto stmt = warp_memory_rewriter.Rewrite(std::move(n->body));
n->body = UpdatePointerStorageScope(warp_memory_rewriter.new_storage_scopes_)(stmt);
LOG(INFO) << f;
// LOG(INFO) << f;
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {});
Expand Down
35 changes: 18 additions & 17 deletions tests/python/unittest/test_mma_16x8x16_4k_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None:
4,
".b16",
A_warp.data,
8 * tx,
A_shared.data,
16 * (tx % 16) + 8 * (tx // 16),
A_warp.elem_offset + 8 * tx,
A_shared.access_ptr("r"),
s1 * (tx % 16) + 8 * (tx // 16),
dtype="float16",
)
)
Expand Down Expand Up @@ -106,9 +106,9 @@ def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None:
4,
".b16",
B_warp.data,
8 * tx,
B_shared.data,
16 * (tx % 16) + 8 * (tx // 16),
B_warp.elem_offset + 8 * tx,
B_shared.access_ptr("r"),
s1 * (tx % 16) + 8 * (tx // 16),
dtype="float16",
)
)
Expand Down Expand Up @@ -313,9 +313,10 @@ def schedule(sch: tir.Schedule):
k_factors = sch.sample_perfect_tile(k, n=3)
num_ty = sch.get(i_factors[2]) * sch.get(j_factors[2])
else:
i_factors = [1, 16, 4, 2, 2]
j_factors = [1, 32, 1, 8, 1]
k_factors = [64, 4, 1]
i_factors = [4, 8, 2, 4, 1]
j_factors = [1, 64, 2, 1, 2]
k_factors = [128, 2, 1]

num_ty = i_factors[2] * j_factors[2]

i0, i1, i2, i3, i4 = sch.split(i, factors=i_factors)
Expand Down Expand Up @@ -487,12 +488,12 @@ def shared_16x16_to_ldmatrix_32x8_layout(i, j):
c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev)
f = tvm.build(sch.mod["main"], target="cuda", name="dense")

# print(f.imported_modules[0].get_source())
# f(a, b, c)
# tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
# print("ok")
print(f.imported_modules[0].get_source())
f(a, b, c)
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
print("ok")

# evaluator = f.time_evaluator(f.entry_name, dev, number=1000)
# gflops = (N * M * K) * 2 / 1e9
# time_ms = evaluator(a, b, c).mean * 1e3
# print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))
evaluator = f.time_evaluator(f.entry_name, dev, number=1000)
gflops = (N * M * K) * 2 / 1e9
time_ms = evaluator(a, b, c).mean * 1e3
print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))

0 comments on commit 5b2d486

Please sign in to comment.