Skip to content

Commit

Permalink
simplify iterator in layout transform
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 17, 2022
1 parent 9362803 commit 76c1bcf
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 28 deletions.
3 changes: 1 addition & 2 deletions src/tir/schedule/ir_comparator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,7 @@ bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const Buf
for (int i = 0; i < offset; i++) {
// High-dim region must be element-wise
if (!is_one(lhs->region[i]->extent)) return false;
// TODO(masahi): Simplify in layout transform?
// if (!analyzer_.CanProveEqual(indices_base[i], lhs->region[i]->min)) return false;
if (!analyzer_.CanProveEqual(indices_base[i], lhs->region[i]->min)) return false;
}
for (size_t i = 0; i < rhs->region.size(); i++) {
// check extent match
Expand Down
49 changes: 23 additions & 26 deletions tests/python/unittest/test_mma_16x8x16_4k_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
for i, j, k in T.grid(16, 16, 16):
with T.block("C"):
i, j, k = T.axis.remap("SSR", [i, j, k])
thread_id_C, local_id_C = shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16)
thread_id_A, local_id_A = shared_16x16_to_ldmatrix_32x8_layout(i % 16, k % 16)
thread_id_B, local_id_B = shared_16x16_to_ldmatrix_32x8_layout(k % 16, j % 16)
thread_id_C, local_id_C = shared_16x16_to_ldmatrix_32x8_layout(i, j)
thread_id_A, local_id_A = shared_16x16_to_ldmatrix_32x8_layout(i, k)
thread_id_B, local_id_B = shared_16x16_to_ldmatrix_32x8_layout(k, j)

T.reads(
C[thread_id_C, local_id_C],
Expand Down Expand Up @@ -252,7 +252,7 @@ def mma_fill_desc(a: T.handle) -> None:
for i0, i1 in T.grid(16, 16):
with T.block("C_warp"):
i_init, j_init = T.axis.remap("SS", [i0, i1])
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i_init % 16, j_init % 16)
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i_init, j_init)
T.reads()
T.writes(C_warp[thread_id, local_id])
C_warp[thread_id, local_id] = T.float32(0)
Expand Down Expand Up @@ -410,9 +410,6 @@ def index_map(i, j):
sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill")
sch.tensorize(sch.get_loops(C_warp)[-2], "mma_store")

# print(sch.mod.script())
# return


ir_module = tvm.IRModule({"main": workload})
sch = tvm.tir.Schedule(ir_module)
Expand All @@ -439,22 +436,22 @@ def index_map(i, j):
# print(sch.mod.script())
# print(sch.trace)

f = tvm.build(sch.mod["main"], target="cuda", name="dense")
dev = tvm.device("cuda", 0)
a_np = np.random.uniform(size=(N, K)).astype("float16")
b_np = np.random.uniform(size=(K, M)).astype("float16")
c_np = np.dot(a_np.astype("float32"), b_np.astype("float32"))
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(b_np, dev)
c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev)


print(f.imported_modules[0].get_source())
f(a, b, c)
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
print("ok")

evaluator = f.time_evaluator(f.entry_name, dev, number=1000)
gflops = (N * M * K) * 2 / 1e9
time_ms = evaluator(a, b, c).mean * 1e3
print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))
# f = tvm.build(sch.mod["main"], target="cuda", name="dense")
# dev = tvm.device("cuda", 0)
# a_np = np.random.uniform(size=(N, K)).astype("float16")
# b_np = np.random.uniform(size=(K, M)).astype("float16")
# c_np = np.dot(a_np.astype("float32"), b_np.astype("float32"))
# a = tvm.nd.array(a_np, dev)
# b = tvm.nd.array(b_np, dev)
# c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev)


# print(f.imported_modules[0].get_source())
# f(a, b, c)
# tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
# print("ok")

# evaluator = f.time_evaluator(f.entry_name, dev, number=1000)
# gflops = (N * M * K) * 2 / 1e9
# time_ms = evaluator(a, b, c).mean * 1e3
# print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))

0 comments on commit 76c1bcf

Please sign in to comment.