diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 8313da067f09..6ee394791991 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -115,7 +115,6 @@ class ScheduleRule : public runtime::ObjectRef { * \brief Create an auto-inline rule that inlines spatial blocks if it satisfies some conditions * \param into_producer If allows to inline a block into its producer * \param into_consumer If allows to inline a block into its consumer - * \param into_cache_only If it only allows to inline into a block generated by cache_read/write * \param inline_const_tensor Always inline constant tensors * \param disallow_if_then_else Always disallow if-then-else-like constructs * \param require_ordered Always require the read-to-write mapping to be ordered @@ -125,7 +124,6 @@ class ScheduleRule : public runtime::ObjectRef { */ TVM_DLL static ScheduleRule AutoInline(bool into_producer, // bool into_consumer, // - bool into_cache_only, // bool inline_const_tensor, // bool disallow_if_then_else, // bool require_injective, // diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py b/python/tvm/meta_schedule/schedule_rule/__init__.py index b90780d5bfdb..be5c0e0b620b 100644 --- a/python/tvm/meta_schedule/schedule_rule/__init__.py +++ b/python/tvm/meta_schedule/schedule_rule/__init__.py @@ -16,4 +16,5 @@ Meta Schedule schedule rules are used for modification of blocks in a schedule. See also PostOrderApply. """ +from .auto_inline import AutoInline from .schedule_rule import PyScheduleRule, ScheduleRule diff --git a/python/tvm/meta_schedule/schedule_rule/auto_inline.py b/python/tvm/meta_schedule/schedule_rule/auto_inline.py new file mode 100644 index 000000000000..22206f3fcc24 --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/auto_inline.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Auto-Inline. Rule that inlines spatial blocks if it satisfies some conditions""" +from typing import List, Optional + +from tvm._ffi import register_object + +from .. import _ffi_api +from .schedule_rule import ScheduleRule + + +@register_object("meta_schedule.AutoInline") +class AutoInline(ScheduleRule): + """Rule that inlines spatial blocks if it satisfies some conditions + + Parameters + ---------- + into_producer : bool + If allows to inline a block into its producer + into_consumer : bool + If allows to inline a block into its consumer + inline_const_tensor : bool + Always inline constant tensors + disallow_if_then_else : bool + Always disallow if-then-else-like constructs + require_injective : bool + Always require the read-to-write mapping to be ordered + require_ordered : bool + Always require the read-to-write mapping to be injective + disallow_op : Optional[List[str]] + The operators that are disallowed in auto inline + """ + + def __init__( + self, + into_producer: bool, + into_consumer: bool, + inline_const_tensor: bool, + disallow_if_then_else: bool, + require_injective: bool, + require_ordered: bool, + disallow_op: Optional[List[str]] = None, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleAutoInline, # type: ignore # pylint: disable=no-member + into_producer, + into_consumer, + inline_const_tensor, + disallow_if_then_else, + require_injective, + require_ordered, + disallow_op, + ) diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py new file mode 100644 index 000000000000..e69be1333092 --- /dev/null +++ b/python/tvm/meta_schedule/testing/schedule_rule.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Default schedule rules""" +from tvm.meta_schedule.schedule_rule import ( + AutoInline, + ScheduleRule, +) +from tvm.target import Target + + +def auto_inline(target: Target) -> ScheduleRule: + """Default schedule rules for auto inline""" + if target.kind.name == "llvm": + return AutoInline( + into_producer=False, + into_consumer=True, + inline_const_tensor=True, + disallow_if_then_else=True, + require_injective=True, + require_ordered=True, + disallow_op=["tir.exp"], + ) + if target.kind.name == "cuda": + return AutoInline( + into_producer=True, + into_consumer=True, + inline_const_tensor=True, + disallow_if_then_else=False, + require_injective=False, + require_ordered=False, + disallow_op=None, + ) + raise NotImplementedError(f"{target.kind.name} is not supported") diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc new file mode 100644 index 000000000000..38156f86e6cb --- /dev/null +++ b/src/meta_schedule/schedule_rule/auto_inline.cc @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief The type of inline to be performed on a specific block */ +enum class InlineType : int32_t { + /*! \brief No inline opportunity */ + kNoInline = 0, + /*! \brief Inline the block into its consumer */ + kInlineIntoConsumer = 1, + /*! \brief Inline the block into its producer */ + kInlineIntoProducer = 2, +}; + +/*! \brief The rule that inlines spatial blocks if it satisfies some conditions. */ +class AutoInlineNode : public ScheduleRuleNode { + public: + /*! \brief Checks if the specific block should be inlined */ + inline InlineType CheckInline(const tir::Schedule& sch, const tir::BlockRV& block_rv); + + // Inherited from ScheduleRuleNode + void InitializeWithTuneContext(const TuneContext& context) final {} + + // Inherited from ScheduleRuleNode + Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + InlineType inline_type = CheckInline(sch, block_rv); + if (inline_type == InlineType::kInlineIntoConsumer) { + sch->ComputeInline(block_rv); + } else if (inline_type == InlineType::kInlineIntoProducer) { + sch->ReverseComputeInline(block_rv); + } + return {sch}; + } + + public: + /*! \brief If allows to inline a block into its producer */ + bool into_producer; + /*! \brief If allows to inline a block into its consumer */ + bool into_consumer; + /*! \brief Always inline constant tensors */ + bool inline_const_tensor; + /*! \brief Always disallow if-then-else-like constructs */ + bool disallow_if_then_else; + /*! \brief Always require the read-to-write mapping to be injective to do auto inline */ + bool require_injective; + /*! \brief Always require the read-to-write mapping to be ordered to do auto inline */ + bool require_ordered; + /*! \brief The operators that are disallowed in auto inline */ + Array disallow_op; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("into_producer", &into_producer); + v->Visit("into_consumer", &into_consumer); + v->Visit("inline_const_tensor", &inline_const_tensor); + v->Visit("disallow_if_then_else", &disallow_if_then_else); + v->Visit("require_injective", &require_injective); + v->Visit("require_ordered", &require_ordered); + v->Visit("disallow_op", &disallow_op); + } + + static constexpr const char* _type_key = "meta_schedule.AutoInline"; + TVM_DECLARE_FINAL_OBJECT_INFO(AutoInlineNode, ScheduleRuleNode); +}; + +inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch, + const tir::BlockRV& block_rv) { + using namespace tvm::tir; + StmtSRef block_sref = sch->GetSRef(block_rv); + ScheduleState state = sch->state(); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + BlockRealize realize = GetBlockRealize(state, block_sref); + // Cond 1. The block has only one write buffer + if (block->writes.size() != 1) { + return InlineType::kNoInline; + } + // Cond 2. For a block that generates a constant tensor, ignore all other conditions + if (inline_const_tensor && block->reads.empty()) { + return InlineType::kInlineIntoConsumer; + } + // Cond 3. The block doesn't contain any disallowed operators + if (!disallow_op.empty() && HasOp(realize, disallow_op)) { + return InlineType::kNoInline; + } + // Cond 4. The block doesn't have any if-then-else-like constructs + if (disallow_if_then_else && HasIfThenElse(realize)) { + return InlineType::kNoInline; + } + // Cond 5. The mapping from read indices to write indices are injective and ordered + if (require_injective || require_ordered) { + const BufferRegion& write_region = block->writes[0]; + for (const BufferRegion& read_region : block->reads) { + bool injective, ordered; + auto _ = std::ignore; + std::tie(/*exists=*/_, /*surjective=*/_, injective, ordered, /*no_const_read=*/_, + /*no_shift_read=*/_) = AnalyzeReadWritePattern(read_region, write_region); + if (require_injective && injective == false) { + return InlineType::kNoInline; + } + if (require_ordered && ordered == false) { + return InlineType::kNoInline; + } + } + } + // Last cond: Check inline into the consumers or the spatial producer + tir::StmtSRef scope_block = tir::GetScopeRoot(sch->state(), block_sref, // + /*require_stage_pipeline=*/false, // + /*require_subtree_compact_dataflow=*/false); + if (into_consumer) { + Array consumer_srefs = GetConsumers(state, block_sref); + if (!consumer_srefs.empty() && CanComputeInline(state, block_sref)) { + return InlineType::kInlineIntoConsumer; + } + } + if (into_producer) { + Array producer_srefs = GetProducers(state, block_sref); + if (producer_srefs.size() == 1 && + tir::IsCompleteBlock(sch->state(), producer_srefs[0], scope_block) && + CanReverseComputeInline(state, block_sref)) { + return InlineType::kInlineIntoProducer; + } + } + return InlineType::kNoInline; +} + +ScheduleRule ScheduleRule::AutoInline(bool into_producer, // + bool into_consumer, // + bool inline_const_tensor, // + bool disallow_if_then_else, // + bool require_injective, // + bool require_ordered, // + Optional> disallow_op) { + ObjectPtr n = make_object(); + n->into_producer = into_producer; + n->into_consumer = into_consumer; + n->inline_const_tensor = inline_const_tensor; + n->disallow_if_then_else = disallow_if_then_else; + n->require_injective = require_injective; + n->require_ordered = require_ordered; + n->disallow_op.clear(); + if (disallow_op.defined()) { + Array op_names = disallow_op.value(); + n->disallow_op.reserve(op_names.size()); + for (const String& op_name : op_names) { + n->disallow_op.push_back(Op::Get(op_name)); + } + } + return ScheduleRule(n); +} + +TVM_REGISTER_NODE_TYPE(AutoInlineNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoInline") + .set_body_typed(ScheduleRule::AutoInline); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index ae72d592339f..1070833be19d 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -20,6 +20,7 @@ #define TVM_TIR_SCHEDULE_ANALYSIS_H_ #include +#include #include #include @@ -442,6 +443,50 @@ bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const S bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref, bool preserve_unit_loops); +/*! + * \brief Checks if the given AST contains the specific operators + * \param stmt The AST statement to be checked + * \param ops The list of operators to be checked + * \return A boolean indicating whether the AST contains the specific operators + */ +bool HasOp(const Stmt& stmt, const Array& ops); + +/*! + * \brief Checks if the given AST statement contains if-then-else, including + * 1) IfThenElse statement + * 2) Select expression + * 3) The operator `tir.if_then_else` + * 4) non-constant-true Block predicates + * \param stmt The AST statement to be checked + * \return A boolean indicating whether the statement contains the if-then-else pattern + */ +bool HasIfThenElse(const Stmt& stmt); + +/*! + * \brief Given the read/write region, extract the pattern of their index correspondence + * namely, the mapping from read index to the write index. + * \param read_region The read region + * \param write_region The write region + * \return A tuple of booleans, the extracted pattern + * 0) exists: if the pattern is found + * 1) surjective: if the pattern is surjective, i.e. each write index is mapped at least once + * e.g. A[i, j] = B[i, i, j] + * 2) injective: if the pattern is injective, i.e. each write index is mapped at most once. + * e.g. A[i, j] = B[i] + * 3) ordered: if the mapping is ordered + * 4) no_const_read: if there is no constant indexing in the read indices, + * e.g. A[i, j] = B[0, i, j] + * 5) no_shift_read: if there is no constant shift in the read indices, + * e.g. A[i, j] = B[i + 1, j] + */ +std::tuple +AnalyzeReadWritePattern(const BufferRegion& read_region, const BufferRegion& write_region); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 0a7d57effd0d..36a1d05f4cf2 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -1345,6 +1345,139 @@ StmtSRef GetSRefTreeRoot(const StmtSRef& sref) { return GetRef(p); } +/******** Misc ********/ + +bool HasOp(const Stmt& stmt, const Array& ops) { + std::unordered_set op_set; + op_set.reserve(ops.size()); + for (const Op& op : ops) { + op_set.insert(op.operator->()); + } + bool found = false; + PreOrderVisit(stmt, [&found, &op_set](const ObjectRef& obj) -> bool { + if (found) { + return false; + } + if (const auto* call = obj.as()) { + if (op_set.count(call->op.operator->())) { + found = true; + } + } + return !found; + }); + return found; +} + +bool HasIfThenElse(const Stmt& stmt) { + bool has_branch = false; + auto f_visit = [&has_branch](const ObjectRef& obj) -> bool { + if (has_branch) { + // stop visiting + return false; + } + if (const auto* realize = obj.as()) { + // Case 1: BlockRealize + if (!is_one(realize->predicate)) { + has_branch = true; + } + } else if (obj->IsInstance() || obj->IsInstance()) { + // Case 2: IfThenElse / Select + has_branch = true; + } else if (const auto* call = obj.as()) { + // Case 3: Call the `if_then_else` operator + static const Op& op_if_then_else = Op::Get("tir.if_then_else"); + if (call->op.same_as(op_if_then_else)) { + has_branch = true; + } + } + return !has_branch; + }; + PreOrderVisit(stmt, f_visit); + return has_branch; +} + +std::tuple +AnalyzeReadWritePattern(const BufferRegion& read_region, const BufferRegion& write_region) { + static constexpr const std::tuple kNotExist = + std::make_tuple(false, false, false, false, false, false); + // Step 1. Extract the write indices + int w_dim = write_region->buffer->shape.size(); + std::unordered_map var2idx; + var2idx.reserve(w_dim); + for (int i = 0; i < w_dim; ++i) { + const Range& dom = write_region->region[i]; + if (as_const_int(dom->extent) == nullptr) { + return kNotExist; + } + if (const auto* v = dom->min.as()) { + var2idx.emplace(v, i); + } else { + return kNotExist; + } + } + // Step 2. Map each read index to a write index + bool no_const_read = true; + bool no_shift_read = true; + int r_dim = read_region->buffer->shape.size(); + std::vector mapped(r_dim, -1); + for (int i = 0; i < r_dim; ++i) { + const Range& dom = read_region->region[i]; + if (as_const_int(dom->extent) == nullptr) { + return kNotExist; + } + // Case 1. Read index is a constant + if (as_const_int(dom->min) != nullptr) { + no_const_read = false; + continue; + } + // Case 2. Read index cannot be recognized as `var +/- const` + // where `var` is a write index and `const` is an optional constant shift + Optional opt_const = NullOpt; + const VarNode* var = + static_cast(AnalyzeVarWithShift(dom->min, &opt_const).get()); + if (var == nullptr || !var2idx.count(var)) { + return kNotExist; + } + // Case 3. Read index is `var +/- const` + mapped[i] = var2idx.at(var); + if (opt_const.defined()) { + no_shift_read = false; + } + } + // Step 3. Check if the mapping is ordered, and count how many times each var is mapped + std::vector mapped_counter(w_dim, 0); + bool ordered = true; + int last_mapped = -1; + for (int i : mapped) { + if (i != -1) { + ++mapped_counter[i]; + if (last_mapped != -1 && last_mapped > i) { + ordered = false; + } + last_mapped = i; + } + } + // Step 4. Check if the mapping is surjective or injective + // Surjective: each write index is mapped at least once + // Injective: each write index is mapped at most once + bool surjective = true; + bool injective = true; + for (int cnt : mapped_counter) { + if (cnt == 0) { + surjective = false; + } else if (cnt >= 2) { + injective = false; + } + } + return std::make_tuple(/*exist=*/true, surjective, injective, ordered, no_const_read, + no_shift_read); +} + /******** Storage Scope ********/ void CheckStorageScope(const ScheduleState& self, String storage_scope) { diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index 860b3f64b5dc..4df335079f93 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -229,6 +229,34 @@ inline const int64_t* GetLoopIntExtent(const StmtSRef& loop_sref) { return as_const_int(loop->extent); } +/*! + * \brief Check if an expression consists of a single variable, + * or a variable plus/minus an constant integer shift + * \param expr The expression to be checked + * \return The single variable in the expression, or NullOpt if the expression is neither a variable + * or a constant shift from a variable + */ +inline Optional AnalyzeVarWithShift(const PrimExpr& expr, Optional* constant) { + if (const auto* var = expr.as()) { + *constant = NullOpt; + return GetRef(var); + } + arith::PVar var; + arith::PVar shift; + // match: "var + shift" + if ((var + shift).Match(expr) || (shift + var).Match(expr)) { + *constant = shift.Eval(); + return var.Eval(); + } + // match: "var - shift" + if ((var - shift).Match(expr)) { + IntImm result = shift.Eval(); + *constant = IntImm(result->dtype, -result->value); + return var.Eval(); + } + return NullOpt; +} + /******** Annotation ********/ /*! diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py new file mode 100644 index 000000000000..e206fcc4502c --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py @@ -0,0 +1,300 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import tvm +from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply +from tvm.meta_schedule.testing.schedule_rule import auto_inline +from tvm.meta_schedule.tune_context import TuneContext +from tvm.script import tir as T +from tvm.target import Target + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks + +@tvm.script.ir_module +class Conv2DBiasBnReLU: + @T.prim_func + def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.handle, var_bn_offset: T.handle, var_compute: T.handle) -> None: + X = T.match_buffer(var_X, [1, 512, 56, 56], dtype="float32") + W = T.match_buffer(var_W, [512, 512, 3, 3], dtype="float32") + B = T.match_buffer(var_B, [512, 1, 1], dtype="float32") + bn_scale = T.match_buffer(var_bn_scale, [512, 1, 1], dtype="float32") + bn_offset = T.match_buffer(var_bn_offset, [512, 1, 1], dtype="float32") + compute = T.match_buffer(var_compute, [1, 512, 56, 56], dtype="float32") + pad_temp = T.alloc_buffer([1, 512, 58, 58], dtype="float32") + compute_1 = T.alloc_buffer([1, 512, 56, 56], dtype="float32") + bias_add = T.alloc_buffer([1, 512, 56, 56], dtype="float32") + bn_mul = T.alloc_buffer([1, 512, 56, 56], dtype="float32") + bn_add = T.alloc_buffer([1, 512, 56, 56], dtype="float32") + for i0, i1, i2, i3 in T.grid(1, 512, 58, 58): + with T.block("pad_temp"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + pad_temp[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(i2_1 >= 1 and i2_1 < 57 and i3_1 >= 1 and i3_1 < 57, X[i0_1, i1_1, i2_1 - 1, i3_1 - 1], T.float32(0), dtype="float32") + for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 512, 56, 56, 512, 3, 3): + with T.block("compute"): + nn, ff, yy, xx, rc, ry, rx = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + with T.init(): + compute_1[nn, ff, yy, xx] = T.float32(0) + compute_1[nn, ff, yy, xx] = compute_1[nn, ff, yy, xx] + pad_temp[nn, rc, yy + ry, xx + rx] * W[ff, rc, ry, rx] + for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): + with T.block("bias_add"): + i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3]) + bias_add[i, j, k, l] = compute_1[i, j, k, l] + B[j, 0, 0] + for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): + with T.block("bn_mul"): + i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3]) + bn_mul[i, j, k, l] = bias_add[i, j, k, l] * bn_scale[j, 0, 0] + for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): + with T.block("bn_add"): + i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3]) + bn_add[i, j, k, l] = bn_mul[i, j, k, l] + bn_offset[j, 0, 0] + for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): + with T.block("compute_1"): + i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + compute[i0_2, i1_2, i2_2, i3_2] = T.max(bn_add[i0_2, i1_2, i2_2, i3_2], T.float32(0)) + + +@tvm.script.ir_module +class Conv2DBiasBnReLUInlined: + @T.prim_func + def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.handle, var_bn_offset: T.handle, var_compute: T.handle) -> None: + X = T.match_buffer(var_X, [1, 512, 56, 56], dtype="float32") + W = T.match_buffer(var_W, [512, 512, 3, 3], dtype="float32") + B = T.match_buffer(var_B, [512, 1, 1], dtype="float32") + bn_scale = T.match_buffer(var_bn_scale, [512, 1, 1], dtype="float32") + bn_offset = T.match_buffer(var_bn_offset, [512, 1, 1], dtype="float32") + compute = T.match_buffer(var_compute, [1, 512, 56, 56], dtype="float32") + pad_temp = T.alloc_buffer([1, 512, 58, 58], dtype="float32") + compute_1 = T.alloc_buffer([1, 512, 56, 56], dtype="float32") + for i0, i1, i2, i3 in T.grid(1, 512, 58, 58): + with T.block("pad_temp"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + pad_temp[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(i2_1 >= 1 and i2_1 < 57 and i3_1 >= 1 and i3_1 < 57, X[i0_1, i1_1, i2_1 - 1, i3_1 - 1], T.float32(0), dtype="float32") + for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 512, 56, 56, 512, 3, 3): + with T.block("compute"): + nn, ff, yy, xx, rc, ry, rx = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + with T.init(): + compute_1[nn, ff, yy, xx] = T.float32(0) + compute_1[nn, ff, yy, xx] = compute_1[nn, ff, yy, xx] + pad_temp[nn, rc, yy + ry, xx + rx] * W[ff, rc, ry, rx] + for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): + with T.block("compute_1"): + i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + compute[i0_2, i1_2, i2_2, i3_2] = T.max((compute_1[i0_2, i1_2, i2_2, i3_2] + B[i1_2, 0, 0]) * bn_scale[i1_2, 0, 0] + bn_offset[i1_2, 0, 0], T.float32(0)) + + +@tvm.script.ir_module +class MultiLevelTiledConv2D: + @T.prim_func + def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.handle, var_bn_offset: T.handle, var_compute: T.handle) -> None: + X = T.match_buffer(var_X, [1, 512, 56, 56], dtype="float32") + W = T.match_buffer(var_W, [512, 512, 3, 3], dtype="float32") + B = T.match_buffer(var_B, [512, 1, 1], dtype="float32") + bn_scale = T.match_buffer(var_bn_scale, [512, 1, 1], dtype="float32") + bn_offset = T.match_buffer(var_bn_offset, [512, 1, 1], dtype="float32") + compute = T.match_buffer(var_compute, [1, 512, 56, 56], dtype="float32") + pad_temp = T.alloc_buffer([1, 512, 58, 58], dtype="float32") + compute_1 = T.alloc_buffer([1, 512, 56, 56], dtype="float32") + compute_local = T.alloc_buffer([1, 512, 56, 56], dtype="float32", scope="local") + pad_temp_shared = T.alloc_buffer([1, 512, 58, 58], dtype="float32", scope="shared") + W_shared = T.alloc_buffer([512, 512, 3, 3], dtype="float32", scope="shared") + for i0, i1, i2, i3 in T.grid(1, 512, 58, 58): + with T.block("pad_temp"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + pad_temp[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(i2_1 >= 1 and i2_1 < 57 and i3_1 >= 1 and i3_1 < 57, X[i0_1, i1_1, i2_1 - 1, i3_1 - 1], T.float32(0), dtype="float32") + for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(0, 224, thread="blockIdx.x"): + for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(0, 2, thread="vthread.x"): + for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(0, 8, thread="threadIdx.x"): + for i4_0, i5_0, i6_0 in T.grid(1, 3, 1): + for ax0_ax1_ax2_ax3_fused_0 in T.serial(0, 40960, annotations={"meta_schedule.cooperative_fetch":1}): + for ax0_ax1_ax2_ax3_fused_1 in T.vectorized(0, 3): + with T.block("pad_temp_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(512, (ax0_ax1_ax2_ax3_fused_0 * 3 + ax0_ax1_ax2_ax3_fused_1) // 30 // 8 % 512) + v2 = T.axis.spatial(58, i0_0_i1_0_i2_0_i3_0_fused % 14 // 2 * 8 + i5_0 + (ax0_ax1_ax2_ax3_fused_0 * 3 + ax0_ax1_ax2_ax3_fused_1) // 30 % 8) + v3 = T.axis.spatial(58, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + (ax0_ax1_ax2_ax3_fused_0 * 3 + ax0_ax1_ax2_ax3_fused_1) % 30) + pad_temp_shared[v0, v1, v2, v3] = pad_temp[v0, v1, v2, v3] + for ax0_ax1_ax2_ax3_fused_0 in T.serial(0, 12288, annotations={"meta_schedule.cooperative_fetch":1}): + for ax0_ax1_ax2_ax3_fused_1 in T.vectorized(0, 4): + with T.block("W_shared"): + v0 = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) // 1536) + v1 = T.axis.spatial(512, (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) // 3 % 512) + v2 = T.axis.spatial(3, i5_0) + v3 = T.axis.spatial(3, (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) % 3) + W_shared[v0, v1, v2, v3] = W[v0, v1, v2, v3] + for i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid(32, 1, 1, 1, 1, 1, 1, 16, 1, 3, 1, 8, 2, 28): + with T.block("compute"): + nn = T.axis.spatial(1, 0) + ff = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + i1_4) + yy = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused // 2 % 7 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + i2_4) + xx = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + i3_4) + rc = T.axis.reduce(512, i4_1 * 16 + i4_2) + ry, rx = T.axis.remap("RR", [i5_0, i6_2]) + with T.init(): + compute_local[nn, ff, yy, xx] = T.float32(0) + compute_local[nn, ff, yy, xx] = compute_local[nn, ff, yy, xx] + pad_temp_shared[nn, rc, yy + ry, xx + rx] * W_shared[ff, rc, ry, rx] + for ax0, ax1, ax2, ax3 in T.grid(1, 8, 2, 28): + with T.block("compute_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + ax1) + v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 14 // 2 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + ax2) + v3 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + ax3) + compute_1[v0, v1, v2, v3] = compute_local[v0, v1, v2, v3] + for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): + with T.block("compute_1"): + i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + compute[i0_2, i1_2, i2_2, i3_2] = T.max((compute_1[i0_2, i1_2, i2_2, i3_2] + B[i1_2, 0, 0]) * bn_scale[i1_2, 0, 0] + bn_offset[i1_2, 0, 0], T.float32(0)) + + +@tvm.script.ir_module +class MultiLevelTiledConv2DAfterInline: + @T.prim_func + def main(X: T.Buffer[(1, 512, 56, 56), "float32"], W: T.Buffer[(512, 512, 3, 3), "float32"], B: T.Buffer[(512, 1, 1), "float32"], bn_scale: T.Buffer[(512, 1, 1), "float32"], bn_offset: T.Buffer[(512, 1, 1), "float32"], compute: T.Buffer[(1, 512, 56, 56), "float32"]) -> None: + compute_local = T.alloc_buffer([1, 512, 56, 56], dtype="float32", scope="local") + for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(224, thread="blockIdx.x"): + for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(2, thread="vthread.x"): + for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(8, thread="threadIdx.x"): + for i4_0, i5_0, i6_0, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid(1, 3, 1, 32, 1, 1, 1, 1, 1, 1, 16, 1, 3, 1, 8, 2, 28): + with T.block("compute"): + nn = T.axis.spatial(1, 0) + ff = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + i1_4) + yy = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused // 2 % 7 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + i2_4) + xx = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + i3_4) + rc = T.axis.reduce(512, i4_1 * 16 + i4_2) + ry, rx = T.axis.remap("RR", [i5_0, i6_2]) + with T.init(): + compute_local[nn, ff, yy, xx] = T.float32(0) + compute_local[nn, ff, yy, xx] = compute_local[nn, ff, yy, xx] + T.if_then_else(yy + ry >= 1 and yy + ry < 57 and xx + rx >= 1 and xx + rx < 57, X[nn, rc, yy + ry - 1, xx + rx - 1], T.float32(0), dtype="float32") * W[ff, rc, ry, rx] + for ax0, ax1, ax2, ax3 in T.grid(1, 8, 2, 28): + with T.block("compute_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + ax1) + v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 14 // 2 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + ax2) + v3 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + ax3) + compute[v0, v1, v2, v3] = T.max((compute_local[v0, v1, v2, v3] + B[v1, 0, 0]) * bn_scale[v1, 0, 0] + bn_offset[v1, 0, 0], T.float32(0)) + + +@tvm.script.ir_module +class SoftmaxBeforeInline: + @T.prim_func + def main(A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None: + T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") + T_softmax_exp = T.alloc_buffer([256, 256], dtype="float32") + T_softmax_expsum = T.alloc_buffer([256], dtype="float32") + for i0, i1 in T.grid(256, 256): + with T.block("T_softmax_maxelem"): + i0_1, k = T.axis.remap("SR", [i0, i1]) + with T.init(): + T_softmax_maxelem[i0_1] = T.min_value("float32") + T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k]) + for i0, i1 in T.grid(256, 256): + with T.block("T_softmax_exp"): + i0_2, i1_1 = T.axis.remap("SS", [i0, i1]) + T_softmax_exp[i0_2, i1_1] = T.exp(A[i0_2, i1_1] - T_softmax_maxelem[i0_2], dtype="float32") + for i0_3, i1 in T.grid(256, 256): + with T.block("T_softmax_expsum"): + i0_4, k = T.axis.remap("SR", [i0_3, i1]) + with T.init(): + T_softmax_expsum[i0_4] = T.float32(0) + T_softmax_expsum[i0_4] = T_softmax_expsum[i0_4] + T_softmax_exp[i0_4, k] + for i0_5, i1 in T.grid(256, 256): + with T.block("T_softmax_norm"): + i0_6, i1_2 = T.axis.remap("SS", [i0_5, i1]) + T_softmax_norm[i0_6, i1_2] = T_softmax_exp[i0_6, i1_2] / T_softmax_expsum[i0_6] + + +@tvm.script.ir_module +class SoftmaxAfterInline: + @T.prim_func + def main(A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None: + T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") + T_softmax_expsum = T.alloc_buffer([256], dtype="float32") + for i0, i1 in T.grid(256, 256): + with T.block("T_softmax_maxelem"): + i0_1, k = T.axis.remap("SR", [i0, i1]) + with T.init(): + T_softmax_maxelem[i0_1] = T.min_value("float32") + T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k]) + for i0, i1 in T.grid(256, 256): + with T.block("T_softmax_expsum"): + i0_2, k = T.axis.remap("SR", [i0, i1]) + with T.init(): + T_softmax_expsum[i0_2] = T.float32(0) + T_softmax_expsum[i0_2] = T_softmax_expsum[i0_2] + T.exp(A[i0_2, k] - T_softmax_maxelem[i0_2], dtype="float32") + for i0_3, i1 in T.grid(256, 256): + with T.block("T_softmax_norm"): + i0_4, i1_1 = T.axis.remap("SS", [i0_3, i1]) + T_softmax_norm[i0_4, i1_1] = T.exp(A[i0_4, i1_1] - T_softmax_maxelem[i0_4], dtype="float32") / T_softmax_expsum[i0_4] + + +# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks +# fmt: on + + +def _create_context(mod, target, rule): + ctx = TuneContext( + mod=mod, + target=target, + space_generator=PostOrderApply(), + sch_rules=[rule], + task_name="test", + ) + ctx.space_generator.initialize_with_tune_context(ctx) + for sch_rule in ctx.sch_rules: + sch_rule.initialize_with_tune_context(ctx) + return ctx + + +def test_inline_consumer_chain(): + mod = Conv2DBiasBnReLU + target = Target("llvm") + ctx = _create_context( + mod=mod, + target=target, + rule=auto_inline(target=target), + ) + (space,) = ctx.space_generator.generate_design_space(mod=mod) + tvm.ir.assert_structural_equal(lhs=space.mod, rhs=Conv2DBiasBnReLUInlined) + + +def test_inline_into_cache(): + mod = MultiLevelTiledConv2D + target = Target("cuda", host="llvm") + ctx = _create_context( + mod=mod, + target=target, + rule=auto_inline(target=target), + ) + (space,) = ctx.space_generator.generate_design_space(mod=mod) + tvm.ir.assert_structural_equal(lhs=space.mod, rhs=MultiLevelTiledConv2DAfterInline) + + +def test_inline_into_multiple_consumers(): + mod = SoftmaxBeforeInline + target = Target("cuda", host="llvm") + ctx = _create_context( + mod=mod, + target=target, + rule=auto_inline(target=target), + ) + (space,) = ctx.space_generator.generate_design_space(mod=mod) + tvm.ir.assert_structural_equal(lhs=space.mod, rhs=SoftmaxAfterInline) + + +if __name__ == "__main__": + test_inline_consumer_chain() + test_inline_into_cache() + test_inline_into_multiple_consumers()