Skip to content

Commit

Permalink
simplified fill
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 17, 2022
1 parent 7b13c73 commit 1adcb77
Showing 1 changed file with 3 additions and 14 deletions.
17 changes: 3 additions & 14 deletions tests/python/unittest/test_mma_16x8x16_4k_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,6 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
for i, j, k in T.grid(16, 16, 16):
with T.block("C"):
i, j, k = T.axis.remap("SSR", [i, j, k])
# thread_id_C, local_id_C = shared_16x16_to_ldmatrix_32x8_layout(i, j)
# thread_id_A, local_id_A = shared_16x16_to_ldmatrix_32x8_layout(i, k)
# thread_id_B, local_id_B = shared_16x16_to_ldmatrix_32x8_layout(k, j)

thread_id_C, local_id_C = shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16)
thread_id_A, local_id_A = shared_16x16_to_ldmatrix_32x8_layout(i % 16, k % 16)
thread_id_B, local_id_B = shared_16x16_to_ldmatrix_32x8_layout(k % 16, j % 16)
Expand Down Expand Up @@ -248,10 +244,9 @@ def mma_fill_desc(a: T.handle) -> None:
with T.block("root"):
T.reads()
T.writes(C_warp[0:32, 0:8])
for i0, i1 in T.grid(32, 8):
for i0, i1 in T.grid(16, 16):
with T.block("C_warp"):
i_init = T.axis.spatial(16, i1 // 4 * 8 + i0 // 4)
j_init = T.axis.spatial(16, (i0 % 4) * 4 + i1 % 4)
i_init, j_init = T.axis.remap("SS", [i0, i1])
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i_init % 16, j_init % 16)
T.reads()
T.writes(C_warp[thread_id, local_id])
Expand Down Expand Up @@ -409,13 +404,7 @@ def index_map(i, j):
sch.tensorize(mma_loop, "mma_sync")

block_init_c = sch.get_block("C_init")
init_loop1, init_loop2 = sch.get_loops(block_init_c)[-2:]
f_0, f_1 = sch.split(init_loop1, factors=[None, 8])
f_2, f_3 = sch.split(init_loop2, factors=[None, 4])
sch.reorder(f_1, f_2, f_0, f_3)
fused_1 = sch.fuse(f_1, f_2)
fused_2 = sch.fuse(f_0, f_3)
sch.tensorize(fused_1, "mma_fill")
sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill")

warp_loop1, warp_loop2 = sch.get_loops(C_warp)[-2:]
f_0, f_1 = sch.split(warp_loop1, factors=[None, 8])
Expand Down

0 comments on commit 1adcb77

Please sign in to comment.