Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 17, 2022
1 parent a9640f4 commit e80a1f1
Showing 1 changed file with 6 additions and 11 deletions.
17 changes: 6 additions & 11 deletions tests/python/unittest/test_mma_16x8x8_4k_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,14 +183,14 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
tir.TensorIntrin.register("mma.ldmatrix_b", ldmatrix_b_desc, ldmatrix_b_impl)
tir.TensorIntrin.register("mma_sync", mma_sync_desc, mma_sync_impl)


N = 4096
M = 4096
K = 4096

workload = te.create_prim_func(te_workload.matmul_fp16(n=N, m=M, k=K))

tune = False
use_ldmatrix = True


def schedule(sch: tir.Schedule):
Expand All @@ -199,6 +199,7 @@ def schedule(sch: tir.Schedule):
i, i_tc = sch.split(i, factors=[None, 16])
j, j_tc = sch.split(j, factors=[None, 8])
k, k_tc = sch.split(k, factors=[None, 8])

sch.reorder(
i, j, k,
i_tc, j_tc, k_tc,
Expand All @@ -211,10 +212,12 @@ def schedule(sch: tir.Schedule):
i_factors = sch.sample_perfect_tile(i, n=5)
j_factors = sch.sample_perfect_tile(j, n=5)
k_factors = sch.sample_perfect_tile(k, n=3)
num_ty = sch.get(i_factors[2]) * sch.get(j_factors[2])
else:
i_factors = [1, 16, 4, 2, 2]
j_factors = [1, 64, 1, 8, 1]
k_factors = [128, 4, 1]
num_ty = i_factors[2] * j_factors[2]

i0, i1, i2, i3, i4 = sch.split(i, factors=i_factors)
j0, j1, j2, j3, j4 = sch.split(j, factors=j_factors)
Expand All @@ -241,11 +244,6 @@ def schedule(sch: tir.Schedule):
sch.bind(block_idy, "blockIdx.y")
sch.bind(thread_idy, "threadIdx.y")

if isinstance(i_factors[2], int):
num_ty = i_factors[2] * j_factors[2]
else:
num_ty = sch.get(i_factors[2]) * sch.get(j_factors[2])

def fetch_to_shared(block, idx, ndim):
block_read = sch.cache_read(block, idx, "shared")
sch.compute_at(block_read, k0)
Expand Down Expand Up @@ -327,8 +325,6 @@ def lambda_b(i, j):
index_map=lambda_a,
)

use_ldmatrix = True

if use_ldmatrix:
sch.tensorize(loop_a, "mma.ldmatrix_a")
sch.tensorize(loop_b, "mma.ldmatrix_b")
Expand All @@ -347,8 +343,8 @@ def lambda_b(i, j):
fused_1 = sch.fuse(warp_loop2, f_0)
sch.bind(fused_1, "threadIdx.x")

loop = sch.get_loops(block_inner)[-3]
sch.tensorize(loop, "mma_sync")
mma_loop = sch.get_loops(block_inner)[-3]
sch.tensorize(mma_loop, "mma_sync")

block_init_c = sch.get_block("C_init")
init_loop1, init_loop2 = sch.get_loops(block_init_c)[-2:]
Expand Down Expand Up @@ -378,7 +374,6 @@ def lambda_b(i, j):
sch = ms.tune_tir(
mod=workload,
target=tvm.target.Target("nvidia/geforce-rtx-3070"),
# use replay or evolutionary search
config=ms.TuneConfig(
strategy="evolutionary",
num_trials_per_iter=32,
Expand Down

0 comments on commit e80a1f1

Please sign in to comment.