From 68464841ea617f9c859c1f7c546b29b611428186 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 27 Apr 2022 02:17:44 +0900 Subject: [PATCH] [Metaschedule] Auto tensorization for CPU / GPU dot product (#11088) * [Metaschedule] Auto-tensorization for CPU / GPU dot product Co-authored-by: Siyuan Feng Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Ruihang Lai Co-authored-by: Wuwei Lin * doc update * add vnni conv2d test * add dp4a test * adding tests for rewrite_tensorize * add rewrite_tensorize test * add missing pydoc * black * more doc * adding auto tensorize integration test * add dp4a test * fix target name * fix dtype in test * skip bert test * replace hard-coded llvm intrinsic id in test with look up * remove unnecessary include, add doc for the rest of params * update postproc.h * update doc * fix shape in te matmul workload * fix newline in cppdoc Co-authored-by: Siyuan Feng Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Ruihang Lai Co-authored-by: Wuwei Lin --- include/tvm/meta_schedule/postproc.h | 6 +- include/tvm/meta_schedule/schedule_rule.h | 23 + include/tvm/tir/stmt.h | 5 + python/tvm/meta_schedule/postproc/__init__.py | 1 + .../postproc/rewrite_tensorize.py | 38 ++ .../meta_schedule/schedule_rule/__init__.py | 2 +- .../schedule_rule/multi_level_tiling.py | 49 ++ .../tvm/meta_schedule/testing/te_workload.py | 2 +- .../postproc/rewrite_tensorize.cc | 105 ++++ .../schedule_rule/multi_level_tiling.cc | 25 +- .../schedule_rule/multi_level_tiling.h | 30 ++ .../multi_level_tiling_with_intrin.cc | 79 +++ .../test_meta_schedule_auto_tensorize.py | 347 ++++++++++++ ...eta_schedule_postproc_rewrite_tensorize.py | 509 ++++++++++++++++++ ...hedule_schedule_rule_multi_level_tiling.py | 263 ++++++++- 15 files changed, 1457 insertions(+), 27 deletions(-) create mode 100644 python/tvm/meta_schedule/postproc/rewrite_tensorize.py create mode 100644 src/meta_schedule/postproc/rewrite_tensorize.cc create mode 100644 src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc create mode 100644 tests/python/integration/test_meta_schedule_auto_tensorize.py create mode 100644 tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h index b35d725cfd40..8b32ce460933 100644 --- a/include/tvm/meta_schedule/postproc.h +++ b/include/tvm/meta_schedule/postproc.h @@ -149,10 +149,12 @@ class Postproc : public runtime::ObjectRef { */ TVM_DLL static Postproc RewriteUnboundBlock(int max_threadblock); /*! - * \brief Create a postprocessor that tensorize Tensor Core related components + * \brief Create a postprocessor that applies tensorization to annotated blocks + * \param vectorize_init_loop Whether or not vectorize the initialization loop produced by + * DecomposeReduction * \return The postprocessor created. */ - TVM_DLL static Postproc RewriteTensorCore(); + TVM_DLL static Postproc RewriteTensorize(bool vectorize_init_loop = false); /*! * \brief Creates a postprocessor that verifies if the GPU code is correct diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 1675bcce05ed..2b2eefeb7574 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -150,6 +150,29 @@ class ScheduleRule : public runtime::ObjectRef { Optional> vector_load_lens, // Optional> reuse_read, // Optional> reuse_write); + + /*! + * \brief Extension of MultiLevelTiling for auto-tensorizing with a single intrinsic. + * \param intrin_name The name of a tensor intrinsic, must be registerd via + * TensorIntrin.register(...) beforehand + * \param structure The tiling structure. Recommended: + * - 'SSRSRS' on CPU + * - 'SSSRRSRS' on GPU + * \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended: + * - NullOpt on CPU + * - [blockIdx.x, vthread.x, threadIdx.x] on GPU + * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit + * \param vector_load_lens The length of vector lane in vectorized cooperative fetching. + * NullOpt means disable vectorization + * \param reuse_read Data reuse configuration for reading. NullOpt means no reuse. + * \param reuse_write Data reuse configuration for writing. NullOpt means no reuse. + * \return The schedule rule created + */ + TVM_DLL static ScheduleRule MultiLevelTilingWithIntrin( + String intrin_name, String structure, Optional> tile_binds, + Optional max_innermost_factor, Optional> vector_load_lens, + Optional> reuse_read, Optional> reuse_write); + /*! * \brief Create a rule: add-rfactor to some blocks if needed * \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 6cdd6499c821..48cac6d8d057 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1509,6 +1509,11 @@ constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_expl /*! \brief Mark auto-unroll setting on the block. */ constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit"; +/*! + * \brief Mark that a block should be further rewritten using tensorization. + */ +constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize"; + /*! * \brief Check if attr_key is a pragma key extension * \param attr_key The attr key to be compared diff --git a/python/tvm/meta_schedule/postproc/__init__.py b/python/tvm/meta_schedule/postproc/__init__.py index 96361e739186..39113bb90011 100644 --- a/python/tvm/meta_schedule/postproc/__init__.py +++ b/python/tvm/meta_schedule/postproc/__init__.py @@ -22,3 +22,4 @@ from .rewrite_reduction_block import RewriteReductionBlock from .rewrite_unbound_block import RewriteUnboundBlock from .verify_gpu_code import VerifyGPUCode +from .rewrite_tensorize import RewriteTensorize diff --git a/python/tvm/meta_schedule/postproc/rewrite_tensorize.py b/python/tvm/meta_schedule/postproc/rewrite_tensorize.py new file mode 100644 index 000000000000..85075c41b43c --- /dev/null +++ b/python/tvm/meta_schedule/postproc/rewrite_tensorize.py @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that tensorize related components.""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc + + +@register_object("meta_schedule.RewriteTensorize") +class RewriteTensorize(Postproc): + """A postprocessor that applies tensorization to annotated blocks. + + Parameters + ---------- + vectorize_init_loop : bool + Whether or not vectorize the initialization loop produced by DecomposeReduction + """ + + def __init__(self, vectorize_init_loop=False) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocRewriteTensorize, # type: ignore # pylint: disable=no-member + vectorize_init_loop, + ) diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py b/python/tvm/meta_schedule/schedule_rule/__init__.py index f03c6de3df4b..a958fdc39db1 100644 --- a/python/tvm/meta_schedule/schedule_rule/__init__.py +++ b/python/tvm/meta_schedule/schedule_rule/__init__.py @@ -22,7 +22,7 @@ from .add_rfactor import AddRFactor from .auto_inline import AutoInline from .cross_thread_reduction import CrossThreadReduction -from .multi_level_tiling import MultiLevelTiling, ReuseType +from .multi_level_tiling import MultiLevelTiling, MultiLevelTilingWithIntrin, ReuseType from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll from .random_compute_location import RandomComputeLocation from .schedule_rule import PyScheduleRule, ScheduleRule diff --git a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py index 2ff49168d0c6..0bad6cbb4cd5 100644 --- a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py +++ b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py @@ -82,3 +82,52 @@ def __init__( reuse_read.as_dict() if reuse_read is not None else None, reuse_write.as_dict() if reuse_write is not None else None, ) + + +@register_object("meta_schedule.MultiLevelTilingWithIntrin") +class MultiLevelTilingWithIntrin(ScheduleRule): + """Extension of MultiLevelTiling for auto-tensorizing with a single intrinsic. + + Parameters + ---------- + intrin_name : str + The name of a tensor intrinsic, must be registerd via TensorIntrin.register(...) beforehand + structure : str + The tiling structure. Recommended: + - 'SSRSRS' on CPU + - 'SSSRRSRS' on GPU + tile_bind : Optional[List[str]] + For each level of tiles, which thread axis it is bound to. Recommended: + - None on CPU + - [blockIdx.x, vthread.x, threadIdx.x] on GPU + max_innermost_factor : Optional[int] + The maximum size of the innermost factor. None means no limit + vector_load_lens : Optional[List[int]] + The length of vector lane in vectorized cooperative fetching. + None means disable vectorization + reuse_read : Optional[ReuseType] + Data reuse configuration for reading. None means no reuse. + reuse_write : Optional[ReuseType] + Data reuse configuration for writing. None means no reuse. + """ + + def __init__( + self, + intrin_name: str, + structure: str, + tile_binds: Optional[List[str]] = None, + max_innermost_factor: Optional[int] = None, + vector_load_lens: Optional[List[int]] = None, + reuse_read: Optional[ReuseType] = None, + reuse_write: Optional[ReuseType] = None, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleMultiLevelTilingWithIntrin, # type: ignore # pylint: disable=no-member + intrin_name, + structure, + tile_binds, + max_innermost_factor, + vector_load_lens, + reuse_read.as_dict() if reuse_read is not None else None, + reuse_write.as_dict() if reuse_write is not None else None, + ) diff --git a/python/tvm/meta_schedule/testing/te_workload.py b/python/tvm/meta_schedule/testing/te_workload.py index 49a60a27526a..52f5f49b0a12 100644 --- a/python/tvm/meta_schedule/testing/te_workload.py +++ b/python/tvm/meta_schedule/testing/te_workload.py @@ -607,7 +607,7 @@ def f_compute(i, j): def matmul_relu(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: a = te.placeholder((n, k), name="A") - b = te.placeholder((m, k), name="B") + b = te.placeholder((k, m), name="B") k = te.reduce_axis((0, k), name="k") c = te.compute( (n, m), diff --git a/src/meta_schedule/postproc/rewrite_tensorize.cc b/src/meta_schedule/postproc/rewrite_tensorize.cc new file mode 100644 index 000000000000..1ad394e49c59 --- /dev/null +++ b/src/meta_schedule/postproc/rewrite_tensorize.cc @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include + +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +using tir::BlockRV; +using tir::LoopRV; + +void ApplyTensorization(const tir::Schedule& sch, const String& func_name, + const tir::PrimFuncNode* func, bool vectorize_init_loop) { + std::vector>> jobs; + + tir::PostOrderVisit(func->body, [=, &jobs](const ObjectRef& obj) { + if (const auto* block = obj.as()) { + tir::StmtSRef block_sref = sch->GetSRef(block); + if (Optional intrin_name = + tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize)) { + std::string block_name = block_sref->StmtAs()->name_hint; + if (block_name.find("init") == std::string::npos) { + jobs.emplace_back(block_name, [sch, intrin_name](tir::BlockRV block) { + try { + sch->Tensorize(block, intrin_name.value()); + } catch (const std::exception& e) { + LOG(WARNING) << "Tensorize failed with error " << e.what(); + } + }); + } else if (vectorize_init_loop) { + jobs.emplace_back(block_name, [sch](tir::BlockRV block) { + Array child_blocks = sch->GetChildBlocks(block); + ICHECK(child_blocks.size() == 1); + Array init_loops = sch->GetLoops(child_blocks[0]); + ICHECK(init_loops.size() == 1); + sch->Vectorize(init_loops[0]); + }); + } + } + } + }); + + for (auto kv : jobs) { + tir::BlockRV block = sch->GetBlock(kv.first, func_name); + sch->Unannotate(block, tir::attr::meta_schedule_auto_tensorize); + kv.second(block); + } +} + +class RewriteTensorizeNode : public PostprocNode { + public: + void InitializeWithTuneContext(const TuneContext& context) final {} + + bool Apply(const tir::Schedule& sch) final; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + bool vectorize_init_loop = false; + + static constexpr const char* _type_key = "meta_schedule.RewriteTensorize"; + TVM_DECLARE_FINAL_OBJECT_INFO(RewriteTensorizeNode, PostprocNode); +}; + +bool RewriteTensorizeNode::Apply(const tir::Schedule& sch) { + for (const auto& kv : sch->mod()->functions) { + GlobalVar g_var = kv.first; + BaseFunc base_func = kv.second; + if (const tir::PrimFuncNode* prim_func = base_func.as()) { + ApplyTensorization(sch, g_var->name_hint, prim_func, vectorize_init_loop); + } + } + return true; +} + +Postproc Postproc::RewriteTensorize(bool vectorize_init_loop) { + ObjectPtr n = make_object(); + n->vectorize_init_loop = vectorize_init_loop; + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(RewriteTensorizeNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteTensorize") + .set_body_typed(Postproc::RewriteTensorize); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index 6b18b17867dc..0a3ea882b5eb 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -260,28 +260,9 @@ ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional> vector_load_lens, Optional> reuse_read, Optional> reuse_write) { - ObjectPtr n = make_object(); - n->structure = structure; - n->tile_binds = tile_binds.value_or({}); - n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value; - n->vector_load_lens = vector_load_lens.defined() - ? support::AsVector(vector_load_lens.value()) - : std::vector(); - n->reuse_read_ = reuse_read.defined() ? ReuseConfig(reuse_read.value()) : ReuseConfig(); - n->reuse_write_ = reuse_write.defined() ? ReuseConfig(reuse_write.value()) : ReuseConfig(); - for (int i = 0, len = structure.size(); i < len; ++i) { - char c = structure.data()[i]; - if (c == 'S') { - n->s_indices_.push_back(i); - } else if (c == 'R') { - n->r_indices_.push_back(i); - } else { - LOG(FATAL) << "ValueError: Invalid tiling structure: " << structure; - } - } - n->thread_warp_size_ = -1; - n->max_threads_per_block_ = -1; - return ScheduleRule(n); + auto node = MultiLevelTilingInitCommon( + structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write); + return ScheduleRule(node); } TVM_REGISTER_NODE_TYPE(MultiLevelTilingNode); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h index b7712b5c1989..f260c4856e36 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h @@ -181,6 +181,36 @@ class MultiLevelTilingNode : public ScheduleRuleNode { TVM_DECLARE_BASE_OBJECT_INFO(MultiLevelTilingNode, ScheduleRuleNode); }; +template +ObjectPtr MultiLevelTilingInitCommon(String structure, Optional> tile_binds, + Optional max_innermost_factor, + Optional> vector_load_lens, + Optional> reuse_read, + Optional> reuse_write) { + ObjectPtr n = make_object(); + n->structure = structure; + n->tile_binds = tile_binds.value_or({}); + n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value; + n->vector_load_lens = vector_load_lens.defined() + ? support::AsVector(vector_load_lens.value()) + : std::vector(); + n->reuse_read_ = reuse_read.defined() ? ReuseConfig(reuse_read.value()) : ReuseConfig(); + n->reuse_write_ = reuse_write.defined() ? ReuseConfig(reuse_write.value()) : ReuseConfig(); + for (int i = 0, len = structure.size(); i < len; ++i) { + char c = structure.data()[i]; + if (c == 'S') { + n->s_indices_.push_back(i); + } else if (c == 'R') { + n->r_indices_.push_back(i); + } else { + LOG(FATAL) << "ValueError: Invalid tiling structure: " << structure; + } + } + n->thread_warp_size_ = -1; + n->max_threads_per_block_ = -1; + return n; +} + } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc new file mode 100644 index 000000000000..da3ea2484e6e --- /dev/null +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "../../tir/schedule/transform.h" +#include "../utils.h" +#include "multi_level_tiling.h" + +namespace tvm { +namespace meta_schedule { + +/*! + * \brief Tile a subset of loops in the block according to the given tensor intrinsic, and annotate + * the tiled block for tensorization by postproc rewrite. + */ +tir::BlockRV TileForIntrin(tir::Schedule sch, tir::BlockRV block, const std::string& intrin_name) { + Optional tiled_loop_rv = TileWithTensorIntrin(sch, block, intrin_name); + ICHECK(tiled_loop_rv.defined()); + tir::BlockRV outer_block = sch->Blockize(tiled_loop_rv.value()); + sch->Annotate(outer_block, tir::attr::meta_schedule_auto_tensorize, String(intrin_name)); + return outer_block; +} + +/*! + * \brief Extension of MultiLevelTiling for auto-tensorizing with a single intrinsic. + */ +class MultiLevelTilingWithIntrinNode : public MultiLevelTilingNode { + protected: + // Override ApplySubRules to tile the inner loops according to the given tensor intrinsic, then + // tile the outerloops. + virtual std::vector ApplySubRules(std::vector states) { + states = SubRule(std::move(states), [&](State state) { + state.block_rv = TileForIntrin(state.sch, state.block_rv, intrin_name); + return std::vector(1, state); + }); + return MultiLevelTilingNode::ApplySubRules(states); + } + + public: + /*! \brief The name of a tensor intrinsic. */ + String intrin_name; + + static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingWithIntrin"; + TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingWithIntrinNode, MultiLevelTilingNode); +}; + +ScheduleRule ScheduleRule::MultiLevelTilingWithIntrin( + String intrin_name, String structure, Optional> tile_binds, + Optional max_innermost_factor, Optional> vector_load_lens, + Optional> reuse_read, Optional> reuse_write) { + ICHECK(tir::TensorIntrin::Get(intrin_name).defined()) + << "Provided tensor intrinsic " << intrin_name << " is not registered."; + auto node = MultiLevelTilingInitCommon( + structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write); + node->intrin_name = intrin_name; + return ScheduleRule(node); +} + +TVM_REGISTER_NODE_TYPE(MultiLevelTilingWithIntrinNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTilingWithIntrin") + .set_body_typed(ScheduleRule::MultiLevelTilingWithIntrin); + +} // namespace meta_schedule +} // namespace tvm diff --git a/tests/python/integration/test_meta_schedule_auto_tensorize.py b/tests/python/integration/test_meta_schedule_auto_tensorize.py new file mode 100644 index 000000000000..511e75723b03 --- /dev/null +++ b/tests/python/integration/test_meta_schedule_auto_tensorize.py @@ -0,0 +1,347 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +from tvm import relay +import tvm.testing +import numpy as np +from tvm.meta_schedule.tune import tune_extracted_tasks +from tvm.meta_schedule.relay_integration import extract_task_from_relay +from tvm.meta_schedule import ApplyHistoryBest +from tvm.meta_schedule import schedule_rule, postproc +from tvm.meta_schedule.testing.tlcbench import load_quantized_bert_base +from tvm import meta_schedule as ms +from tvm.tir.tensor_intrin import ( + VNNI_DOT_16x4_INTRIN as VNNI_INTRIN, + DP4A_INTRIN, + AMDGPU_SDOT4_INTRIN, +) +import tempfile +import tvm.topi.testing + + +config = ms.TuneConfig( + strategy="evolutionary", + num_trials_per_iter=32, + max_trials_per_task=32, + max_trials_global=20000, +) + +sch_rules_for_vnni = [ + schedule_rule.AutoInline( + into_producer=False, + into_consumer=True, + inline_const_tensor=True, + disallow_if_then_else=True, + require_injective=True, + require_ordered=True, + disallow_op=["tir.exp"], + ), + schedule_rule.AddRFactor(max_jobs_per_core=16, max_innermost_factor=64), + schedule_rule.MultiLevelTilingWithIntrin( + VNNI_INTRIN, + structure="SSRSRS", + tile_binds=None, + max_innermost_factor=64, + vector_load_lens=None, + reuse_read=None, + reuse_write=schedule_rule.ReuseType( + req="may", + levels=[1, 2], + scope="global", + ), + ), + schedule_rule.ParallelizeVectorizeUnroll( + max_jobs_per_core=16, + max_vectorize_extent=64, + unroll_max_steps=[0, 16, 64, 512], + unroll_explicit=True, + ), + schedule_rule.RandomComputeLocation(), +] + + +def get_sch_rules_for_dp4a(intrin): + return [ + schedule_rule.MultiLevelTilingWithIntrin( + intrin, + structure="SSSRRSRS", + tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"], + max_innermost_factor=64, + vector_load_lens=[1, 2, 3, 4], + reuse_read=schedule_rule.ReuseType( + req="must", + levels=[4], + scope="shared", + ), + reuse_write=schedule_rule.ReuseType( + req="must", + levels=[3], + scope="local", + ), + ), + schedule_rule.AutoInline( + into_producer=True, + into_consumer=True, + inline_const_tensor=True, + disallow_if_then_else=False, + require_injective=False, + require_ordered=False, + disallow_op=None, + ), + schedule_rule.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]), + schedule_rule.ParallelizeVectorizeUnroll( + max_jobs_per_core=-1, # disable parallelize + max_vectorize_extent=-1, # disable vectorize + unroll_max_steps=[0, 16, 64, 512, 1024], + unroll_explicit=True, + ), + ] + + +sch_rules_for_dp4a = get_sch_rules_for_dp4a(DP4A_INTRIN) +sch_rules_for_sdot4 = get_sch_rules_for_dp4a(AMDGPU_SDOT4_INTRIN) + +postprocs_for_vnni = [ + postproc.DisallowDynamicLoop(), + postproc.RewriteParallelVectorizeUnroll(), + postproc.RewriteReductionBlock(), + postproc.RewriteTensorize(vectorize_init_loop=True), +] + +postprocs_for_dp4a = [ + postproc.DisallowDynamicLoop(), + postproc.RewriteCooperativeFetch(), + postproc.RewriteUnboundBlock(), + postproc.RewriteParallelVectorizeUnroll(), + postproc.RewriteReductionBlock(), + postproc.RewriteTensorize(), + postproc.VerifyGPUCode(), +] + + +def tune_and_test(relay_mod, data_np, weight_np, op_name, target, sch_rules, postprocs): + tgt = "cuda" if "nvidia" in target else target + dev = tvm.device(tgt, 0) + + ref = ( + relay.create_executor("vm", mod=relay_mod, device=dev, target=tgt) + .evaluate()(*[data_np, weight_np]) + .numpy() + ) + + params = {"weight": weight_np} + + extracted_tasks = extract_task_from_relay(relay_mod, target, params) + + tune_tasks = list( + filter( + lambda task: op_name in task.task_name, + extracted_tasks, + ) + ) + + with tempfile.TemporaryDirectory() as work_dir: + database = tune_extracted_tasks( + tune_tasks, + config, + work_dir=work_dir, + sch_rules=lambda: sch_rules, + postprocs=lambda: postprocs, + ) + + with ApplyHistoryBest(database): + with tvm.transform.PassContext( + opt_level=3, + config={"relay.backend.use_meta_schedule": True}, + ): + lib = relay.build(relay_mod, target=target, params=params) + + if "cascadelake" in target: + asm = lib.lib.get_source("asm") + assert "vpdpbusd" in asm + + runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + + runtime.set_input("data", data_np) + runtime.run() + + out = runtime.get_output(0).numpy() + + np.testing.assert_equal(out, ref) + + +def _test_dense(data_dtype, sch_rules, postprocs, target): + M, N, K = 1024, 1024, 1024 + data_shape = (M, K) + weight_shape = (N, K) + + weight_dtype = "int8" + out_dtype = "int32" + + data = relay.var("data", shape=data_shape, dtype=data_dtype) + weight = relay.var("weight", shape=weight_shape, dtype=weight_dtype) + dense = relay.nn.dense(data, weight, out_dtype=out_dtype) + + relay_mod = tvm.IRModule.from_expr(dense) + + data_np = np.random.uniform(1, 10, size=data_shape).astype(data_dtype) + weight_np = np.random.uniform(1, 10, size=weight_shape).astype(weight_dtype) + + tune_and_test(relay_mod, data_np, weight_np, "dense", target, sch_rules, postprocs) + + +def _test_conv2d(data_dtype, sch_rules, postprocs, target): + d_shape = (1, 64, 56, 56) + w_shape = (64, 64, 3, 3) + + weight_dtype = "int8" + out_dtype = "int32" + + data = relay.var("data", shape=d_shape, dtype=data_dtype) + weight = relay.var("weight", shape=w_shape, dtype=weight_dtype) + out_channel = w_shape[0] + conv2d = relay.nn.conv2d( + data=data, + weight=weight, + kernel_size=w_shape[2:], + channels=out_channel, + padding=(1, 1), + strides=(1, 1), + out_dtype=out_dtype, + ) + + relay_mod = tvm.IRModule.from_expr(conv2d) + + data_np = np.random.uniform(1, 10, d_shape).astype(data_dtype) + weight_np = np.random.uniform(1, 10, size=w_shape).astype("int8") + + tune_and_test(relay_mod, data_np, weight_np, "conv2d", target, sch_rules, postprocs) + + +def _test_bert_int8(target, sch_rules, postprocs): + relay_mod, params, input_info = load_quantized_bert_base() + + relay_mod = relay.transform.FastMath()(relay_mod) + + extracted_tasks = extract_task_from_relay(relay_mod, target, params) + + tune_tasks = [] + + for task in filter( + lambda task: "dense" in task.task_name or "batch_matmul" in task.task_name, + extracted_tasks, + ): + relay_func = list(task.mod.functions.values())[0] + out_type = relay_func.body.checked_type + + if out_type.dtype != "float32": + tune_tasks.append(task) + + with tempfile.TemporaryDirectory() as work_dir: + database = tune_extracted_tasks( + tune_tasks, + config, + work_dir=work_dir, + sch_rules=lambda: sch_rules, + postprocs=lambda: postprocs, + ) + + with ApplyHistoryBest(database): + with tvm.transform.PassContext( + opt_level=3, + config={"relay.backend.use_meta_schedule": True}, + ): + lib = relay.build(relay_mod, target=target, params=params) + + dev = tvm.device("cuda" if "nvidia" in target else target, 0) + runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + + inputs = [] + + for name, shape in input_info: + arr = np.random.uniform(1, 10, size=shape).astype("int64") + runtime.set_input(name, arr) + inputs.append(arr) + + print(runtime.benchmark(dev, number=1, repeat=50).mean) + + +@pytest.mark.skip("Requires cascadelake") +def test_vnni_dense(): + _test_dense( + "uint8", sch_rules_for_vnni, postprocs_for_vnni, "llvm -mcpu=cascadelake -num-cores 4" + ) + + +@pytest.mark.skip("Only tested locally on sm_86 (for cuda) which is not supported by CI") +@tvm.testing.requires_gpu +def test_dp4a_dense(): + _test_dense("int8", sch_rules_for_dp4a, postprocs_for_dp4a, "nvidia/geforce-rtx-3070") + + # Uncomment to test on vulkan or rocm target + # _test_dense( + # "int8", sch_rules_for_dp4a, postprocs_for_dp4a, "vulkan -from_device=0" + # ) + # _test_dense( + # "int8", sch_rules_for_sdot4, postprocs_for_dp4a, "rocm" + # ) + + +@pytest.mark.skip("Requires cascadelake") +def test_vnni_conv2d(): + _test_conv2d( + "uint8", sch_rules_for_vnni, postprocs_for_vnni, "llvm -mcpu=cascadelake -num-cores 4" + ) + + +@pytest.mark.skip("Only tested locally on sm_86 (for cuda) which is not supported by CI") +@tvm.testing.requires_gpu +def test_dp4a_conv2d(): + _test_conv2d("int8", sch_rules_for_dp4a, postprocs_for_dp4a, "nvidia/geforce-rtx-3070") + + # Uncomment to test on vulkan or rocm target + # _test_conv2d( + # "int8", sch_rules_for_dp4a, postprocs_for_dp4a, "vulkan -from_device=0" + # ) + # _test_conv2d( + # "int8", sch_rules_for_sdot4, postprocs_for_dp4a, "rocm" + # ) + + +@pytest.mark.skip("Requires cascadelake") +def test_vnni_bert_int8(): + _test_bert_int8("llvm -mcpu=cascadelake -num-cores 4", sch_rules_for_vnni, postprocs_for_vnni) + + +@tvm.testing.requires_gpu +@pytest.mark.skip("Slow on CI") +def test_dp4a_bert_int8(): + _test_bert_int8("nvidia/geforce-rtx-3070", sch_rules_for_dp4a, postprocs_for_dp4a) + + # Uncomment to test on vulkan or rocm target + # _test_bert_int8("vulkan -from_device=0", sch_rules_for_dp4a, postprocs_for_dp4a) + # _test_bert_int8("rocm", sch_rules_for_sdot4, postprocs_for_dp4a) + + +if __name__ == "__main__": + test_vnni_dense() + test_vnni_conv2d() + test_vnni_bert_int8() + test_dp4a_dense() + test_dp4a_conv2d() + test_dp4a_bert_int8() diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py new file mode 100644 index 000000000000..bc84fb1ad0b2 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py @@ -0,0 +1,509 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import tvm +import tvm.tir.tensor_intrin +from tvm.script import tir as T +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule import postproc + + +@tvm.script.ir_module +class Conv2dNCHWcVNNIModuleTiled: + @T.prim_func + def main( + placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], + placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], + conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"], + ) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for ( + i0_0, + i1_0, + i2_0, + i3_0, + i4_0_0, + i0_1, + i1_1, + i2_1, + i3_1, + i4_0_1, + i5_0, + i6_0, + i7_0, + i8_0, + i9_0_0, + i0_2, + i1_2, + i2_2, + i3_2, + i4_0_2, + i5_1, + i6_1, + i7_1, + i8_1, + i9_0_1, + i0_3, + i1_3, + i2_3, + i3_3, + i4_0_3, + ) in T.grid( + 1, + 1, + 2, + 1, + 1, + 1, + 4, + 1, + 14, + 1, + 1, + 1, + 4, + 1, + 1, + 1, + 4, + 7, + 1, + 1, + 1, + 1, + 1, + 4, + 1, + 1, + 1, + 4, + 4, + 1, + ): + with T.block("conv2d_NCHWc_int8_o"): + n = T.axis.spatial(1, 0) + oc_chunk = T.axis.spatial(16, i1_1 * 4 + i1_2) + oh = T.axis.spatial(56, i2_0 * 28 + i2_2 * 4 + i2_3) + ow = T.axis.spatial(56, i3_1 * 4 + i3_3) + oc_block_o = T.axis.spatial(1, 0) + kh = T.axis.reduce(1, 0) + kw = T.axis.reduce(1, 0) + ic_outer, ic_f_inner = T.axis.remap("RR", [i7_0, i8_1]) + ic_s_inner_o = T.axis.reduce(1, 0) + T.reads( + placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], + placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:16, 0:4], + ) + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "dot_16x4_vnni"}) + with T.init(): + for i4_1 in T.serial(16): + with T.block("conv2d_NCHWc_int8_init"): + oc_block_init = T.axis.spatial(16, i4_1) + T.reads() + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_init]) + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_init] = 0 + for i4_1, i9_1 in T.grid(16, 4): + with T.block("conv2d_NCHWc_int8"): + oc_block, ic_s_inner = T.axis.remap("SR", [i4_1, i9_1]) + T.reads( + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block], + placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], + placeholder_1[ + oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner + ], + ) + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[ + n, oc_chunk, oh, ow, oc_block + ] + T.cast( + placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], + "int32", + ) * T.cast( + placeholder_1[ + oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner + ], + "int32", + ) + + +@tvm.script.ir_module +class Conv2dNCHWcVNNIModuleTensorized: + @T.prim_func + def main( + placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], + placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], + conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + for i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1, i5_0, i6_0 in T.grid( + 1, 1, 2, 1, 1, 1, 4, 1, 14, 1, 1, 1 + ): + for i1_2_init, i2_2_init, i2_3_init, i3_3_init in T.grid(4, 7, 4, 4): + with T.block("conv2d_NCHWc_int8_o_init"): + n = T.axis.spatial(1, 0) + oc_chunk = T.axis.spatial(16, i1_1 * 4 + i1_2_init) + oh = T.axis.spatial(56, i2_0 * 28 + i2_2_init * 4 + i2_3_init) + ow = T.axis.spatial(56, i3_1 * 4 + i3_3_init) + oc_block_o = T.axis.spatial(1, 0) + T.reads() + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16]) + for i4_1 in T.vectorized(16): + with T.block("conv2d_NCHWc_int8_init"): + oc_block_init = T.axis.spatial(16, i4_1) + T.reads() + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_init]) + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_init] = 0 + for ( + i7_0, + i8_0, + i9_0_0, + i0_2, + i1_2, + i2_2, + i3_2, + i4_0_2, + i5_1, + i6_1, + i7_1, + i8_1, + i9_0_1, + i0_3, + i1_3, + i2_3, + i3_3, + i4_0_3, + ) in T.grid(4, 1, 1, 1, 4, 7, 1, 1, 1, 1, 1, 4, 1, 1, 1, 4, 4, 1): + with T.block("conv2d_NCHWc_int8_o_update"): + n = T.axis.spatial(1, 0) + oc_chunk = T.axis.spatial(16, i1_1 * 4 + i1_2) + oh = T.axis.spatial(56, i2_0 * 28 + i2_2 * 4 + i2_3) + ow = T.axis.spatial(56, i3_1 * 4 + i3_3) + oc_block_o = T.axis.spatial(1, 0) + kh = T.axis.reduce(1, 0) + kw = T.axis.reduce(1, 0) + ic_outer, ic_f_inner = T.axis.remap("RR", [i7_0, i8_1]) + ic_s_inner_o = T.axis.reduce(1, 0) + T.reads( + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16], + placeholder[ + n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4 + ], + placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:16, 0:4], + ) + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16]) + A = T.match_buffer( + placeholder[ + n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4 + ], + [4], + dtype="uint8", + offset_factor=1, + ) + B = T.match_buffer( + placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:16, 0:4], + [16, 4], + dtype="int8", + offset_factor=1, + ) + C = T.match_buffer( + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16], + [16], + dtype="int32", + offset_factor=1, + ) + A_u8x4 = A.vload([0], "uint8x4") + A_i32 = T.reinterpret(A_u8x4, dtype="int32") + B_i8x64 = B.vload([0, 0], dtype="int8x64") + B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16") + C[T.ramp(0, 1, 16)] = C[T.ramp(0, 1, 16)] + T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"), + T.uint32(0), + T.broadcast(0, 16), + T.broadcast(A_i32, 16), + B_i32x16, + dtype="int32x16", + ) + + +@tvm.script.ir_module +class DenseDP4ATiled: + @T.prim_func + def main( + X: T.Buffer[(128, 128), "int8"], + W: T.Buffer[(128, 128), "int8"], + compute: T.Buffer[(128, 128), "int32"], + ) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + compute_local = T.alloc_buffer([128, 128], dtype="int32", scope="local") + X_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared") + W_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared") + for i0_0_i1_0_fused in T.thread_binding(16, thread="blockIdx.x"): + for i0_1_i1_1_fused in T.thread_binding(2, thread="vthread.x"): + for i0_2_i1_2_fused in T.thread_binding(2, thread="threadIdx.x"): + for i2_0_0 in T.serial(2): + for ax0_ax1_fused in T.serial(1024): + with T.block("X_shared"): + v0 = T.axis.spatial( + 128, i0_0_i1_0_fused // 2 * 16 + ax0_ax1_fused // 64 + ) + v1 = T.axis.spatial(128, i2_0_0 * 64 + ax0_ax1_fused % 64) + T.reads(X[v0, v1]) + T.writes(X_shared[v0, v1]) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) + X_shared[v0, v1] = X[v0, v1] + for ax0_ax1_fused in T.serial(4096): + with T.block("W_shared"): + v0 = T.axis.spatial( + 128, i0_0_i1_0_fused % 2 * 64 + ax0_ax1_fused // 64 + ) + v1 = T.axis.spatial(128, i2_0_0 * 64 + ax0_ax1_fused % 64) + T.reads(W[v0, v1]) + T.writes(W_shared[v0, v1]) + T.block_attr({"meta_schedule.cooperative_fetch": 1}) + W_shared[v0, v1] = W[v0, v1] + for i2_0_1, i0_3, i1_3, i2_0_2, i0_4, i1_4 in T.grid(2, 4, 16, 8, 4, 1): + with T.block("compute_o"): + i = T.axis.spatial(128, i0_0_i1_0_fused // 2 * 16 + i0_3 * 4 + i0_4) + j = T.axis.spatial( + 128, + i0_0_i1_0_fused % 2 * 64 + + i0_1_i1_1_fused * 32 + + i0_2_i1_2_fused * 16 + + i1_3, + ) + k_o = T.axis.reduce(32, i2_0_0 * 16 + i2_0_1 * 8 + i2_0_2) + T.reads( + X_shared[i, k_o * 4 : k_o * 4 + 4], + W_shared[j, k_o * 4 : k_o * 4 + 4], + ) + T.writes(compute_local[i, j]) + T.block_attr({"meta_schedule.auto_tensorize": "dp4a"}) + with T.init(): + with T.block("compute_init"): + T.reads() + T.writes(compute_local[i, j]) + compute_local[i, j] = 0 + for i2_1 in T.serial(4): + with T.block("compute"): + k = T.axis.reduce(4, i2_1) + T.reads( + compute_local[i, j], + X_shared[i, k_o * 4 + k], + W_shared[j, k_o * 4 + k], + ) + T.writes(compute_local[i, j]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + compute_local[i, j] = compute_local[i, j] + T.cast( + X_shared[i, k_o * 4 + k], "int32" + ) * T.cast(W_shared[j, k_o * 4 + k], "int32") + for ax0, ax1 in T.grid(16, 16): + with T.block("compute_local"): + v0 = T.axis.spatial(128, i0_0_i1_0_fused // 2 * 16 + ax0) + v1 = T.axis.spatial( + 128, + i0_0_i1_0_fused % 2 * 64 + + i0_1_i1_1_fused * 32 + + i0_2_i1_2_fused * 16 + + ax1, + ) + T.reads(compute_local[v0, v1]) + T.writes(compute[v0, v1]) + compute[v0, v1] = compute_local[v0, v1] + + +@tvm.script.ir_module +class DenseDP4ATensorized: + @T.prim_func + def main( + X: T.Buffer[(128, 128), "int8"], + W: T.Buffer[(128, 128), "int8"], + compute: T.Buffer[(128, 128), "int32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + compute_local = T.alloc_buffer([128, 128], dtype="int32", scope="local") + X_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared") + W_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared") + for i0_0_i1_0_fused in T.thread_binding(16, thread="blockIdx.x"): + for i0_1_i1_1_fused in T.thread_binding(2, thread="vthread.x"): + for i0_2_i1_2_fused in T.thread_binding(2, thread="threadIdx.x"): + for i0_3_init, i1_3_init, i0_4_init in T.grid(4, 16, 4): + with T.block("compute_o_init"): + i = T.axis.spatial( + 128, i0_0_i1_0_fused // 2 * 16 + i0_3_init * 4 + i0_4_init + ) + j = T.axis.spatial( + 128, + i0_0_i1_0_fused % 2 * 64 + + i0_1_i1_1_fused * 32 + + i0_2_i1_2_fused * 16 + + i1_3_init, + ) + T.reads() + T.writes(compute_local[i, j]) + T.block_attr({"meta_schedule.auto_tensorize": "dp4a"}) + with T.block("compute_init"): + T.reads() + T.writes(compute_local[i, j]) + compute_local[i, j] = 0 + for i2_0_0 in T.serial(2): + for ax0_ax1_fused in T.serial(1024): + with T.block("X_shared"): + v0 = T.axis.spatial( + 128, i0_0_i1_0_fused // 2 * 16 + ax0_ax1_fused // 64 + ) + v1 = T.axis.spatial(128, i2_0_0 * 64 + ax0_ax1_fused % 64) + T.reads(X[v0, v1]) + T.writes(X_shared[v0, v1]) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) + X_shared[v0, v1] = X[v0, v1] + for ax0_ax1_fused in T.serial(4096): + with T.block("W_shared"): + v0 = T.axis.spatial( + 128, i0_0_i1_0_fused % 2 * 64 + ax0_ax1_fused // 64 + ) + v1 = T.axis.spatial(128, i2_0_0 * 64 + ax0_ax1_fused % 64) + T.reads(W[v0, v1]) + T.writes(W_shared[v0, v1]) + T.block_attr({"meta_schedule.cooperative_fetch": 1}) + W_shared[v0, v1] = W[v0, v1] + for i2_0_1, i0_3, i1_3, i2_0_2, i0_4, i1_4 in T.grid(2, 4, 16, 8, 4, 1): + with T.block("compute_o_update"): + i = T.axis.spatial(128, i0_0_i1_0_fused // 2 * 16 + i0_3 * 4 + i0_4) + j = T.axis.spatial( + 128, + i0_0_i1_0_fused % 2 * 64 + + i0_1_i1_1_fused * 32 + + i0_2_i1_2_fused * 16 + + i1_3, + ) + k_o = T.axis.reduce(32, i2_0_0 * 16 + i2_0_1 * 8 + i2_0_2) + T.reads( + compute_local[i, j], + X_shared[i, k_o * 4 : k_o * 4 + 4], + W_shared[j, k_o * 4 : k_o * 4 + 4], + ) + T.writes(compute_local[i, j]) + A = T.match_buffer( + X_shared[i, k_o * 4 : k_o * 4 + 4], + [4], + dtype="int8", + scope="shared", + align=4, + offset_factor=1, + ) + B = T.match_buffer( + W_shared[j, k_o * 4 : k_o * 4 + 4], + [4], + dtype="int8", + scope="shared", + align=4, + offset_factor=1, + ) + C = T.match_buffer( + compute_local[i, j], + [1], + dtype="int32", + scope="local", + align=4, + offset_factor=1, + ) + C[0] = C[0] + T.call_pure_extern( + "__dp4a", + A[T.ramp(0, 1, 4)], + B[T.ramp(0, 1, 4)], + 0, + dtype="int32", + ) + for ax0, ax1 in T.grid(16, 16): + with T.block("compute_local"): + v0 = T.axis.spatial(128, i0_0_i1_0_fused // 2 * 16 + ax0) + v1 = T.axis.spatial( + 128, + i0_0_i1_0_fused % 2 * 64 + + i0_1_i1_1_fused * 32 + + i0_2_i1_2_fused * 16 + + ax1, + ) + T.reads(compute_local[v0, v1]) + T.writes(compute[v0, v1]) + compute[v0, v1] = compute_local[v0, v1] + + +def _create_context(mod, target, postprocs): + ctx = TuneContext( + mod=mod, + target=target, + postprocs=postprocs, + task_name="test", + ) + for rule in ctx.postprocs: + rule.initialize_with_tune_context(ctx) + return ctx + + +def test_rewrite_tensorize_conv2d_nchwc_vnni(): + mod = Conv2dNCHWcVNNIModuleTiled + target = tvm.target.Target("llvm -mcpu=cascadelake -num-cores 4") + ctx = _create_context( + mod, + target, + [ + postproc.RewriteReductionBlock(), + postproc.RewriteTensorize(True), + ], + ) + sch = tvm.tir.Schedule(mod, debug_mask="all") + sch.enter_postproc() + + for proc in ctx.postprocs: + proc.apply(sch) + + tvm.ir.assert_structural_equal(sch.mod, Conv2dNCHWcVNNIModuleTensorized) + + +def test_rewrite_tensorize_dense_dp4a(): + mod = DenseDP4ATiled + target = tvm.target.Target("nvidia/geforce-rtx-3070") + ctx = _create_context( + mod, + target, + [ + postproc.RewriteCooperativeFetch(), + postproc.RewriteReductionBlock(), + postproc.RewriteTensorize(), + ], + ) + sch = tvm.tir.Schedule(mod, debug_mask="all") + sch.enter_postproc() + + for proc in ctx.postprocs: + proc.apply(sch) + + tvm.ir.assert_structural_equal(sch.mod, DenseDP4ATensorized) + + +if __name__ == "__main__": + test_rewrite_tensorize_conv2d_nchwc_vnni() + test_rewrite_tensorize_dense_dp4a() diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py index 555a1a8e1f15..43ce9969be84 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring - +import tvm +from tvm import te from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply from tvm.meta_schedule.testing import te_workload from tvm.meta_schedule.testing.schedule_rule import ( @@ -23,9 +24,11 @@ ) from tvm.meta_schedule.testing.space_generation import check_trace from tvm.meta_schedule.tune_context import TuneContext +from tvm.meta_schedule import schedule_rule from tvm.script import tir as T from tvm.te import create_prim_func from tvm.target import Target +from tvm.tir.tensor_intrin import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN, DP4A_INTRIN def _create_context(mod, target, rule) -> TuneContext: @@ -301,9 +304,267 @@ def sum_with_trivial_block_iter( check_trace(spaces, expected) +@tvm.script.ir_module +class Conv2dNCHWcVNNIModule: + @T.prim_func + def main( + placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], + placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], + conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"], + ) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4): + with T.block("conv2d_NCHWc_int8"): + ( + n, + oc_chunk, + oh, + ow, + oc_block, + kh, + kw, + ic_outer, + ic_f_inner, + ic_s_inner, + ) = T.axis.remap("SSSSSRRRRR", [i0, i1, i2, i3, i4, i5, i6, i7, i8, i9]) + T.reads( + placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], + placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], + ) + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block]) + with T.init(): + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0 + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[ + n, oc_chunk, oh, ow, oc_block + ] + T.cast( + placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32" + ) * T.cast( + placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], + "int32", + ) + + +def test_multi_level_tiling_conv2d_nchwc_vnni(): + target = "llvm -mcpu=cascadelake -num-cores 4" + ctx = _create_context( + Conv2dNCHWcVNNIModule, + target=tvm.target.Target(target), + rule=schedule_rule.MultiLevelTilingWithIntrin( + VNNI_INTRIN, + structure="SSRSRS", + tile_binds=None, + max_innermost_factor=64, + vector_load_lens=None, + reuse_read=None, + reuse_write=schedule_rule.ReuseType( + req="may", + levels=[1, 2], + scope="global", + ), + ), + ) + + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + + expected = [ + """b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main") +sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS") +l1, l2, l3, l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0) +l11, l12 = sch.split(loop=l10, factors=[1, 4]) +l13, l14 = sch.split(loop=l5, factors=[1, 16]) +l15, l16, l17, l18, l19, l20, l21, l22, l23, l24, l25, l26 = sch.get_loops(block=b0) +sch.reorder(l21, l22, l23, l24, l25, l14, l12) +b27 = sch.blockize(loop=l14) +sch.annotate(block_or_loop=b27, ann_key="meta_schedule.auto_tensorize", ann_val="dot_16x4_vnni") +l28, l29, l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_loops(block=b27) +v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l28, n=4, max_innermost_factor=64) +l42, l43, l44, l45 = sch.split(loop=l28, factors=[v38, v39, v40, v41]) +v46, v47, v48, v49 = sch.sample_perfect_tile(loop=l29, n=4, max_innermost_factor=64) +l50, l51, l52, l53 = sch.split(loop=l29, factors=[v46, v47, v48, v49]) +v54, v55, v56, v57 = sch.sample_perfect_tile(loop=l30, n=4, max_innermost_factor=64) +l58, l59, l60, l61 = sch.split(loop=l30, factors=[v54, v55, v56, v57]) +v62, v63, v64, v65 = sch.sample_perfect_tile(loop=l31, n=4, max_innermost_factor=64) +l66, l67, l68, l69 = sch.split(loop=l31, factors=[v62, v63, v64, v65]) +v70, v71, v72, v73 = sch.sample_perfect_tile(loop=l32, n=4, max_innermost_factor=64) +l74, l75, l76, l77 = sch.split(loop=l32, factors=[v70, v71, v72, v73]) +v78, v79 = sch.sample_perfect_tile(loop=l33, n=2, max_innermost_factor=64) +l80, l81 = sch.split(loop=l33, factors=[v78, v79]) +v82, v83 = sch.sample_perfect_tile(loop=l34, n=2, max_innermost_factor=64) +l84, l85 = sch.split(loop=l34, factors=[v82, v83]) +v86, v87 = sch.sample_perfect_tile(loop=l35, n=2, max_innermost_factor=64) +l88, l89 = sch.split(loop=l35, factors=[v86, v87]) +v90, v91 = sch.sample_perfect_tile(loop=l36, n=2, max_innermost_factor=64) +l92, l93 = sch.split(loop=l36, factors=[v90, v91]) +v94, v95 = sch.sample_perfect_tile(loop=l37, n=2, max_innermost_factor=64) +l96, l97 = sch.split(loop=l37, factors=[v94, v95]) +sch.reorder(l42, l50, l58, l66, l74, l43, l51, l59, l67, l75, l80, l84, l88, l92, l96, l44, l52, l60, l68, l76, l81, l85, l89, l93, l97, l45, l53, l61, l69, l77) +b98 = sch.cache_write(block=b27, write_buffer_index=0, storage_scope="global") +sch.reverse_compute_at(block=b98, loop=l75, preserve_unit_loops=True)""".split( + "\n" + ), + """b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main") +sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS") +l1, l2, l3, l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0) +l11, l12 = sch.split(loop=l10, factors=[1, 4]) +l13, l14 = sch.split(loop=l5, factors=[1, 16]) +l15, l16, l17, l18, l19, l20, l21, l22, l23, l24, l25, l26 = sch.get_loops(block=b0) +sch.reorder(l21, l22, l23, l24, l25, l14, l12) +b27 = sch.blockize(loop=l14) +sch.annotate(block_or_loop=b27, ann_key="meta_schedule.auto_tensorize", ann_val="dot_16x4_vnni") +l28, l29, l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_loops(block=b27) +v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l28, n=4, max_innermost_factor=64) +l42, l43, l44, l45 = sch.split(loop=l28, factors=[v38, v39, v40, v41]) +v46, v47, v48, v49 = sch.sample_perfect_tile(loop=l29, n=4, max_innermost_factor=64) +l50, l51, l52, l53 = sch.split(loop=l29, factors=[v46, v47, v48, v49]) +v54, v55, v56, v57 = sch.sample_perfect_tile(loop=l30, n=4, max_innermost_factor=64) +l58, l59, l60, l61 = sch.split(loop=l30, factors=[v54, v55, v56, v57]) +v62, v63, v64, v65 = sch.sample_perfect_tile(loop=l31, n=4, max_innermost_factor=64) +l66, l67, l68, l69 = sch.split(loop=l31, factors=[v62, v63, v64, v65]) +v70, v71, v72, v73 = sch.sample_perfect_tile(loop=l32, n=4, max_innermost_factor=64) +l74, l75, l76, l77 = sch.split(loop=l32, factors=[v70, v71, v72, v73]) +v78, v79 = sch.sample_perfect_tile(loop=l33, n=2, max_innermost_factor=64) +l80, l81 = sch.split(loop=l33, factors=[v78, v79]) +v82, v83 = sch.sample_perfect_tile(loop=l34, n=2, max_innermost_factor=64) +l84, l85 = sch.split(loop=l34, factors=[v82, v83]) +v86, v87 = sch.sample_perfect_tile(loop=l35, n=2, max_innermost_factor=64) +l88, l89 = sch.split(loop=l35, factors=[v86, v87]) +v90, v91 = sch.sample_perfect_tile(loop=l36, n=2, max_innermost_factor=64) +l92, l93 = sch.split(loop=l36, factors=[v90, v91]) +v94, v95 = sch.sample_perfect_tile(loop=l37, n=2, max_innermost_factor=64) +l96, l97 = sch.split(loop=l37, factors=[v94, v95]) +sch.reorder(l42, l50, l58, l66, l74, l43, l51, l59, l67, l75, l80, l84, l88, l92, l96, l44, l52, l60, l68, l76, l81, l85, l89, l93, l97, l45, l53, l61, l69, l77) +b98 = sch.cache_write(block=b27, write_buffer_index=0, storage_scope="global") +sch.reverse_compute_at(block=b98, loop=l74, preserve_unit_loops=True)""".split( + "\n" + ), + """b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main") +sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS") +l1, l2, l3, l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0) +l11, l12 = sch.split(loop=l10, factors=[1, 4]) +l13, l14 = sch.split(loop=l5, factors=[1, 16]) +l15, l16, l17, l18, l19, l20, l21, l22, l23, l24, l25, l26 = sch.get_loops(block=b0) +sch.reorder(l21, l22, l23, l24, l25, l14, l12) +b27 = sch.blockize(loop=l14) +sch.annotate(block_or_loop=b27, ann_key="meta_schedule.auto_tensorize", ann_val="dot_16x4_vnni") +l28, l29, l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_loops(block=b27) +v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l28, n=4, max_innermost_factor=64) +l42, l43, l44, l45 = sch.split(loop=l28, factors=[v38, v39, v40, v41]) +v46, v47, v48, v49 = sch.sample_perfect_tile(loop=l29, n=4, max_innermost_factor=64) +l50, l51, l52, l53 = sch.split(loop=l29, factors=[v46, v47, v48, v49]) +v54, v55, v56, v57 = sch.sample_perfect_tile(loop=l30, n=4, max_innermost_factor=64) +l58, l59, l60, l61 = sch.split(loop=l30, factors=[v54, v55, v56, v57]) +v62, v63, v64, v65 = sch.sample_perfect_tile(loop=l31, n=4, max_innermost_factor=64) +l66, l67, l68, l69 = sch.split(loop=l31, factors=[v62, v63, v64, v65]) +v70, v71, v72, v73 = sch.sample_perfect_tile(loop=l32, n=4, max_innermost_factor=64) +l74, l75, l76, l77 = sch.split(loop=l32, factors=[v70, v71, v72, v73]) +v78, v79 = sch.sample_perfect_tile(loop=l33, n=2, max_innermost_factor=64) +l80, l81 = sch.split(loop=l33, factors=[v78, v79]) +v82, v83 = sch.sample_perfect_tile(loop=l34, n=2, max_innermost_factor=64) +l84, l85 = sch.split(loop=l34, factors=[v82, v83]) +v86, v87 = sch.sample_perfect_tile(loop=l35, n=2, max_innermost_factor=64) +l88, l89 = sch.split(loop=l35, factors=[v86, v87]) +v90, v91 = sch.sample_perfect_tile(loop=l36, n=2, max_innermost_factor=64) +l92, l93 = sch.split(loop=l36, factors=[v90, v91]) +v94, v95 = sch.sample_perfect_tile(loop=l37, n=2, max_innermost_factor=64) +l96, l97 = sch.split(loop=l37, factors=[v94, v95]) +sch.reorder(l42, l50, l58, l66, l74, l43, l51, l59, l67, l75, l80, l84, l88, l92, l96, l44, l52, l60, l68, l76, l81, l85, l89, l93, l97, l45, l53, l61, l69, l77)""".split( + "\n" + ), + ] + + check_trace(spaces, expected) + + +def test_multi_level_tiling_dense_dpa4(): + m, n, k = 128, 128, 128 + + X = te.placeholder((m, k), name="X", dtype="int8") + W = te.placeholder((n, k), name="W", dtype="int8") + ak = te.reduce_axis((0, k), name="k") + + matmul = te.compute( + (m, n), + lambda i, j: te.sum( + X[i, ak].astype("int32") * W[j, ak].astype("int32"), + axis=ak, + ), + name="compute", + ) + + func = te.create_prim_func([X, W, matmul]) + + ctx = _create_context( + func, + target=tvm.target.Target("cuda"), + rule=schedule_rule.MultiLevelTilingWithIntrin( + DP4A_INTRIN, + structure="SSSRRSRS", + tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"], + max_innermost_factor=64, + vector_load_lens=[1, 2, 3, 4], + reuse_read=schedule_rule.ReuseType( + req="must", + levels=[4], + scope="shared", + ), + reuse_write=schedule_rule.ReuseType( + req="must", + levels=[3], + scope="local", + ), + ), + ) + + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + + expected = [ + """b0 = sch.get_block(name="compute", func_name="main") +sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") +l1, l2, l3 = sch.get_loops(block=b0) +l4, l5 = sch.split(loop=l3, factors=[32, 4]) +sch.reorder(l5) +b6 = sch.blockize(loop=l5) +sch.annotate(block_or_loop=b6, ann_key="meta_schedule.auto_tensorize", ann_val="dp4a") +l7, l8, l9 = sch.get_loops(block=b6) +v10, v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64) +l15, l16, l17, l18, l19 = sch.split(loop=l7, factors=[v10, v11, v12, v13, v14]) +v20, v21, v22, v23, v24 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64) +l25, l26, l27, l28, l29 = sch.split(loop=l8, factors=[v20, v21, v22, v23, v24]) +v30, v31, v32 = sch.sample_perfect_tile(loop=l9, n=3, max_innermost_factor=64) +l33, l34, l35 = sch.split(loop=l9, factors=[v30, v31, v32]) +sch.reorder(l15, l25, l16, l26, l17, l27, l33, l34, l18, l28, l35, l19, l29) +l36 = sch.fuse(l15, l25) +sch.bind(loop=l36, thread_axis="blockIdx.x") +l37 = sch.fuse(l16, l26) +sch.bind(loop=l37, thread_axis="vthread.x") +l38 = sch.fuse(l17, l27) +sch.bind(loop=l38, thread_axis="threadIdx.x") +b39 = sch.cache_write(block=b6, write_buffer_index=0, storage_scope="local") +sch.reverse_compute_at(block=b39, loop=l38, preserve_unit_loops=True) +b40 = sch.cache_read(block=b6, read_buffer_index=0, storage_scope="shared") +sch.compute_at(block=b40, loop=l33, preserve_unit_loops=True) +l41, l42, l43, l44, l45, l46 = sch.get_loops(block=b40) +l47 = sch.fuse(l45, l46) +v48 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25]) +sch.annotate(block_or_loop=b40, ann_key="meta_schedule.cooperative_fetch", ann_val=v48) +b49 = sch.cache_read(block=b6, read_buffer_index=1, storage_scope="shared") +sch.compute_at(block=b49, loop=l33, preserve_unit_loops=True) +l50, l51, l52, l53, l54, l55 = sch.get_loops(block=b49) +l56 = sch.fuse(l54, l55) +v57 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25]) +sch.annotate(block_or_loop=b49, ann_key="meta_schedule.cooperative_fetch", ann_val=v57)""".split( + "\n" + ) + ] + + check_trace(spaces, expected) + + if __name__ == "__main__": test_cpu_matmul() test_cpu_matmul_relu() test_cuda_matmul() test_cuda_matmul_relu() test_cuda_sum_with_trivial_block_iter() + test_multi_level_tiling_conv2d_nchwc_vnni() + test_multi_level_tiling_dense_dpa4()