Skip to content

Commit

Permalink
add requires_gpu decorator in tests, always test build on non-ampere
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 18, 2022
1 parent bbd9eaa commit 90e01fd
Showing 1 changed file with 19 additions and 21 deletions.
40 changes: 19 additions & 21 deletions tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ def maybe_swap(i, j):
return (a, b, c)


def is_ampere_or_newer():
arch = tvm.contrib.nvcc.get_target_compute_version()
major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
return major * 10 + minor >= 80


def run_test(
k_inner,
in_dtype,
Expand Down Expand Up @@ -182,6 +188,10 @@ def tile_wmma_fragment(block_read, height, width):
sch.tensorize(sch.get_loops(C_warp)[-2], mma_store_intrin)

f = tvm.build(sch.mod["main"], target="cuda", name="dense")

if not is_ampere_or_newer():
return None

dev = tvm.device("cuda", 0)

if in_dtype == "float16":
Expand Down Expand Up @@ -221,16 +231,8 @@ def tile_wmma_fragment(block_read, height, width):
return lambda: f.time_evaluator(f.entry_name, dev, number=500)(a, b, c)


def is_ampere_or_newer():
arch = tvm.contrib.nvcc.get_target_compute_version()
major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
return major * 10 + minor >= 80


@tvm.testing.requires_cuda
def test_f16f16f32_m16n16k16():
if not is_ampere_or_newer():
return

def index_map(i, j):
return (
i // 16,
Expand Down Expand Up @@ -261,7 +263,7 @@ def index_map(i, j):
MMA_store_16x16_f32_global_INTRIN,
)

if measure_perf:
if measure_perf and timer:
print("f16f16f32_m16n16k16: %f GFLOPS" % (gflops / (timer().mean)))

timer = run_test(
Expand All @@ -282,14 +284,12 @@ def index_map(i, j):
MMA_store_16x16_f32_global_INTRIN,
)

if measure_perf:
if measure_perf and timer:
print("f16f16f32_m16n16k16_trans: %f GFLOPS" % (gflops / (timer().mean)))


@tvm.testing.requires_cuda
def test_f16f16f16_m16n16k16():
if not is_ampere_or_newer():
return

def index_map(i, j):
return (
i // 16,
Expand Down Expand Up @@ -320,7 +320,7 @@ def index_map(i, j):
MMA_store_16x16_f16_global_INTRIN,
)

if measure_perf:
if measure_perf and timer:
print("f16f16f16_m16n16k16: %f GFLOPS" % (gflops / (timer().mean)))

timer = run_test(
Expand All @@ -341,14 +341,12 @@ def index_map(i, j):
MMA_store_16x16_f16_global_INTRIN,
)

if measure_perf:
if measure_perf and timer:
print("f16f16f16_m16n16k16_trans: %f GFLOPS" % (gflops / (timer().mean)))


@tvm.testing.requires_cuda
def test_i8i8i32_m16n16k32():
if not is_ampere_or_newer():
return

def index_map_A(i, j):
return (
i // 16,
Expand Down Expand Up @@ -393,7 +391,7 @@ def index_map_C(i, j):
MMA_store_16x16_i32_global_INTRIN,
)

if measure_perf:
if measure_perf and timer:
print("i8i8i32_m16n16k32: %f GOPS" % (gflops / (timer().mean)))

timer = run_test(
Expand All @@ -414,7 +412,7 @@ def index_map_C(i, j):
MMA_store_16x16_i32_global_INTRIN,
)

if measure_perf:
if measure_perf and timer:
print("i8i8i32_m16n16k32_trans: %f GOPS" % (gflops / (timer().mean)))


Expand Down

0 comments on commit 90e01fd

Please sign in to comment.