From 8dd2de58d1b1dca7ac35acbb947bd7cb0ffefb48 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 6 Jun 2022 14:52:55 -0700 Subject: [PATCH] [MetaSchedule] Add MultiLevelTilingTensorCore rule for auto-tensorization on CUDA Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Hongyi Jin <3231950289@qq.com> --- include/tvm/meta_schedule/schedule_rule.h | 26 ++ include/tvm/tir/stmt.h | 4 + .../meta_schedule/schedule_rule/__init__.py | 7 +- .../schedule_rule/multi_level_tiling.py | 56 ++- .../meta_schedule/testing/schedule_rule.py | 34 ++ python/tvm/tir/tensor_intrin/cuda.py | 8 +- .../postproc/rewrite_reduction_block.cc | 15 + .../postproc/rewrite_tensorize.cc | 34 +- .../schedule_rule/auto_inline.cc | 3 +- .../schedule_rule/multi_level_tiling.cc | 28 +- .../schedule_rule/multi_level_tiling.h | 35 +- .../multi_level_tiling_tensor_core.cc | 371 ++++++++++++++++ src/tir/schedule/analysis/analysis.cc | 3 + src/tir/schedule/transform.cc | 6 +- ...hedule_schedule_rule_multi_level_tiling.py | 420 +++++++++++++++++- 15 files changed, 995 insertions(+), 55 deletions(-) create mode 100644 src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 7e0e5bda57b6..2677864b4469 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -173,6 +173,32 @@ class ScheduleRule : public runtime::ObjectRef { Optional max_innermost_factor, Optional> vector_load_lens, Optional> reuse_read, Optional> reuse_write); + /*! + * \brief Extension of MultiLevelTiling for auto-tensorizing with a single group of tensor core + * intrinsics + * \param intrin_group A group of tensor core intrinsics. The map should contains key "init", + * "load_a", "load_b", "compute", "store", which represent the tensor intrin for initialization, + * loading operand A, loading operand B, tensor core computation, storing the result. The value of + * the map should be names of tensor intrinsics, must be registerd via TensorIntrin.register(...) + * beforehand + * \param structure The tiling structure. Recommended: + * - 'SSRSRS' on CPU + * - 'SSSRRSRS' on GPU + * \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended: + * - NullOpt on CPU + * - [blockIdx.y, blockIdx.x, threadIdx.y] on GPU + * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit + * \param vector_load_lens The length of vector lane in vectorized cooperative fetching. + * NullOpt means disable vectorization + * \param reuse_read Data reuse configuration for reading. NullOpt means no reuse. + * \param reuse_write Data reuse configuration for writing. NullOpt means no reuse. + * \return The schedule rule created + */ + TVM_DLL static ScheduleRule MultiLevelTilingTensorCore( + Map intrin_group, String structure, Optional> tile_binds, + Optional max_innermost_factor, Optional> vector_load_lens, + Optional> reuse_read, Optional> reuse_write); + /*! * \brief Create a rule: add-rfactor to some blocks if needed * \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index ddc97549fc70..2060fb7920ed 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1524,6 +1524,10 @@ constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensori /*! \brief Mark that a block is a preprocessor block for layout rewrite. */ constexpr const char* meta_schedule_layout_rewrite_preproc = "meta_schedule.layout_rewrite_preproc"; +/*! + * \brief Mark that the init statement of a block should be further rewritten using tensorization. + */ +constexpr const char* meta_schedule_auto_tensorize_init = "meta_schedule.auto_tensorize_init"; /*! * \brief Mark that a block is executed by a warp. This implies the extend of threadIdx.x is diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py b/python/tvm/meta_schedule/schedule_rule/__init__.py index 18fc1de78c7b..dd0119b0a7f8 100644 --- a/python/tvm/meta_schedule/schedule_rule/__init__.py +++ b/python/tvm/meta_schedule/schedule_rule/__init__.py @@ -23,7 +23,12 @@ from .auto_bind import AutoBind from .auto_inline import AutoInline from .cross_thread_reduction import CrossThreadReduction -from .multi_level_tiling import MultiLevelTiling, MultiLevelTilingWithIntrin, ReuseType +from .multi_level_tiling import ( + MultiLevelTiling, + MultiLevelTilingWithIntrin, + ReuseType, + MultiLevelTilingTensorCore, +) from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll from .random_compute_location import RandomComputeLocation from .schedule_rule import PyScheduleRule, ScheduleRule diff --git a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py index 0bad6cbb4cd5..9e455f5af4ed 100644 --- a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py +++ b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Multi-level tiling with reuse.""" -from typing import Any, Dict, List, NamedTuple, Optional +from typing import Any, Dict, List, Mapping, NamedTuple, Optional from tvm._ffi import register_object @@ -131,3 +131,57 @@ def __init__( reuse_read.as_dict() if reuse_read is not None else None, reuse_write.as_dict() if reuse_write is not None else None, ) + + +@register_object("meta_schedule.MultiLevelTilingTensorCore") +class MultiLevelTilingTensorCore(ScheduleRule): + """Extension of MultiLevelTiling for auto-tensorizing with a single group of tensor core + intrinsics. + + Parameters + ---------- + intrin_group : Mapping[str, str] + A group of tensor core intrinsics. The map should contains key "init", "load_a", "load_b", + "compute", "store", which represent the tensor intrin for initialization, loading operand A, + loading operand B, tensor core computation, storing the result. + The value of the map should be names of tensor intrinsics, must be registerd via + TensorIntrin.register(...) beforehand + structure : str + The tiling structure. Recommended: + - 'SSRSRS' on CPU + - 'SSSRRSRS' on GPU + tile_bind : Optional[List[str]] + For each level of tiles, which thread axis it is bound to. Recommended: + - None on CPU + - [blockIdx.y, vthread.x, threadIdx.y] on GPU + max_innermost_factor : Optional[int] + The maximum size of the innermost factor. None means no limit + vector_load_lens : Optional[List[int]] + The length of vector lane in vectorized cooperative fetching. + None means disable vectorization + reuse_read : Optional[ReuseType] + Data reuse configuration for reading. None means no reuse. + reuse_write : Optional[ReuseType] + Data reuse configuration for writing. None means no reuse. + """ + + def __init__( + self, + intrin_group: Mapping[str, str], + structure: str, + tile_binds: Optional[List[str]] = None, + max_innermost_factor: Optional[int] = None, + vector_load_lens: Optional[List[int]] = None, + reuse_read: Optional[ReuseType] = None, + reuse_write: Optional[ReuseType] = None, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleMultiLevelTilingTensorCore, # type: ignore # pylint: disable=no-member + intrin_group, + structure, + tile_binds, + max_innermost_factor, + vector_load_lens, + reuse_read.as_dict() if reuse_read is not None else None, + reuse_write.as_dict() if reuse_write is not None else None, + ) diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py index e159bfaaaa5a..f6ada833ecd4 100644 --- a/python/tvm/meta_schedule/testing/schedule_rule.py +++ b/python/tvm/meta_schedule/testing/schedule_rule.py @@ -26,6 +26,8 @@ ReuseType, ScheduleRule, ) +from tvm.meta_schedule.schedule_rule.multi_level_tiling import MultiLevelTilingTensorCore +from tvm.tir import tensor_intrin from tvm.target import Target @@ -110,6 +112,38 @@ def multi_level_tiling(target: Target) -> ScheduleRule: raise NotImplementedError(f"{target.kind.name} is not supported") +def multi_level_tiling_tensor_core(target: Target, scope="shared") -> ScheduleRule: + """Default schedule rules for with multi-level tiling reuse for tensor core""" + assert scope in ["shared", "global"] + if target.kind.name == "cuda": + return MultiLevelTilingTensorCore( + intrin_group={ + "init": tensor_intrin.WMMA_FILL_16x16x16_F32_INTRIN, + "load_a": tensor_intrin.WMMA_LOAD_16x16x16_F16_A_INTRIN, + "load_b": tensor_intrin.WMMA_LOAD_16x16x16_F16_B_INTRIN, + "compute": tensor_intrin.WMMA_SYNC_16x16x16_f16f16f32_INTRIN, + "store": tensor_intrin.WMMA_STORE_16x16x16_F32_SHARED_INTRIN + if scope == "shared" + else tensor_intrin.WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN, + }, + structure="SSSRRSRS", + tile_binds=["blockIdx.y", "blockIdx.x", "threadIdx.y"], + max_innermost_factor=4, # 64 // tensor intrin size + vector_load_lens=[1, 2, 3, 4], + reuse_read=ReuseType( + req="must", + levels=[4], + scope="shared", + ), + reuse_write=ReuseType( + req="must" if scope == "shared" else "no", + levels=[2], + scope="shared", + ), + ) + raise NotImplementedError(f"{target.kind.name} is not supported") + + def random_compute_location(target: Target) -> ScheduleRule: """Default schedule rules for with random-compute-location""" if target.kind.name == "llvm": diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index 909b13e35c7c..fae877e2276b 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -769,15 +769,15 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: *get_wmma_load_intrin(16, 16, 16, "float16", "shared", True, False), ) -WMMA_LOAD_16x16x16_F16_A_INTRIN = "wmma_load_16x16x16_f16_a_trans" +WMMA_LOAD_16x16x16_F16_A_TRANS_INTRIN = "wmma_load_16x16x16_f16_a_trans" TensorIntrin.register( - WMMA_LOAD_16x16x16_F16_A_INTRIN, + WMMA_LOAD_16x16x16_F16_A_TRANS_INTRIN, *get_wmma_load_intrin(16, 16, 16, "float16", "shared", False, True), ) -WMMA_LOAD_16x16x16_F16_B_INTRIN = "wmma_load_16x16x16_f16_b_trans" +WMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN = "wmma_load_16x16x16_f16_b_trans" TensorIntrin.register( - WMMA_LOAD_16x16x16_F16_B_INTRIN, + WMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN, *get_wmma_load_intrin(16, 16, 16, "float16", "shared", True, True), ) diff --git a/src/meta_schedule/postproc/rewrite_reduction_block.cc b/src/meta_schedule/postproc/rewrite_reduction_block.cc index cea1f5b93c9f..a31e6204f6d4 100644 --- a/src/meta_schedule/postproc/rewrite_reduction_block.cc +++ b/src/meta_schedule/postproc/rewrite_reduction_block.cc @@ -135,6 +135,21 @@ bool RewriteReductionBlockNode::Apply(const tir::Schedule& sch) { tir::BlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name); Array loop_rvs = sch->GetLoops(block_rv); tir::BlockRV init_block_rv = sch->DecomposeReduction(block_rv, loop_rvs[decompose_point]); + + // Rewrite auto tensorization related annotations + if (tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize).defined()) { + // Remove tensorization annotation as it shouldn't be propagated to the init block. + sch->Unannotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize); + } + if (Optional tensorize_init = + tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize_init)) { + // The annotation of tensorization of the init statement should be moved to the init block + // after 'DecomposeReduction'. + sch->Annotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize, + tensorize_init.value()); + sch->Unannotate(block_rv, tir::attr::meta_schedule_auto_tensorize_init); + sch->Unannotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize_init); + } ++rewritten; } if (rewritten == 0) { diff --git a/src/meta_schedule/postproc/rewrite_tensorize.cc b/src/meta_schedule/postproc/rewrite_tensorize.cc index 3df907597296..e4ec27050556 100644 --- a/src/meta_schedule/postproc/rewrite_tensorize.cc +++ b/src/meta_schedule/postproc/rewrite_tensorize.cc @@ -35,26 +35,24 @@ void CollectTensorizationJobs( tir::PostOrderVisit(func->body, [=, &jobs](const ObjectRef& obj) { if (const auto* block = obj.as()) { tir::StmtSRef block_sref = sch->GetSRef(block); + std::string block_name = block_sref->StmtAs()->name_hint; if (Optional intrin_name = tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize)) { - std::string block_name = block_sref->StmtAs()->name_hint; - if (block_name.find("init") == std::string::npos) { - jobs->emplace_back(block_name, func_name, [sch, intrin_name](tir::BlockRV block) { - try { - sch->Tensorize(block, intrin_name.value()); - } catch (const std::exception& e) { - LOG(WARNING) << "Tensorize failed with error " << e.what(); - } - }); - } else if (vectorize_init_loop) { - jobs->emplace_back(block_name, func_name, [sch](tir::BlockRV block) { - Array child_blocks = sch->GetChildBlocks(block); - ICHECK(child_blocks.size() == 1); - Array init_loops = sch->GetLoops(child_blocks[0]); - ICHECK(init_loops.size() == 1); - sch->Vectorize(init_loops[0]); - }); - } + jobs->emplace_back(block_name, func_name, [sch, intrin_name](tir::BlockRV block) { + try { + sch->Tensorize(block, intrin_name.value()); + } catch (const std::exception& e) { + LOG(WARNING) << "Tensorize failed with error " << e.what(); + } + }); + } else if (block_name.find("init") && vectorize_init_loop) { + jobs->emplace_back(block_name, func_name, [sch](tir::BlockRV block) { + Array child_blocks = sch->GetChildBlocks(block); + ICHECK(child_blocks.size() == 1); + Array init_loops = sch->GetLoops(child_blocks[0]); + ICHECK(init_loops.size() == 1); + sch->Vectorize(init_loops[0]); + }); } } }); diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc index 309f0a60aca0..df4d3ac85911 100644 --- a/src/meta_schedule/schedule_rule/auto_inline.cc +++ b/src/meta_schedule/schedule_rule/auto_inline.cc @@ -143,7 +143,8 @@ inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch, Array producer_srefs = GetProducers(state, block_sref); if (producer_srefs.size() == 1 && tir::IsCompleteBlock(sch->state(), producer_srefs[0], scope_block) && - CanReverseComputeInline(state, block_sref)) { + CanReverseComputeInline(state, block_sref) && + !GetAnn(producer_srefs[0], tir::attr::meta_schedule_auto_tensorize).defined()) { return InlineType::kInlineIntoProducer; } } diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index 2f2eb219e8c7..5f048dec007f 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -61,6 +61,8 @@ using tir::IterVarType; using tir::LoopRV; using tir::Schedule; +TVM_REGISTER_OBJECT_TYPE(StateNode); + State::State(tir::Schedule sch, tir::BlockRV block_rv, Array> tiles) { ObjectPtr node = make_object(); node->sch = std::move(sch); @@ -133,6 +135,7 @@ std::vector MultiLevelTilingNode::AddWriteReuse(State state) const { new_state->sch->ReverseComputeAt(consumer_rvs[0], loop_rv, true); results.push_back(std::move(new_state)); } + state->write_reuse.emplace(0, consumer_rvs[0]); results.push_back(state); return results; } else { @@ -146,6 +149,7 @@ std::vector MultiLevelTilingNode::AddWriteReuse(State state) const { BlockRV write_cache = state->sch->CacheWrite(/*block_rv=*/state->block_rv, /*read_buffer_index=*/0, /*storage_scope=*/config.scope); + state->write_reuse.emplace(0, write_cache); for (int level : levels) { State new_state = state->Copy(); const LoopRV& loop_rv = new_state->tiles[level - 1].back(); @@ -247,22 +251,26 @@ std::vector MultiLevelTilingNode::AddReadReuse(State state) const { Array buffer_loops = sch->GetLoops(cache_read_block); LoopRV fused = sch->Fuse(Array{buffer_loops.end() - buffer_ndim, // buffer_loops.end()}); - // Annotate cooperative fetching - if (!vector_load_lens.empty()) { - int n = vector_load_lens.size(); - double prob = 1.0 / n; - tir::ExprRV vector_load_len = - sch->SampleCategorical(support::AsArray(vector_load_lens), - Array(n, FloatImm(DataType::Float(64), prob))); - sch->Annotate(cache_read_block, tir::attr::meta_schedule_cooperative_fetch, - vector_load_len); - } + AnnotateCooperativeFetching(&sch, cache_read_block); + new_state->read_reuse.emplace(i, cache_read_block); } results.push_back(std::move(new_state)); } return results; } +void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch, + const tir::BlockRV& block) const { + if (!vector_load_lens.empty()) { + int n = vector_load_lens.size(); + double prob = 1.0 / n; + tir::ExprRV vector_load_len = + (*sch)->SampleCategorical(support::AsArray(vector_load_lens), + Array(n, FloatImm(DataType::Float(64), prob))); + (*sch)->Annotate(block, tir::attr::meta_schedule_cooperative_fetch, vector_load_len); + } +} + // Constructor ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional> tile_binds, diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h index 05179318d0b3..36c2efdafbef 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h @@ -22,6 +22,7 @@ #include #include +#include #include #include @@ -93,6 +94,10 @@ class StateNode : public Object { tir::BlockRV block_rv; /*! \brief The loop tiles */ Array> tiles; + /*! \brief The mapping from buffer index to read cache block. */ + std::unordered_map read_reuse; + /*! \brief The mapping from buffer index to write cache block. */ + std::unordered_map write_reuse; /*! * \brief Create a copy of the state. The underlying schedule is copied. Schedule rules that @@ -112,6 +117,31 @@ class State : public ObjectRef { TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(State, ObjectRef, StateNode); }; +class TensorCoreStateNode : public StateNode { + public: + /*! \brief The Tensor Core reindex block A for Tensor Core computation */ + tir::BlockRV tensor_core_reindex_A; + /*! \brief The Tensor Core reindex block B for Tensor Core computation */ + tir::BlockRV tensor_core_reindex_B; + /*! \brief The Tensor Core reindex store block for Tensor Core computation */ + tir::BlockRV tensor_core_reindex_store; + + State Copy() const final; + + static constexpr const char* _type_key = "meta_schedule.TensorCoreState"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorCoreStateNode, StateNode); +}; + +class TensorCoreState : public State { + public: + explicit TensorCoreState(tir::Schedule sch, tir::BlockRV block_rv, + Array> tiles = {}); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorCoreState, State, TensorCoreStateNode); +}; + +struct AutoTensorizationState : public State {}; + /*! * \brief Helper to apply a sub-rule to a list of auto scheduling states * \tparam FLambda The type of the sub-rule functor @@ -148,11 +178,14 @@ class MultiLevelTilingNode : public ScheduleRuleNode { void InitializeWithTuneContext(const TuneContext& context) final; // Entry of the mega rule; Inherited from ScheduleRuleNode - Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final; + Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) override; protected: virtual std::vector ApplySubRules(std::vector states); + // Annotate a block to use cooperative fetching + void AnnotateCooperativeFetching(tir::Schedule* sch, const tir::BlockRV& block) const; + public: /*! * \brief The tiling structure. Recommended: diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc new file mode 100644 index 000000000000..7cf23fa8ad93 --- /dev/null +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -0,0 +1,371 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include +#include +#include + +#include "../utils.h" +#include "./multi_level_tiling.h" + +namespace tvm { +namespace meta_schedule { + +using tir::BlockRV; +using tir::LoopRV; +using tir::Schedule; + +struct TensorCoreIntrinGroup { + String init_intrin; + String load_a_intrin; + String load_b_intrin; + String compute_intrin; + String store_intrin; +}; + +TVM_REGISTER_OBJECT_TYPE(TensorCoreStateNode); + +TensorCoreState::TensorCoreState(Schedule sch, BlockRV block_rv, Array> tiles) { + ObjectPtr node = make_object(); + node->sch = std::move(sch); + node->block_rv = std::move(block_rv); + node->tiles = std::move(tiles); + data_ = std::move(node); +} + +State TensorCoreStateNode::Copy() const { + ObjectPtr node = make_object(*this); + node->sch = sch->Copy(); + return State(node); +} + +/*! + * \brief Extension of MultiLevelTiling for auto-tensorizing with a single intrinsic. + */ +class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode { + private: + // SubRule: Add tensorization-related transformations + inline std::vector TransformForTensorization(TensorCoreState state) const; + // Subrule: Add tensorized load + inline std::vector AddReadReuseTensorCore(TensorCoreState state) const; + // Subrule: Add tensorized store + inline std::vector AddWriteReuseTensorCore(TensorCoreState state) const; + + // Override ApplySubRules to apply tensorization-specific sub-rules + std::vector ApplySubRules(std::vector states) final; + + // Override Apply to apply tensorization-specific analysis before applying sub-rules + Array Apply(const Schedule& sch, const BlockRV& block_rv) final; + + /*! + * \brief Transform and tensorize with the given tensor intrin + * \param state The state of the meta schedule rule + * \param intrin_name The name of the tensor intrin + * \return The loop to be tensorized. NullOpt if the workload can't be tensorized. + */ + Optional TransformWithTensorIntrin(TensorCoreStateNode* state, + const String& intrin_name) const; + + using BufferTypeIndex = std::pair; + + /*! + * \brief Extract buffer index and its type from block reads/writes + * \param block_sref The sref to the block to extract + * \return The mapping from buffer to its type and and index + */ + std::unordered_map + ExtractBufferIndex(const tir::StmtSRef& block_sref) const; + + /*! + * \brief Tile, blockize and annotate for tensorization with the given intrin + * \param block_rv The block to be tensorized + * \param intrin_name The name of the tensor intrin + */ + void TileAndAnnotateTensorize(Schedule* sch, const BlockRV& block_rv, + const String& intrin_name) const; + + public: + /*! \brief The tensor core intrin group to apply */ + TensorCoreIntrinGroup intrin_group; + static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingTensorCore"; + TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingTensorCoreNode, MultiLevelTilingNode); + + private: + /*! + * \brief The mapping info for auto tensorization + */ + tir::AutoTensorizeMappingInfo mapping_info_{nullptr}; +}; + +// Entry of the mega rule; Inherited from ScheduleRuleNode +Array MultiLevelTilingTensorCoreNode::Apply(const Schedule& sch, + const BlockRV& block_rv) { + if (!NeedsMultiLevelTiling(sch->state(), sch->GetSRef(block_rv))) { + return {sch}; + } + + Optional mapping_info = + tir::GetAutoTensorizeMappingInfo(sch->state(), sch->GetSRef(block_rv), + tir::TensorIntrin::Get(intrin_group.compute_intrin)->desc); + if (!mapping_info.defined()) { + return {sch}; + } + mapping_info_ = mapping_info.value(); + + // Create a copy of the schedule so that we can roll back transformations if tensorization + // fail. + Schedule original_sch = sch->Copy(); + sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure); + + Array results; + for (auto&& state : ApplySubRules({TensorCoreState(sch, block_rv)})) { + results.push_back(std::move(state->sch)); + } + if (results.empty()) { + return {original_sch}; + } + return results; +} + +std::vector MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector states) { + states = SubRule(std::move(states), [&](State state) { + return TransformForTensorization(Downcast(state)); + }); + states = SubRule(std::move(states), [&](State state) { return TileLoopNest(state); }); + states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(state); }); + states = SubRule(std::move(states), [&](State state) { + return AddWriteReuseTensorCore(Downcast(state)); + }); + states = SubRule(std::move(states), [&](State state) { return AddReadReuse(state); }); + states = SubRule(std::move(states), [&](State state) { + return AddReadReuseTensorCore(Downcast(state)); + }); + return states; +} + +void MultiLevelTilingTensorCoreNode::TileAndAnnotateTensorize(Schedule* sch, + const BlockRV& block_rv, + const String& intrin_name) const { + Optional loop = TileWithTensorIntrin(*sch, block_rv, intrin_name).value(); + ICHECK(loop.defined()); + BlockRV blockized_outer = (*sch)->Blockize(loop.value()); + (*sch)->Annotate(blockized_outer, tir::attr::meta_schedule_auto_tensorize, intrin_name); +} + +std::vector MultiLevelTilingTensorCoreNode::AddWriteReuseTensorCore( + TensorCoreState state) const { + // Add the cache write stage for Tensor Core + int level = r_indices_.front() - 1; + const LoopRV& loop = state->tiles[level].back(); + Schedule& sch = state->sch; + auto cache_write = sch->CacheWrite(state->block_rv, 0, "wmma.accumulator"); + sch->ReverseComputeAt(cache_write, loop, true); + + if (state->write_reuse.count(0)) { + AnnotateCooperativeFetching(&sch, state->write_reuse[0]); + } + sch->ReverseComputeInline(state->tensor_core_reindex_store); + TileAndAnnotateTensorize(&sch, cache_write, intrin_group.store_intrin); + return {state}; +} + +std::vector MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore( + TensorCoreState state) const { + const Array& r_tiles = state->tiles[r_indices_[1]]; + Schedule& sch = state->sch; + ICHECK(!r_tiles.empty()) << "ValueError: Cannot find the suitable reduction loop in the block"; + + auto f_tensorize_load = [&](int read_index, String scope, String intrin_name) { + auto cache_read = sch->CacheRead(state->block_rv, read_index, scope); + state->sch->ComputeAt(cache_read, r_tiles.back(), true); + TileAndAnnotateTensorize(&sch, cache_read, intrin_name); + }; + + f_tensorize_load(0, "wmma.matrix_a", intrin_group.load_a_intrin); + f_tensorize_load(1, "wmma.matrix_b", intrin_group.load_b_intrin); + sch->ComputeInline(state->tensor_core_reindex_A); + sch->ComputeInline(state->tensor_core_reindex_B); + + sch->StorageAlign(state->read_reuse.at(0), 0, -2, 32, 8); + sch->StorageAlign(state->read_reuse.at(1), 0, -2, 32, 8); + return {state}; +} + +std::unordered_map +MultiLevelTilingTensorCoreNode::ExtractBufferIndex(const tir::StmtSRef& block_sref) const { + const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + // Collect buffer info before + std::unordered_map buffer_index_info; + for (int i = 0; i < static_cast(block->reads.size()); ++i) { + buffer_index_info[block->reads[i]->buffer] = {tir::BufferIndexType::kRead, i}; + } + for (int i = 0; i < static_cast(block->writes.size()); ++i) { + buffer_index_info[block->writes[i]->buffer] = {tir::BufferIndexType::kWrite, i}; + } + return buffer_index_info; +} + +Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( + TensorCoreStateNode* state, const String& intrin_name) const { + BlockRV block_rv = state->block_rv; + tir::StmtSRef block_sref = state->sch->GetSRef(state->block_rv); + + std::unordered_map + buffer_index_info = ExtractBufferIndex(block_sref); + + // Add reindex stages + const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + if (block->reads.size() != 2 || block->writes.size() != 1) { + // only matmul-like computation is allowed + return NullOpt; + } + state->tensor_core_reindex_store = + state->sch->ReIndex(state->block_rv, 0, tir::BufferIndexType::kWrite); + state->tensor_core_reindex_A = + state->sch->ReIndex(state->block_rv, 0, tir::BufferIndexType::kRead); + state->tensor_core_reindex_B = + state->sch->ReIndex(state->block_rv, 1, tir::BufferIndexType::kRead); + + // Transform the layout of reindex buffers accordingly. + // The index map defines the mapping for the computation block. We need to extract the sub index + // map to transform the load and store block. + ICHECK_EQ(mapping_info_->mappings.size(), 1U); // assume only one mapping is present + const tir::IndexMap& index_map = mapping_info_->mappings[0]; + + // Find the correspondence between block iters and the iters in the index map. + std::unordered_map lhs_to_index_map_src; + std::unordered_map rhs_to_index_map_tgt; + std::unordered_set unmapped_index_map_src; + ICHECK_EQ(mapping_info_->lhs_iters.size(), index_map->initial_indices.size()); + for (int i = 0; i < static_cast(mapping_info_->lhs_iters.size()); ++i) { + lhs_to_index_map_src[mapping_info_->lhs_iters[i]->var] = index_map->initial_indices[i]; + } + // The number of result iters in the index map is equal or more than the number of rhs (the + // tensor intrin) iters. When there are extra iters, these iters represent unmapped iters from the + // lhs. They will be skipped during pattern matching for tensorization. + // An example of such case is batch matmul, the batch dimension is kept after layout + // transformations and it will be kept as a outer loop after tensorization. + int offset = static_cast(index_map->final_indices.size()) - + static_cast(mapping_info_->rhs_iters.size()); + ICHECK_GE(offset, 0); + for (int i = 0; i < offset; ++i) { + const tir::VarNode* var_ptr = index_map->final_indices[i].as(); + ICHECK(var_ptr != nullptr); + unmapped_index_map_src.insert(GetRef(var_ptr)); + } + for (int i = offset; i < static_cast(index_map->final_indices.size()); ++i) { + rhs_to_index_map_tgt[mapping_info_->rhs_iters[i - offset]->var] = index_map->final_indices[i]; + } + + auto f_get_sub_index_map = [&](const tir::Buffer& lhs_buffer, const tir::Region& lhs_region) { + std::vector sub_index_map_src; + std::vector sub_index_map_tgt; + const tir::Buffer& rhs_buffer = mapping_info_->lhs_buffer_map[lhs_buffer]; + for (const Range& range : lhs_region) { + ICHECK(tir::is_one(range->extent)); + const tir::VarNode* var_ptr = range->min.as(); + ICHECK(var_ptr != nullptr); + const tir::Var& lhs_representer = lhs_to_index_map_src[GetRef(var_ptr)]; + sub_index_map_src.push_back(lhs_representer); + if (unmapped_index_map_src.count(lhs_representer)) { + sub_index_map_tgt.push_back(lhs_representer); + } + } + for (size_t i = 0; i < mapping_info_->rhs_buffer_indices[rhs_buffer].size(); ++i) { + const tir::VarNode* var = mapping_info_->rhs_buffer_indices[rhs_buffer][i].as(); + ICHECK(var != nullptr); + sub_index_map_tgt.push_back(rhs_to_index_map_tgt[GetRef(var)]); + } + return tir::IndexMap(sub_index_map_src, sub_index_map_tgt); + }; + + for (const auto& it : buffer_index_info) { + const tir::Buffer& lhs_buffer = it.first; + const tir::BufferIndexType buffer_type = it.second.first; + int buffer_index = it.second.second; + // Refresh block pointer (block sref is not invalidated) + block = TVM_SREF_TO_BLOCK(block, block_sref); + const tir::BufferRegion& reindexed_buffer_region = buffer_type == tir::BufferIndexType::kRead + ? block->reads[buffer_index] + : block->writes[buffer_index]; + auto sub_index_map = f_get_sub_index_map(lhs_buffer, reindexed_buffer_region->region); + state->sch->TransformLayout(state->block_rv, buffer_index, buffer_type, sub_index_map); + } + + // Transform the layout of current block and reindex blocks + state->sch->TransformBlockLayout(state->tensor_core_reindex_store, index_map); + state->sch->TransformBlockLayout(state->tensor_core_reindex_A, index_map); + state->sch->TransformBlockLayout(state->tensor_core_reindex_B, index_map); + state->sch->TransformBlockLayout(state->block_rv, index_map); + + return tir::TileWithTensorIntrin(state->sch, state->block_rv, intrin_group.compute_intrin); +} + +inline std::vector MultiLevelTilingTensorCoreNode::TransformForTensorization( + TensorCoreState state) const { + // Do reindex and layout transformations. + Optional transformed_loop_rv = + TransformWithTensorIntrin(state.operator->(), intrin_group.compute_intrin); + if (!transformed_loop_rv.defined()) { + // The workload can't be tensorized. + return {}; + } + + // Do blockize + state->block_rv = state->sch->Blockize(transformed_loop_rv.value()); + + // Add annotations for post processors. + state->sch->Annotate(state->block_rv, tir::attr::meta_schedule_auto_tensorize, + intrin_group.compute_intrin); + state->sch->Annotate(state->block_rv, tir::attr::meta_schedule_auto_tensorize_init, + intrin_group.init_intrin); + state->sch->Annotate(state->block_rv, tir::attr::warp_execution, Bool(true)); + return {std::move(state)}; +} + +ScheduleRule ScheduleRule::MultiLevelTilingTensorCore( + Map intrin_group, String structure, Optional> tile_binds, + Optional max_innermost_factor, Optional> vector_load_lens, + Optional> reuse_read, Optional> reuse_write) { + auto node = MultiLevelTilingInitCommon( + structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write); + + auto f_initialize_intrin = [&intrin_group](String key_name, String* intrin_name) { + CHECK(intrin_group.count(key_name)) << "ValueError: " << key_name << " is not set."; + *intrin_name = intrin_group.at(key_name); + // Check the existence of the intrin + tir::TensorIntrin::Get(*intrin_name); + }; + f_initialize_intrin("init", &node->intrin_group.init_intrin); + f_initialize_intrin("load_a", &node->intrin_group.load_a_intrin); + f_initialize_intrin("load_b", &node->intrin_group.load_b_intrin); + f_initialize_intrin("compute", &node->intrin_group.compute_intrin); + f_initialize_intrin("store", &node->intrin_group.store_intrin); + + return ScheduleRule(node); +} + +TVM_REGISTER_NODE_TYPE(MultiLevelTilingTensorCoreNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTilingTensorCore") + .set_body_typed(ScheduleRule::MultiLevelTilingTensorCore); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index ac73ac3ce2c1..f8fed79ae834 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -1942,6 +1942,9 @@ bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) { } bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref) { + if (HasBeenMultiLevelTiled(block_sref)) { + return false; + } const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); if (block->writes.size() != 1 || block->reads.empty() || IsSpatial(block_sref) || !IsTrivialBinding(self, block_sref)) { diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index 436d529abdc5..a739373ab329 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -278,9 +278,9 @@ Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::Block int64_t total = int_block_extent->value; int64_t inner = int_desc_extent->value; ICHECK_EQ(total % inner, 0); - int64_t outer = int_block_extent->value / int_desc_extent->value; - // Do the split - Array split = sch->Split(loop2rv.at(block_loop_sref), {Integer(outer), Integer(inner)}); + // Do the split. Leave the outer extent as NullOpt (unspecified) so that the split factors + // can be used for different extents (needed during tuning). + Array split = sch->Split(loop2rv.at(block_loop_sref), {NullOpt, Integer(inner)}); ICHECK_EQ(split.size(), 2); inner_loops.insert(sch->GetSRef(split[1]).operator->()); // The inner split will be reordered to the loop domain that is tensorized diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py index 30511d6690c7..a8c236bf5c31 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py @@ -16,11 +16,16 @@ # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring import tvm +import tvm.testing from tvm import te from tvm.meta_schedule import schedule_rule from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply from tvm.meta_schedule.testing import te_workload -from tvm.meta_schedule.testing.schedule_rule import multi_level_tiling +from tvm.meta_schedule.testing.schedule_rule import ( + auto_inline, + multi_level_tiling, + multi_level_tiling_tensor_core, +) from tvm.meta_schedule.testing.space_generation import check_trace from tvm.meta_schedule.tune_context import TuneContext from tvm.script import tir as T @@ -31,11 +36,13 @@ def _create_context(mod, target, rule) -> TuneContext: + if not isinstance(rule, (list, tuple)): + rule = [rule] ctx = TuneContext( mod=mod, target=target, space_generator=PostOrderApply(), - sch_rules=[rule], + sch_rules=rule, task_name="test", ) return ctx @@ -366,8 +373,8 @@ def test_multi_level_tiling_conv2d_nchwc_vnni(): """b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main") sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS") l1, l2, l3, l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0) -l11, l12 = sch.split(loop=l10, factors=[1, 4], preserve_unit_iters=True) -l13, l14 = sch.split(loop=l5, factors=[1, 16], preserve_unit_iters=True) +l11, l12 = sch.split(loop=l10, factors=[None, 4], preserve_unit_iters=True) +l13, l14 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True) l15, l16, l17, l18, l19, l20, l21, l22, l23, l24, l25, l26 = sch.get_loops(block=b0) sch.reorder(l21, l22, l23, l24, l25, l14, l12) b27 = sch.blockize(loop=l14) @@ -401,8 +408,8 @@ def test_multi_level_tiling_conv2d_nchwc_vnni(): """b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main") sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS") l1, l2, l3, l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0) -l11, l12 = sch.split(loop=l10, factors=[1, 4], preserve_unit_iters=True) -l13, l14 = sch.split(loop=l5, factors=[1, 16], preserve_unit_iters=True) +l11, l12 = sch.split(loop=l10, factors=[None, 4], preserve_unit_iters=True) +l13, l14 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True) l15, l16, l17, l18, l19, l20, l21, l22, l23, l24, l25, l26 = sch.get_loops(block=b0) sch.reorder(l21, l22, l23, l24, l25, l14, l12) b27 = sch.blockize(loop=l14) @@ -436,8 +443,8 @@ def test_multi_level_tiling_conv2d_nchwc_vnni(): """b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main") sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS") l1, l2, l3, l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0) -l11, l12 = sch.split(loop=l10, factors=[1, 4], preserve_unit_iters=True) -l13, l14 = sch.split(loop=l5, factors=[1, 16], preserve_unit_iters=True) +l11, l12 = sch.split(loop=l10, factors=[None, 4], preserve_unit_iters=True) +l13, l14 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True) l15, l16, l17, l18, l19, l20, l21, l22, l23, l24, l25, l26 = sch.get_loops(block=b0) sch.reorder(l21, l22, l23, l24, l25, l14, l12) b27 = sch.blockize(loop=l14) @@ -517,7 +524,7 @@ def test_multi_level_tiling_dense_dpa4(): """b0 = sch.get_block(name="compute", func_name="main") sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") l1, l2, l3 = sch.get_loops(block=b0) -l4, l5 = sch.split(loop=l3, factors=[32, 4], preserve_unit_iters=True) +l4, l5 = sch.split(loop=l3, factors=[None, 4], preserve_unit_iters=True) sch.reorder(l5) b6 = sch.blockize(loop=l5) sch.annotate(block_or_loop=b6, ann_key="meta_schedule.auto_tensorize", ann_val="dp4a") @@ -556,11 +563,392 @@ def test_multi_level_tiling_dense_dpa4(): check_trace(spaces, expected) +def test_cuda_tensor_core_conv2d(): + target = Target("cuda", host="llvm") + ctx = _create_context( + create_prim_func( + te_workload.conv2d_nhwc_f16( + N=1, H=16, W=16, CI=16, CO=16, kernel_size=3, stride=1, padding=1 + ) + ), + target, + multi_level_tiling_tensor_core(target=target, scope="shared"), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + + expected = [] + print("".join(spaces[0].trace.as_python())) + check_trace(spaces, expected) + + +def test_cuda_tensor_core_matmul_relu(): + m = n = k = 128 + target = Target("cuda", host="llvm") + ctx = _create_context( + create_prim_func( + te_workload.matmul_relu_fp16( + n=n, + m=m, + k=k, + ) + ), + target=target, + rule=[multi_level_tiling_tensor_core(target=target, scope="shared"), auto_inline(target)], + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + + expected = [ + """b0 = sch.get_block(name="C", func_name="main") +b1 = sch.get_block(name="compute", func_name="main") +sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") +b2 = sch.reindex(block=b0, buffer=("write", 0)) +b3 = sch.reindex(block=b0, buffer=("read", 0)) +b4 = sch.reindex(block=b0, buffer=("read", 1)) +sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda i, j: (i, j, )) +sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda j, k: (k, j, )) +sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda i, k: (i, k, )) +sch.transform_block_layout(block=b2, index_map=lambda i, j, k: (i, j, k, )) +sch.transform_block_layout(block=b3, index_map=lambda i, j, k: (i, j, k, )) +sch.transform_block_layout(block=b4, index_map=lambda i, j, k: (i, j, k, )) +sch.transform_block_layout(block=b0, index_map=lambda i, j, k: (i, j, k, )) +l5, l6, l7 = sch.get_loops(block=b0) +l8, l9 = sch.split(loop=l7, factors=[None, 16], preserve_unit_iters=True) +l10, l11 = sch.split(loop=l6, factors=[None, 16], preserve_unit_iters=True) +l12, l13 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True) +l14, l15, l16, l17, l18, l19 = sch.get_loops(block=b0) +sch.reorder(l16, l18, l13, l11, l9) +b20 = sch.blockize(loop=l13) +sch.annotate(block_or_loop=b20, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_sync_16x16x16_f16f16f32") +sch.annotate(block_or_loop=b20, ann_key="meta_schedule.auto_tensorize_init", ann_val="wmma_fill_16x16x16_f32") +sch.annotate(block_or_loop=b20, ann_key="warp_execution", ann_val=1) +l21, l22, l23 = sch.get_loops(block=b20) +v24, v25, v26, v27, v28 = sch.sample_perfect_tile(loop=l21, n=5, max_innermost_factor=4) +l29, l30, l31, l32, l33 = sch.split(loop=l21, factors=[v24, v25, v26, v27, v28], preserve_unit_iters=True) +v34, v35, v36, v37, v38 = sch.sample_perfect_tile(loop=l22, n=5, max_innermost_factor=4) +l39, l40, l41, l42, l43 = sch.split(loop=l22, factors=[v34, v35, v36, v37, v38], preserve_unit_iters=True) +v44, v45, v46 = sch.sample_perfect_tile(loop=l23, n=3, max_innermost_factor=4) +l47, l48, l49 = sch.split(loop=l23, factors=[v44, v45, v46], preserve_unit_iters=True) +sch.reorder(l29, l39, l30, l40, l31, l41, l47, l48, l32, l42, l49, l33, l43) +l50 = sch.fuse(l29, l39, preserve_unit_iters=True) +sch.bind(loop=l50, thread_axis="blockIdx.y") +l51 = sch.fuse(l30, l40, preserve_unit_iters=True) +sch.bind(loop=l51, thread_axis="blockIdx.x") +l52 = sch.fuse(l31, l41, preserve_unit_iters=True) +sch.bind(loop=l52, thread_axis="threadIdx.y") +b53 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="shared") +sch.reverse_compute_at(block=b53, loop=l51, preserve_unit_loops=True) +b54 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="wmma.accumulator") +sch.reverse_compute_at(block=b54, loop=l52, preserve_unit_loops=True) +v55 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25]) +sch.annotate(block_or_loop=b53, ann_key="meta_schedule.cooperative_fetch", ann_val=v55) +sch.reverse_compute_inline(block=b2) +l56, l57, l58, l59, l60 = sch.get_loops(block=b54) +l61, l62 = sch.split(loop=l60, factors=[None, 16], preserve_unit_iters=True) +l63, l64 = sch.split(loop=l59, factors=[None, 16], preserve_unit_iters=True) +l65, l66, l67, l68, l69, l70, l71 = sch.get_loops(block=b54) +sch.reorder(l70, l64, l62) +b72 = sch.blockize(loop=l64) +sch.annotate(block_or_loop=b72, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_shared") +b73 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="shared") +sch.compute_at(block=b73, loop=l47, preserve_unit_loops=True) +l74, l75, l76, l77, l78, l79 = sch.get_loops(block=b73) +l80 = sch.fuse(l78, l79, preserve_unit_iters=True) +v81 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25]) +sch.annotate(block_or_loop=b73, ann_key="meta_schedule.cooperative_fetch", ann_val=v81) +b82 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="shared") +sch.compute_at(block=b82, loop=l47, preserve_unit_loops=True) +l83, l84, l85, l86, l87, l88 = sch.get_loops(block=b82) +l89 = sch.fuse(l87, l88, preserve_unit_iters=True) +v90 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25]) +sch.annotate(block_or_loop=b82, ann_key="meta_schedule.cooperative_fetch", ann_val=v90) +b91 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="wmma.matrix_a") +sch.compute_at(block=b91, loop=l48, preserve_unit_loops=True) +l92, l93, l94, l95, l96, l97, l98 = sch.get_loops(block=b91) +l99, l100 = sch.split(loop=l98, factors=[None, 16], preserve_unit_iters=True) +l101, l102 = sch.split(loop=l97, factors=[None, 16], preserve_unit_iters=True) +l103, l104, l105, l106, l107, l108, l109, l110, l111 = sch.get_loops(block=b91) +sch.reorder(l110, l102, l100) +b112 = sch.blockize(loop=l102) +sch.annotate(block_or_loop=b112, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a") +b113 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="wmma.matrix_b") +sch.compute_at(block=b113, loop=l48, preserve_unit_loops=True) +l114, l115, l116, l117, l118, l119, l120 = sch.get_loops(block=b113) +l121, l122 = sch.split(loop=l120, factors=[None, 16], preserve_unit_iters=True) +l123, l124 = sch.split(loop=l119, factors=[None, 16], preserve_unit_iters=True) +l125, l126, l127, l128, l129, l130, l131, l132, l133 = sch.get_loops(block=b113) +sch.reorder(l132, l124, l122) +b134 = sch.blockize(loop=l124) +sch.annotate(block_or_loop=b134, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_b") +sch.compute_inline(block=b3) +sch.compute_inline(block=b4) +sch.storage_align(block=b73, buffer_index=0, axis=-2, factor=32, offset=8) +sch.storage_align(block=b82, buffer_index=0, axis=-2, factor=32, offset=8) +sch.reverse_compute_inline(block=b1)""".split( + "\n" + ) + ] + check_trace(spaces, expected) + + # test multi_level_tiling_tensor_core and multi_level_tiling can be used together in order + # to use multi_level_tiling as a fallback when the workload can't be tensorized + ctx = _create_context( + create_prim_func( + te_workload.matmul_relu_fp16( + n=n, + m=m, + k=k, + ) + ), + target=target, + rule=[ + multi_level_tiling_tensor_core(target=target, scope="shared"), + multi_level_tiling(target=target), + auto_inline(target), + ], + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +def test_cuda_tensor_core_matmul_relu_global(): + m = n = k = 128 + target = Target("cuda", host="llvm") + ctx = _create_context( + create_prim_func( + te_workload.matmul_relu_fp16( + n=n, + m=m, + k=k, + ), + ), + target=target, + rule=[multi_level_tiling_tensor_core(target=target, scope="global"), auto_inline(target)], + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + + expected = [ + """b0 = sch.get_block(name="C", func_name="main") +sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") +b1 = sch.reindex(block=b0, buffer=("write", 0)) +b2 = sch.reindex(block=b0, buffer=("read", 0)) +b3 = sch.reindex(block=b0, buffer=("read", 1)) +sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda i, j: (i, j, )) +sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda j, k: (k, j, )) +sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda i, k: (i, k, )) +sch.transform_block_layout(block=b1, index_map=lambda i, j, k: (i, j, k, )) +sch.transform_block_layout(block=b2, index_map=lambda i, j, k: (i, j, k, )) +sch.transform_block_layout(block=b3, index_map=lambda i, j, k: (i, j, k, )) +sch.transform_block_layout(block=b0, index_map=lambda i, j, k: (i, j, k, )) +l4, l5, l6 = sch.get_loops(block=b0) +l7, l8 = sch.split(loop=l6, factors=[None, 16], preserve_unit_iters=True) +l9, l10 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True) +l11, l12 = sch.split(loop=l4, factors=[None, 16], preserve_unit_iters=True) +l13, l14, l15, l16, l17, l18 = sch.get_loops(block=b0) +sch.reorder(l15, l17, l12, l10, l8) +b19 = sch.blockize(loop=l12) +sch.annotate(block_or_loop=b19, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_sync_16x16x16_f16f16f32") +sch.annotate(block_or_loop=b19, ann_key="meta_schedule.auto_tensorize_init", ann_val="wmma_fill_16x16x16_f32") +sch.annotate(block_or_loop=b19, ann_key="warp_execution", ann_val=1) +l20, l21, l22 = sch.get_loops(block=b19) +v23, v24, v25, v26, v27 = sch.sample_perfect_tile(loop=l20, n=5, max_innermost_factor=4) +l28, l29, l30, l31, l32 = sch.split(loop=l20, factors=[v23, v24, v25, v26, v27], preserve_unit_iters=True) +v33, v34, v35, v36, v37 = sch.sample_perfect_tile(loop=l21, n=5, max_innermost_factor=4) +l38, l39, l40, l41, l42 = sch.split(loop=l21, factors=[v33, v34, v35, v36, v37], preserve_unit_iters=True) +v43, v44, v45 = sch.sample_perfect_tile(loop=l22, n=3, max_innermost_factor=4) +l46, l47, l48 = sch.split(loop=l22, factors=[v43, v44, v45], preserve_unit_iters=True) +sch.reorder(l28, l38, l29, l39, l30, l40, l46, l47, l31, l41, l48, l32, l42) +l49 = sch.fuse(l28, l38, preserve_unit_iters=True) +sch.bind(loop=l49, thread_axis="blockIdx.y") +l50 = sch.fuse(l29, l39, preserve_unit_iters=True) +sch.bind(loop=l50, thread_axis="blockIdx.x") +l51 = sch.fuse(l30, l40, preserve_unit_iters=True) +sch.bind(loop=l51, thread_axis="threadIdx.y") +b52 = sch.cache_write(block=b19, write_buffer_index=0, storage_scope="wmma.accumulator") +sch.reverse_compute_at(block=b52, loop=l51, preserve_unit_loops=True) +sch.reverse_compute_inline(block=b1) +l53, l54, l55, l56, l57 = sch.get_loops(block=b52) +l58, l59 = sch.split(loop=l57, factors=[None, 16], preserve_unit_iters=True) +l60, l61 = sch.split(loop=l56, factors=[None, 16], preserve_unit_iters=True) +l62, l63, l64, l65, l66, l67, l68 = sch.get_loops(block=b52) +sch.reorder(l67, l61, l59) +b69 = sch.blockize(loop=l61) +sch.annotate(block_or_loop=b69, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_global") +b70 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="shared") +sch.compute_at(block=b70, loop=l46, preserve_unit_loops=True) +l71, l72, l73, l74, l75, l76 = sch.get_loops(block=b70) +l77 = sch.fuse(l75, l76, preserve_unit_iters=True) +v78 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25]) +sch.annotate(block_or_loop=b70, ann_key="meta_schedule.cooperative_fetch", ann_val=v78) +b79 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="shared") +sch.compute_at(block=b79, loop=l46, preserve_unit_loops=True) +l80, l81, l82, l83, l84, l85 = sch.get_loops(block=b79) +l86 = sch.fuse(l84, l85, preserve_unit_iters=True) +v87 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25]) +sch.annotate(block_or_loop=b79, ann_key="meta_schedule.cooperative_fetch", ann_val=v87) +b88 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="wmma.matrix_a") +sch.compute_at(block=b88, loop=l47, preserve_unit_loops=True) +l89, l90, l91, l92, l93, l94, l95 = sch.get_loops(block=b88) +l96, l97 = sch.split(loop=l95, factors=[None, 16], preserve_unit_iters=True) +l98, l99 = sch.split(loop=l94, factors=[None, 16], preserve_unit_iters=True) +l100, l101, l102, l103, l104, l105, l106, l107, l108 = sch.get_loops(block=b88) +sch.reorder(l107, l99, l97) +b109 = sch.blockize(loop=l99) +sch.annotate(block_or_loop=b109, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a") +b110 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="wmma.matrix_b") +sch.compute_at(block=b110, loop=l47, preserve_unit_loops=True) +l111, l112, l113, l114, l115, l116, l117 = sch.get_loops(block=b110) +l118, l119 = sch.split(loop=l117, factors=[None, 16], preserve_unit_iters=True) +l120, l121 = sch.split(loop=l116, factors=[None, 16], preserve_unit_iters=True) +l122, l123, l124, l125, l126, l127, l128, l129, l130 = sch.get_loops(block=b110) +sch.reorder(l129, l121, l119) +b131 = sch.blockize(loop=l121) +sch.annotate(block_or_loop=b131, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_b") +sch.compute_inline(block=b2) +sch.compute_inline(block=b3) +sch.storage_align(block=b70, buffer_index=0, axis=-2, factor=32, offset=8) +sch.storage_align(block=b79, buffer_index=0, axis=-2, factor=32, offset=8)""".split( + "\n" + ) + ] + check_trace(spaces, expected) + + +def test_multi_level_tiling_non_tensorizable(): + # expected to do nothing on non-tensorizable workloads + m = n = k = 128 + target = Target("cuda", host="llvm") + ctx = _create_context( + create_prim_func( + # dtype doesn't match tensor intrin + te_workload.matmul_relu( + n=n, + m=m, + k=k, + ) + ), + target=target, + rule=multi_level_tiling_tensor_core(target=target, scope="global"), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + + expected = [ + "", # expected to do nothing when the workload can't be tensorized + ] + check_trace(spaces, expected) + + +def test_cuda_tensor_core_conv2d(): + target = Target("cuda", host="llvm") + ctx = _create_context( + create_prim_func( + # dtype doesn't match tensor intrin + te_workload.conv2d_nhwc_f16( + N=1, H=16, W=16, CI=32, CO=32, kernel_size=3, stride=1, padding=1 + ) + ), + target=target, + rule=multi_level_tiling_tensor_core(target=target, scope="shared"), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + + expected = [ + """b0 = sch.get_block(name="conv2d_nhwc", func_name="main") +sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") +b1 = sch.reindex(block=b0, buffer=("write", 0)) +b2 = sch.reindex(block=b0, buffer=("read", 0)) +b3 = sch.reindex(block=b0, buffer=("read", 1)) +sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda h, w, co: (((h*16) + w), co, )) +sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda co, rh, rw, rc: ((((rh*96) + (rw*32)) + rc), co, )) +sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda h, w, rh, rw, rc: (((h*16) + w), (((rh*96) + (rw*32)) + rc), )) +sch.transform_block_layout(block=b1, index_map=lambda n, h, w, co, rh, rw, rc: (n, ((h*16) + w), co, (((rh*96) + (rw*32)) + rc), )) +sch.transform_block_layout(block=b2, index_map=lambda n, h, w, co, rh, rw, rc: (n, ((h*16) + w), co, (((rh*96) + (rw*32)) + rc), )) +sch.transform_block_layout(block=b3, index_map=lambda n, h, w, co, rh, rw, rc: (n, ((h*16) + w), co, (((rh*96) + (rw*32)) + rc), )) +sch.transform_block_layout(block=b0, index_map=lambda n, h, w, co, rh, rw, rc: (n, ((h*16) + w), co, (((rh*96) + (rw*32)) + rc), )) +l4, l5, l6, l7 = sch.get_loops(block=b0) +l8, l9 = sch.split(loop=l7, factors=[None, 16], preserve_unit_iters=True) +l10, l11 = sch.split(loop=l6, factors=[None, 16], preserve_unit_iters=True) +l12, l13 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True) +l14, l15, l16, l17, l18, l19, l20 = sch.get_loops(block=b0) +sch.reorder(l17, l19, l13, l11, l9) +b21 = sch.blockize(loop=l13) +sch.annotate(block_or_loop=b21, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_sync_16x16x16_f16f16f32") +sch.annotate(block_or_loop=b21, ann_key="meta_schedule.auto_tensorize_init", ann_val="wmma_fill_16x16x16_f32") +sch.annotate(block_or_loop=b21, ann_key="warp_execution", ann_val=1) +l22, l23, l24, l25 = sch.get_loops(block=b21) +v26, v27, v28, v29, v30 = sch.sample_perfect_tile(loop=l22, n=5, max_innermost_factor=4) +l31, l32, l33, l34, l35 = sch.split(loop=l22, factors=[v26, v27, v28, v29, v30], preserve_unit_iters=True) +v36, v37, v38, v39, v40 = sch.sample_perfect_tile(loop=l23, n=5, max_innermost_factor=4) +l41, l42, l43, l44, l45 = sch.split(loop=l23, factors=[v36, v37, v38, v39, v40], preserve_unit_iters=True) +v46, v47, v48, v49, v50 = sch.sample_perfect_tile(loop=l24, n=5, max_innermost_factor=4) +l51, l52, l53, l54, l55 = sch.split(loop=l24, factors=[v46, v47, v48, v49, v50], preserve_unit_iters=True) +v56, v57, v58 = sch.sample_perfect_tile(loop=l25, n=3, max_innermost_factor=4) +l59, l60, l61 = sch.split(loop=l25, factors=[v56, v57, v58], preserve_unit_iters=True) +sch.reorder(l31, l41, l51, l32, l42, l52, l33, l43, l53, l59, l60, l34, l44, l54, l61, l35, l45, l55) +l62 = sch.fuse(l31, l41, l51, preserve_unit_iters=True) +sch.bind(loop=l62, thread_axis="blockIdx.y") +l63 = sch.fuse(l32, l42, l52, preserve_unit_iters=True) +sch.bind(loop=l63, thread_axis="blockIdx.x") +l64 = sch.fuse(l33, l43, l53, preserve_unit_iters=True) +sch.bind(loop=l64, thread_axis="threadIdx.y") +b65 = sch.cache_write(block=b21, write_buffer_index=0, storage_scope="shared") +sch.reverse_compute_at(block=b65, loop=l63, preserve_unit_loops=True) +b66 = sch.cache_write(block=b21, write_buffer_index=0, storage_scope="wmma.accumulator") +sch.reverse_compute_at(block=b66, loop=l64, preserve_unit_loops=True) +v67 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25]) +sch.annotate(block_or_loop=b65, ann_key="meta_schedule.cooperative_fetch", ann_val=v67) +sch.reverse_compute_inline(block=b1) +l68, l69, l70, l71, l72 = sch.get_loops(block=b66) +l73, l74 = sch.split(loop=l72, factors=[None, 16], preserve_unit_iters=True) +l75, l76 = sch.split(loop=l71, factors=[None, 16], preserve_unit_iters=True) +l77, l78, l79, l80, l81, l82, l83 = sch.get_loops(block=b66) +sch.reorder(l82, l76, l74) +b84 = sch.blockize(loop=l76) +sch.annotate(block_or_loop=b84, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_shared") +b85 = sch.cache_read(block=b21, read_buffer_index=0, storage_scope="shared") +sch.compute_at(block=b85, loop=l59, preserve_unit_loops=True) +l86, l87, l88, l89, l90, l91 = sch.get_loops(block=b85) +l92 = sch.fuse(l90, l91, preserve_unit_iters=True) +v93 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25]) +sch.annotate(block_or_loop=b85, ann_key="meta_schedule.cooperative_fetch", ann_val=v93) +b94 = sch.cache_read(block=b21, read_buffer_index=1, storage_scope="shared") +sch.compute_at(block=b94, loop=l59, preserve_unit_loops=True) +l95, l96, l97, l98, l99, l100 = sch.get_loops(block=b94) +l101 = sch.fuse(l99, l100, preserve_unit_iters=True) +v102 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25]) +sch.annotate(block_or_loop=b94, ann_key="meta_schedule.cooperative_fetch", ann_val=v102) +b103 = sch.cache_read(block=b21, read_buffer_index=0, storage_scope="wmma.matrix_a") +sch.compute_at(block=b103, loop=l60, preserve_unit_loops=True) +l104, l105, l106, l107, l108, l109, l110 = sch.get_loops(block=b103) +l111, l112 = sch.split(loop=l110, factors=[None, 16], preserve_unit_iters=True) +l113, l114 = sch.split(loop=l109, factors=[None, 16], preserve_unit_iters=True) +l115, l116, l117, l118, l119, l120, l121, l122, l123 = sch.get_loops(block=b103) +sch.reorder(l122, l114, l112) +b124 = sch.blockize(loop=l114) +sch.annotate(block_or_loop=b124, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a") +b125 = sch.cache_read(block=b21, read_buffer_index=1, storage_scope="wmma.matrix_b") +sch.compute_at(block=b125, loop=l60, preserve_unit_loops=True) +l126, l127, l128, l129, l130, l131, l132 = sch.get_loops(block=b125) +l133, l134 = sch.split(loop=l132, factors=[None, 16], preserve_unit_iters=True) +l135, l136 = sch.split(loop=l131, factors=[None, 16], preserve_unit_iters=True) +l137, l138, l139, l140, l141, l142, l143, l144, l145 = sch.get_loops(block=b125) +sch.reorder(l144, l136, l134) +b146 = sch.blockize(loop=l136) +sch.annotate(block_or_loop=b146, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_b") +sch.compute_inline(block=b2) +sch.compute_inline(block=b3) +sch.storage_align(block=b85, buffer_index=0, axis=-2, factor=32, offset=8) +sch.storage_align(block=b94, buffer_index=0, axis=-2, factor=32, offset=8)""".split( + "\n" + ) + ] + check_trace(spaces, expected) + + if __name__ == "__main__": - test_cpu_matmul() - test_cpu_matmul_relu() - test_cuda_matmul() - test_cuda_matmul_relu() - test_cuda_sum_with_trivial_block_iter() - test_multi_level_tiling_conv2d_nchwc_vnni() - test_multi_level_tiling_dense_dpa4() + tvm.testing.main()