diff --git a/python/tvm/meta_schedule/default_config.py b/python/tvm/meta_schedule/default_config.py index ac4028ec50f8..c701fd6568e0 100644 --- a/python/tvm/meta_schedule/default_config.py +++ b/python/tvm/meta_schedule/default_config.py @@ -174,10 +174,12 @@ def schedule_rules( # pylint: disable=redefined-outer-name return sch_rules() if sch_rules is not None: raise TypeError(f"Expected `sch_rules` to be None or callable, but gets: {sch_rules}") - if target.kind.name in ["llvm", "hexagon"]: + if target.kind.name == "llvm": return _DefaultLLVM.schedule_rules() if target.kind.name in ["cuda", "rocm", "vulkan"]: return _DefaultCUDA.schedule_rules() + if target.kind.name == "hexagon": + return _DefaultHexagon.schedule_rules() raise ValueError(f"Unsupported target: {target}") @@ -190,10 +192,12 @@ def postproc( # pylint: disable=redefined-outer-name return postproc() if postproc is not None: raise TypeError(f"Expected `postproc` to be None or callable, but gets: {postproc}") - if target.kind.name in ["llvm", "hexagon"]: + if target.kind.name == "llvm": return _DefaultLLVM.postprocs() if target.kind.name in ["cuda", "rocm", "vulkan"]: return _DefaultCUDA.postprocs() + if target.kind.name == "hexagon": + return _DefaultHexagon.postprocs() raise ValueError(f"Unsupported target: {target}") @@ -277,6 +281,55 @@ def mutator_probs() -> Dict[Mutator, float]: } +class _DefaultHexagon: + """Default tuning configuration for Hexagon.""" + + @staticmethod + def schedule_rules() -> List[ScheduleRule]: + from tvm.meta_schedule import schedule_rule as M + + return [ + M.AutoInline( + into_producer=False, + into_consumer=True, + inline_const_tensor=True, + disallow_if_then_else=True, + require_injective=True, + require_ordered=True, + disallow_op=["tir.exp"], + ), + M.MultiLevelTilingWideVector( + structure="SRSRS", + vector_length_in_bits=1024, + max_innermost_factor=128, + reuse_read=None, + reuse_write=M.ReuseType( + req="may", + levels=[1, 2], + scope="global", + ), + ), + M.ParallelizeVectorizeUnroll( + max_jobs_per_core=16, + max_vectorize_extent=128, + unroll_max_steps=[0, 16, 64, 512], + unroll_explicit=True, + ), + ] + + @staticmethod + def postprocs() -> List[Postproc]: + from tvm.meta_schedule import postproc as M + + return [ + M.DisallowDynamicLoop(), + M.RewriteParallelVectorizeUnroll(), + M.RewriteReductionBlock(), + # TODO(masahi): Fix RewriteLayout for link-params=True case + # M.RewriteLayout(), + ] + + class _DefaultCUDA: """Default tuning configuration for CUDA.""" diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index b1cc0f67bd5f..96b554d4e659 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -554,6 +554,7 @@ def tune_relay( postprocs: Optional[FnPostproc] = None, mutator_probs: Optional[FnMutatorProb] = None, num_threads: Optional[int] = None, + executor=None, ) -> Union[Module, vm.Executable]: """Tune a Relay IRModule with a given target. @@ -581,6 +582,9 @@ def tune_relay( The callbacks used during tuning. backend : str = "graph" The backend to use for relay compilation(graph / vm). + executor : relay.backend.Executor + The executor to be passed to relay.build(...). In particular, its link-params + attribute affects task extration and workload database look up. Returns ------- @@ -596,8 +600,23 @@ def tune_relay( target = default_config.target(target) # pylint: enable=protected-access, # parse the tuning contexts + + if executor is None: + executor = relay.backend.Executor("graph") + + if "link-params" in executor.attrs: + link_params = executor.attrs["link-params"] + else: + link_params = False + with Profiler.timeit("TaskExtraction"): - extracted_tasks = extract_task_from_relay(mod, target, params) + pass_config = { + "relay.FuseOps.link_params": link_params, + "relay.backend.use_meta_schedule": True, + "relay.backend.tir_converter": "default", + } + extracted_tasks = extract_task_from_relay(mod, target, params, pass_config=pass_config) + database = tune_extracted_tasks( extracted_tasks, config, @@ -613,7 +632,7 @@ def tune_relay( mutator_probs=mutator_probs, num_threads=num_threads, ) - relay_build = {"graph": relay.build, "vm": relay.vm.compile}[backend] + with Profiler.timeit("PostTuningCompilation"): with target, autotvm_silencer(), database: with PassContext( @@ -624,4 +643,8 @@ def tune_relay( "relay.backend.tir_converter": "default", }, ): - return relay_build(mod, target=target, params=params) + if backend == "graph": + return relay.build(mod, target=target, params=params, executor=executor) + + # Executor is not supported by VM + return relay.vm.compile(mod, target=target, params=params) diff --git a/tests/python/contrib/test_hexagon/test_meta_schedule.py b/tests/python/contrib/test_hexagon/test_meta_schedule.py index 96d18c9b3076..74f3ab673ec8 100644 --- a/tests/python/contrib/test_hexagon/test_meta_schedule.py +++ b/tests/python/contrib/test_hexagon/test_meta_schedule.py @@ -21,15 +21,20 @@ import tempfile import tvm.testing -from tvm import te +import tvm.topi.testing +from tvm import te, relay from tvm import meta_schedule as ms from tvm.meta_schedule.arg_info import TensorInfo from tvm.meta_schedule.builder import BuilderInput +from tvm.meta_schedule import postproc, schedule_rule from tvm.script import tir as T from tvm.tir import FloatImm from tvm.tir.tensor_intrin.hexagon import VRMPY_u8u8i32_INTRIN from tvm.meta_schedule.runner import RunnerInput from tvm.contrib.hexagon.meta_schedule import get_hexagon_local_builder, get_hexagon_rpc_runner +from tvm.relay.backend import Executor +from tvm.topi.utils import get_const_tuple +from tvm.meta_schedule.testing import te_workload MATMUL_N = 16 MATMUL_M = 32 @@ -166,7 +171,6 @@ def verify_dense(sch, target, M, N, K, hexagon_session): print("%f ms, %f GOPS" % (time_ms, gflops / (time_ms / 1e3))) -@pytest.mark.skip(reason="xgboost not installed on CI") @tvm.testing.requires_hexagon def test_vrmpy_dense(hexagon_launcher): if hexagon_launcher._serial_number == "simulator": @@ -209,3 +213,207 @@ def schedule_dense_for_tune(sch): with hexagon_launcher.start_session() as session: verify_dense(sch, target, M, N, K, session) + + +# This is an example of a schedule found by vrmpy auto tensorization. +# It gets 440 GFLOPS on SD888. +@tvm.script.ir_module +class Module_vrmpy_auto_tensorize: + @T.prim_func + def main( + X: T.Buffer[(128, 768), "uint8"], + packedW: T.Buffer[(24, 192, 32, 4), "uint8"], + compute: T.Buffer[(128, 768), "int32"], + ) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i0_0_i1_0_0_fused in T.parallel( + 512, annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1} + ): + for i0_1_init, i1_0_1_init, i0_2_init, i1_0_2_init in T.grid(2, 3, 1, 1): + with T.block("compute_o_init"): + i = T.axis.spatial(128, i0_0_i1_0_0_fused // 8 * 2 + i0_1_init + i0_2_init) + j_o = T.axis.spatial(24, i1_0_2_init + i0_0_i1_0_0_fused % 8 * 3 + i1_0_1_init) + T.reads() + T.writes(compute[i, j_o * 32 : j_o * 32 + 32]) + for i1_1 in T.vectorized(32): + with T.block("compute_init"): + j_i_init = T.axis.spatial(32, i1_1) + T.reads() + T.writes(compute[i, j_o * 32 + j_i_init]) + compute[i, j_o * 32 + j_i_init] = 0 + for i2_0_0, i0_1, i1_0_1, i2_0_1, i0_2, i1_0_2 in T.grid(32, 2, 3, 6, 1, 1): + with T.block("compute_o_update"): + i = T.axis.spatial(128, i0_0_i1_0_0_fused // 8 * 2 + i0_1 + i0_2) + j_o = T.axis.spatial(24, i1_0_2 + i0_0_i1_0_0_fused % 8 * 3 + i1_0_1) + k_o = T.axis.reduce(192, i2_0_0 * 6 + i2_0_1) + T.reads( + compute[i, j_o * 32 : j_o * 32 + 32], + X[i, k_o * 4 : k_o * 4 + 4], + packedW[j_o, k_o, 0:32, 0:4], + ) + T.writes(compute[i, j_o * 32 : j_o * 32 + 32]) + A = T.match_buffer( + X[i, k_o * 4 : k_o * 4 + 4], [4], dtype="uint8", offset_factor=1 + ) + B = T.match_buffer( + packedW[j_o, k_o, 0:32, 0:4], [32, 4], dtype="uint8", offset_factor=1 + ) + C = T.match_buffer( + compute[i, j_o * 32 : j_o * 32 + 32], [32], dtype="int32", offset_factor=1 + ) + A_u8x4: T.uint8x4 = A[0:4] + A_i32: T.int32 = T.reinterpret(A_u8x4, dtype="int32") + B_i32x32: T.int32x32 = T.reinterpret(B[0, 0:128], dtype="int32x32") + C[0:32] = T.call_llvm_pure_intrin( + 4390, T.uint32(3), C[0:32], B_i32x32, A_i32, dtype="int32x32" + ) + + +@tvm.testing.requires_hexagon +def test_vrmpy_dense_auto_tensorize(hexagon_launcher): + if hexagon_launcher._serial_number == "simulator": + pytest.skip(msg="Tuning on simulator not supported.") + + target_hexagon = tvm.target.hexagon("v68") + target = tvm.target.Target(target_hexagon, host=target_hexagon) + + M, N, K = 128, 768, 768 + workload = te.create_prim_func(dense(M, N, K)) + + sch_rules = [ + schedule_rule.MultiLevelTilingWithIntrin( + VRMPY_u8u8i32_INTRIN, + structure="SRSRS", + tile_binds=None, + max_innermost_factor=64, + vector_load_lens=None, + reuse_read=None, + reuse_write=schedule_rule.ReuseType( + req="may", + levels=[1, 2], + scope="global", + ), + ), + schedule_rule.ParallelizeVectorizeUnroll( + max_jobs_per_core=16, + max_vectorize_extent=128, + unroll_max_steps=[0, 16, 64, 512], + unroll_explicit=True, + ), + ] + + postprocs = [ + postproc.RewriteParallelVectorizeUnroll(), + postproc.RewriteReductionBlock(), + postproc.RewriteTensorize(vectorize_init_loop=True), + ] + + if True: + with tempfile.TemporaryDirectory() as work_dir: + config = ms.TuneConfig( + strategy="replay_trace", + num_trials_per_iter=8, + max_trials_per_task=8, + max_trials_global=8, + ) + + sch = ms.tune_tir( + mod=workload, + target=target, + config=config, + work_dir=work_dir, + sch_rules=lambda: sch_rules, + postprocs=lambda: postprocs, + builder=get_hexagon_local_builder(), + runner=get_hexagon_rpc_runner(hexagon_launcher, number=10), + ) + else: + sch = tvm.tir.Schedule(Module_vrmpy_auto_tensorize, debug_mask="all") + + with hexagon_launcher.start_session() as session: + verify_dense(sch, target, M, N, K, session) + + +@tvm.testing.requires_hexagon +def test_conv2d_relay_auto_schedule(hexagon_launcher): + if hexagon_launcher._serial_number == "simulator": + pytest.skip(msg="Tuning on simulator not supported.") + + target_hexagon = tvm.target.hexagon("v69") + target = tvm.target.Target(target_hexagon, host=target_hexagon) + I, O, H, W = 64, 64, 56, 56 + kH = kW = 3 + + strides = (1, 1) + padding = (1, 1) + + d_shape = (1, H, W, I) + w_shape = (kH, kW, I, O) + bias_shape = (1, 1, 1, w_shape[3]) + out_channel = w_shape[3] + + data = relay.var("data", shape=d_shape, dtype="float16") + weight = relay.var("weight", shape=w_shape, dtype="float16") + bias = relay.var("bias", shape=bias_shape, dtype="float16") + conv2d = relay.nn.conv2d( + data=data, + weight=weight, + kernel_size=(kH, kW), + channels=out_channel, + padding=padding, + strides=strides, + out_dtype="float16", + data_layout="NHWC", + kernel_layout="HWIO", + ) + mod = tvm.IRModule.from_expr(conv2d + bias) + + data_np = np.random.randn(*d_shape).astype("float16") + weight_np = np.random.randn(*w_shape).astype("float16") + bias_np = np.random.randn(*bias_shape).astype("float16") + params = {"weight": weight_np, "bias": bias_np} + + target_llvm = tvm.target.Target("llvm") + + with tvm.transform.PassContext( + opt_level=3, + ): + lib_ref = relay.build(mod, target=target_llvm, params=params) + + rt_mod_ref = tvm.contrib.graph_executor.GraphModule(lib_ref["default"](tvm.cpu(0))) + + rt_mod_ref.set_input("data", data_np) + + rt_mod_ref.run() + + ref = rt_mod_ref.get_output(0).numpy() + + config = ms.TuneConfig( + strategy="replay_trace", + num_trials_per_iter=8, + max_trials_per_task=8, + max_trials_global=8, + ) + + with tempfile.TemporaryDirectory() as work_dir: + executor = Executor("graph", {"link-params": True}) + lib = ms.tune_relay( + mod=mod, + params=params, + target=target, + config=config, + work_dir=work_dir, + builder=get_hexagon_local_builder(), + runner=get_hexagon_rpc_runner(hexagon_launcher, number=20), + executor=executor, + ) + + with hexagon_launcher.start_session() as session: + rt_mod = session.get_executor_from_factory(lib) + + rt_mod.set_input("data", data_np) + + rt_mod.run() + + out = rt_mod.get_output(0).numpy() + print(np.max(np.abs(ref - out)), np.mean(np.abs(ref - out)))