From af0da1a2ee340a478cddc27d696ea661eac0aefd Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 19 May 2022 05:01:45 +0900 Subject: [PATCH] add requires_gpu decorator in tests, always test build on non-ampere --- ...est_tir_schedule_tensorize_ldmatrix_mma.py | 40 +++++++++---------- 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py index 2b888c5e86fd6..78c615ff06c34 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py +++ b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py @@ -75,6 +75,12 @@ def maybe_swap(i, j): return (a, b, c) +def is_ampere_or_newer(): + arch = tvm.contrib.nvcc.get_target_compute_version() + major, minor = tvm.contrib.nvcc.parse_compute_version(arch) + return major * 10 + minor >= 80 + + def run_test( k_inner, in_dtype, @@ -182,6 +188,10 @@ def tile_wmma_fragment(block_read, height, width): sch.tensorize(sch.get_loops(C_warp)[-2], mma_store_intrin) f = tvm.build(sch.mod["main"], target="cuda", name="dense") + + if not is_ampere_or_newer(): + return None + dev = tvm.device("cuda", 0) if in_dtype == "float16": @@ -221,16 +231,8 @@ def tile_wmma_fragment(block_read, height, width): return lambda: f.time_evaluator(f.entry_name, dev, number=500)(a, b, c) -def is_ampere_or_newer(): - arch = tvm.contrib.nvcc.get_target_compute_version() - major, minor = tvm.contrib.nvcc.parse_compute_version(arch) - return major * 10 + minor >= 80 - - +@tvm.testing.requires_cuda def test_f16f16f32_m16n16k16(): - if not is_ampere_or_newer(): - return - def index_map(i, j): return ( i // 16, @@ -261,7 +263,7 @@ def index_map(i, j): MMA_store_16x16_f32_global_INTRIN, ) - if measure_perf: + if measure_perf and timer: print("f16f16f32_m16n16k16: %f GFLOPS" % (gflops / (timer().mean))) timer = run_test( @@ -282,14 +284,12 @@ def index_map(i, j): MMA_store_16x16_f32_global_INTRIN, ) - if measure_perf: + if measure_perf and timer: print("f16f16f32_m16n16k16_trans: %f GFLOPS" % (gflops / (timer().mean))) +@tvm.testing.requires_cuda def test_f16f16f16_m16n16k16(): - if not is_ampere_or_newer(): - return - def index_map(i, j): return ( i // 16, @@ -320,7 +320,7 @@ def index_map(i, j): MMA_store_16x16_f16_global_INTRIN, ) - if measure_perf: + if measure_perf and timer: print("f16f16f16_m16n16k16: %f GFLOPS" % (gflops / (timer().mean))) timer = run_test( @@ -341,14 +341,12 @@ def index_map(i, j): MMA_store_16x16_f16_global_INTRIN, ) - if measure_perf: + if measure_perf and timer: print("f16f16f16_m16n16k16_trans: %f GFLOPS" % (gflops / (timer().mean))) +@tvm.testing.requires_cuda def test_i8i8i32_m16n16k32(): - if not is_ampere_or_newer(): - return - def index_map_A(i, j): return ( i // 16, @@ -393,7 +391,7 @@ def index_map_C(i, j): MMA_store_16x16_i32_global_INTRIN, ) - if measure_perf: + if measure_perf and timer: print("i8i8i32_m16n16k32: %f GOPS" % (gflops / (timer().mean))) timer = run_test( @@ -414,7 +412,7 @@ def index_map_C(i, j): MMA_store_16x16_i32_global_INTRIN, ) - if measure_perf: + if measure_perf and timer: print("i8i8i32_m16n16k32_trans: %f GOPS" % (gflops / (timer().mean)))