From 6c86365187af2de92150f1046b05a08c94f495d6 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 12 Aug 2021 17:47:21 +0000 Subject: [PATCH] [TensorIR][M2a] Parallel, Vectorize, Bind & Unroll Co-authored-by: Siyuan Feng Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Junru Shao Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin --- include/tvm/tir/schedule/schedule.h | 39 +- python/tvm/tir/schedule/schedule.py | 241 +++++++++++- src/tir/schedule/analysis.h | 23 ++ src/tir/schedule/analysis/analysis.cc | 59 +++ src/tir/schedule/concrete_schedule.cc | 48 ++- src/tir/schedule/concrete_schedule.h | 6 +- src/tir/schedule/primitive.h | 43 +- src/tir/schedule/primitive/for_kind.cc | 289 ++++++++++++++ src/tir/schedule/schedule.cc | 6 + src/tir/schedule/state.cc | 8 +- src/tir/schedule/traced_schedule.cc | 40 ++ src/tir/schedule/traced_schedule.h | 4 + src/tir/transforms/flatten_buffer.cc | 5 +- .../unittest/test_tir_schedule_for_kind.py | 366 ++++++++++++++++++ 14 files changed, 1155 insertions(+), 22 deletions(-) create mode 100644 src/tir/schedule/primitive/for_kind.cc create mode 100644 tests/python/unittest/test_tir_schedule_for_kind.py diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index e2083778431e3..019e440dac348 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -110,7 +110,7 @@ class ScheduleNode : public runtime::Object { * guaranteeing that * 1) SRef tree is completely reconstructed; * 2) The IRModule being scheduled is not modified; - * 3) All the random variables are valid in the copy, pointing to the correpsonding sref + * 3) All the random variables are valid in the copy, pointing to the corresponding sref * reconstructed */ virtual Schedule Copy() const = 0; @@ -220,6 +220,43 @@ class ScheduleNode : public runtime::Object { */ virtual Array Split(const LoopRV& loop_rv, const Array>& factors) = 0; /******** Schedule: Manipulate ForKind ********/ + /*! + * \brief Parallelize the input loop. It requires: + * 1) The scope block that the loop is in should have stage-pipeline property + * 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine + * bindings + * 3) For each block under the loop, the loop can only be contained in data-parallel block iters' + * bindings + * \param loop_rv The loop to be parallelized + */ + virtual void Parallel(const LoopRV& loop_rv) = 0; + /*! + * \brief Vectorize the input loop. It requires: + * 1) The scope block that the loop is in should have stage-pipeline property + * 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine + * bindings + * 3) For each block under the loop, the loop can only be contained in data-parallel block iters' + * bindings + * \param loop_rv The loop to be vectorized + */ + virtual void Vectorize(const LoopRV& loop_rv) = 0; + /*! + * \brief Bind the input loop to the given thread axis. It requires: + * 1) The scope block that the loop is in should have stage-pipeline property + * 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine + * bindings + * 3) For each block under the loop, if the thread axis starts with "threadIdx`, the loop can only + * be contained in data-parallel block iter and reduction block iters' bindings. Otherwise the + * loop can only be contained in data-parallel block iters' bindings + * \param loop_rv The loop to be bound to the thread axis + * \param thread_axis The given thread axis + */ + virtual void Bind(const LoopRV& loop_rv, const String& thread_axis) = 0; + /*! + * \brief Unroll the input loop. It requires nothing + * \param loop_rv The loop to be unrolled + */ + virtual void Unroll(const LoopRV& loop_rv) = 0; /******** Schedule: Insert cache stages ********/ /******** Schedule: Compute location ********/ /*! diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 4bbb5b9b1582a..096ec7dabec38 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -20,7 +20,7 @@ from tvm._ffi import register_object as _register_object from tvm.error import TVMError, register_error from tvm.ir import IRModule, PrimExpr -from tvm.runtime import Object +from tvm.runtime import Object, String from tvm.tir import Block, For, IntImm, PrimFunc from . import _ffi_api @@ -170,7 +170,7 @@ def copy(self) -> "Schedule": * guaranteeing that * 1) SRef tree is completely reconstructed; * 2) The IRModule being scheduled is untouched; - * 3) All the random variables are valid in the copy, pointing to the correpsonding sref + * 3) All the random variables are valid in the copy, pointing to the corresponding sref * reconstructed Returns @@ -226,7 +226,7 @@ def get( Returns ------- result : Optional[Union[int, Block, For]] - The correpsonding result + The corresponding result """ if isinstance(rand_var_or_sref, StmtSRef): return rand_var_or_sref.stmt @@ -236,7 +236,7 @@ def get( return result def get_sref(self, rand_var_or_stmt: Union[BlockRV, LoopRV, Block, For]) -> Optional[StmtSRef]: - """Returns the correpsonding sref to the given + """Returns the corresponding sref to the given 1) LoopRV 2) BlockRV 3) Block @@ -250,7 +250,7 @@ def get_sref(self, rand_var_or_stmt: Union[BlockRV, LoopRV, Block, For]) -> Opti Returns ------- result : Optional[StmtSRef] - The correpsonding result + The corresponding result """ return _ffi_api.ScheduleGetSRef( # type: ignore # pylint: disable=no-member self, rand_var_or_stmt @@ -413,7 +413,7 @@ def before_split(a: ty.handle, b: ty.handle) -> None: with tir.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 - Create the schedule and do fuse: + Create the schedule and do split: .. code-block:: python @@ -444,6 +444,233 @@ def after_split(a: ty.handle, b: ty.handle) -> None: ########## Schedule: Manipulate ForKind ########## + def parallel(self, loop: LoopRV) -> None: + """Parallelize the input loop. It requires: + 1) The scope block that the loop is in should have stage-pipeline property + 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine + bindings + 3) For each block under the loop, the loop can only be contained in data-parallel block + iters' bindings + + Parameters + ---------- + loop : LoopRV + The loop to be parallelized + + Examples + -------- + + Before parallel, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_parallel(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and do parallel: + + .. code-block:: python + + sch = tir.Schedule(before_parallel) + i, j = sch.get_loops(sch.get_block("B")) + sch.parallel(i) + + After applying parallel, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_parallel(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i in tir.parallel(0, 128): + for j in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + + """ + _ffi_api.ScheduleParallel(self, loop) # type: ignore # pylint: disable=no-member + + def vectorize(self, loop: LoopRV) -> None: + """Vectorize the input loop. It requires: + 1) The scope block that the loop is in should have stage-pipeline property + 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine + bindings + 3) For each block under the loop, the loop can only be contained in data-parallel block + iters' bindings + + Parameters + ---------- + loop : LoopRV + The loop to be vectorized + + Examples + -------- + + Before vectorize, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_vectorize(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and do vectorize: + + .. code-block:: python + + sch = tir.Schedule(before_vectorize) + i, j = sch.get_loops(sch.get_block("B")) + sch.vectorize(j) + + After applying vectorize, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_vectorize(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i in tir.serial(0, 128): + for j in tir.vectorized(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + + """ + _ffi_api.ScheduleVectorize(self, loop) # type: ignore # pylint: disable=no-member + + def bind(self, loop: LoopRV, thread_axis: str) -> None: + """Bind the input loop to the given thread axis. It requires: + 1) The scope block that the loop is in should have stage-pipeline property + 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine + bindings + 3) For each block under the loop, if the thread axis starts with "threadIdx`, the loop can + only be contained in data-parallel block iter and reduction block iters' bindings. Otherwise + the loop can only be contained in data-parallel block iters' bindings + + Parameters + ---------- + loop : LoopRV + The loop to be bound to the thread axis + thread_axis : str + The thread axis to be bound to the loop. Possible candidates: + - blockIdx.x/y/z + - threadIdx.x/y/z + - vthread + - vthread.x/y/z + + Examples + -------- + + Before bind, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_bind(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and do bind: + + .. code-block:: python + + sch = tir.Schedule(before_bind) + i, j = sch.get_loops(sch.get_block("B")) + sch.bind(i, "blockIdx.x") + sch.bind(j, "threadIdx.x") + + After applying bind, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_bind(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i in tir.thread_binding(0, 128, thread = "blockIdx.x"): + for j in tir.thread_binding(0, 128, thread = "threadIdx.x"): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + + """ + _ffi_api.ScheduleBind(self, loop, String(thread_axis)) # type: ignore # pylint: disable=no-member + + def unroll(self, loop: LoopRV) -> None: + """Unroll the input loop. It requires nothing + + Parameters + ---------- + loop : LoopRV + The loop to be unrolled + + Examples + -------- + + Before unroll, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_unroll(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and do unroll: + + .. code-block:: python + + sch = tir.Schedule(before_unroll) + i, j = sch.get_loops(sch.get_block("B")) + sch.unroll(i) + + After applying unroll, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_unroll(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i in tir.unroll(0, 128): + for j in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + + """ + _ffi_api.ScheduleUnroll(self, loop) # type: ignore # pylint: disable=no-member + ########## Schedule: Insert cache stages ########## ########## Schedule: Compute location ########## @@ -581,7 +808,7 @@ def rfactor(self, loop: LoopRV, factor_axis: int) -> LoopRV: RFactor is a schedule primitive that implements the transformation described above: Given a block that writes to buffer `B`, it factorizes a loop of extent `n`. - For example, the pesudocode below accumulates `B[i] = sum(A[i, : , : ])`: + For example, the pseudocode below accumulates `B[i] = sum(A[i, : , : ])`: .. code-block:: python diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 9baf4b5245ead..ae5f0ccbb7368 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -120,6 +120,20 @@ bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref); +/*! + * \brief Check whether a subtree on SRef tree has compact data flow, and throw an exception if the + * subtree does not have compact data flow + * \details For a given StmtSRef, We say the subtree rooted from the StmtSRef has "compact data + * flow" property if: + * - the scope root of the input subtree root has stage-pipeline property, and + * - all its child blocks on SRef tree are complete blocks or reduction blocks. + * \param self The schedule state + * \param subtree_root_sref The root of the subtree to be checked in the SRef tree + * \throw ScheduleError If the subtree does not have compact data flow + * \sa IsCompleteBlock, IsReductionBlock + */ +void CheckSRefSubtreeCompactDataFlow(const ScheduleState& self, const StmtSRef& subtree_root_sref); + /******** Binding ********/ /*! * \brief Verifies if the block binding in a specific BlockRealize is an affine binding. @@ -132,6 +146,15 @@ void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, bool IsAffineBinding(const BlockRealize& realize, const Map& loop_var_ranges, arith::Analyzer* analyzer); +/*! + * \brief Check whether a block has an affine binding using the cached flag, and throw an exception + * if the block does not have an affine binding. + * \param self The schedule state + * \param block The block to be checked + * \throw ScheduleError If the input block does not have an affine binding + */ +void CheckAffineBinding(const ScheduleState& self, Block block); + /*! * \brief Extracts the ranges of loop variables in a path of the sref tree * \param low_inclusive The lowest node in the path diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 3ee98ec5b7d2e..82f0c2a73d690 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -315,6 +315,43 @@ void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, } } +void CheckSRefSubtreeCompactDataFlow(const ScheduleState& self, const StmtSRef& subtree_root_sref) { + class NotCompactDataFlowError : public ScheduleError { + public: + explicit NotCompactDataFlowError(IRModule mod, Stmt subtree_root, Block violate_block) + : mod_(std::move(mod)), + subtree_root_(std::move(subtree_root)), + violate_block_(std::move(violate_block)) { + ICHECK(subtree_root_->IsInstance() || subtree_root_->IsInstance()); + } + String FastErrorString() const final { + return "ScheduleError: The queried subtree root in SRef tree does not have compact data " + "flow, because some of its child block on SRef tree is neither a complete block nor a " + "reduction block"; + } + String DetailRenderTemplate() const final { + return "The queried subtree root {0} in SRef tree does not have compact data flow, because " + "its child block {1} on SRef tree is neither a complete block nor a reduction block"; + } + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {subtree_root_, violate_block_}; } + + IRModule mod_; + Stmt subtree_root_; + Block violate_block_; + }; + + StmtSRef scope_root = GetScopeRoot(self, subtree_root_sref, /*require_stage_pipeline=*/true); + Array child_blocks = GetChildBlockSRefOnSRefTree(self, scope_root); + for (const StmtSRef& block : child_blocks) { + if (!IsCompleteBlock(self, block, scope_root) && !IsReductionBlock(self, block, scope_root)) { + const BlockNode* violate_block = TVM_SREF_TO_BLOCK(violate_block, block); + throw NotCompactDataFlowError(self->mod, GetRef(subtree_root_sref->stmt), + GetRef(violate_block)); + } + } +} + /******** Binding ********/ bool IsAffineBinding(const BlockRealize& realize, const Map& loop_var_ranges, @@ -340,6 +377,28 @@ bool IsAffineBinding(const BlockRealize& realize, const Map& loop_va return true; } +void CheckAffineBinding(const ScheduleState& self, Block block) { + class NotAffineBindingError : public ScheduleError { + public: + explicit NotAffineBindingError(IRModule mod, Block block) + : mod_(std::move(mod)), block_(std::move(block)) {} + String FastErrorString() const final { + return "ScheduleError: The block is required to have an affine binding"; + } + String DetailRenderTemplate() const final { + return "The block {0} is required to have an affine binding"; + } + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + IRModule mod_; + Block block_; + }; + + if (!self->IsAffineBlockBinding(self->stmt2ref.at(block.get()))) { + throw NotAffineBindingError(self->mod, std::move(block)); + } +} + Map LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive, const Optional& high_exclusive, const runtime::StorageScope& extra_relax_scope) { diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 610628c6d88ae..589c750abceb4 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -207,7 +207,8 @@ Schedule ConcreteScheduleNode::Copy() const { } \ } -/******** Block/Loop relation ********/ +/******** Schedule: Schedule: Sampling ********/ +/******** Schedule: Get blocks & loops ********/ BlockRV ConcreteScheduleNode::GetBlock(const String& name, const String& func_name) { class NotSingleResult : public ScheduleError { @@ -257,7 +258,7 @@ Array ConcreteScheduleNode::GetLoops(const BlockRV& block_rv) { return CreateRV(tir::GetLoops(this->GetSRef(block_rv))); } -/******** Schedule: loops manipulation ********/ +/******** Schedule: Transform loops ********/ LoopRV ConcreteScheduleNode::Fuse(const Array& loop_rvs) { CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 loop(s)"; @@ -345,7 +346,40 @@ Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, return CreateRV(results); } -/******** Schedule: compute location ********/ +/******** Schedule: Manipulate ForKind ********/ + +void ConcreteScheduleNode::Parallel(const LoopRV& loop_rv) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::Parallel(state_, this->GetSRef(loop_rv)); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("parallel", this->error_render_level_); +} + +void ConcreteScheduleNode::Vectorize(const LoopRV& loop_rv) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::Vectorize(state_, this->GetSRef(loop_rv)); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("vectorize", this->error_render_level_); +} + +void ConcreteScheduleNode::Bind(const LoopRV& loop_rv, const String& thread_axis) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::Bind(state_, this->GetSRef(loop_rv), + IterVar(/*dom=*/Range(nullptr), /*var=*/Var(thread_axis), /*iter_type=*/kThreadIndex, + /*thread_tag=*/thread_axis)); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("bind", this->error_render_level_); +} + +void ConcreteScheduleNode::Unroll(const LoopRV& loop_rv) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::Unroll(state_, this->GetSRef(loop_rv)); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("unroll", this->error_render_level_); +} + +/******** Schedule: Insert cache stages ********/ +/******** Schedule: Compute location ********/ void ConcreteScheduleNode::ComputeInline(const BlockRV& block_rv) { TVM_TIR_SCHEDULE_BEGIN(); @@ -361,9 +395,7 @@ void ConcreteScheduleNode::ReverseComputeInline(const BlockRV& block_rv) { this->state_->DebugVerify(); } -/******** Schedule: loop binding/annotation ********/ -/******** Schedule: cache read/write ********/ -/******** Schedule: reduction ********/ +/******** Schedule: Reduction ********/ BlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { StmtSRef result{nullptr}; @@ -374,7 +406,9 @@ BlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { return CreateRV(result); } -/******** Schedule: blockize & tensorize ********/ +/******** Schedule: Blockize & Tensorize ********/ +/******** Schedule: Annotation ********/ +/******** Schedule: Misc ********/ } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index ec0dd079243b7..c9f9ead2ff753 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -49,7 +49,7 @@ class ConcreteScheduleNode : public ScheduleNode { // `state_` is not visited // `error_render_level_` is not visited // `symbol_table_` is not visited - // `analyzer_` is not visitied + // `analyzer_` is not visited } virtual ~ConcreteScheduleNode() = default; @@ -82,6 +82,10 @@ class ConcreteScheduleNode : public ScheduleNode { LoopRV Fuse(const Array& loop_rvs) override; Array Split(const LoopRV& loop_rv, const Array>& factors) override; /******** Schedule: Manipulate ForKind ********/ + void Parallel(const LoopRV& loop_rv) override; + void Vectorize(const LoopRV& loop_rv) override; + void Bind(const LoopRV& loop_rv, const String& thread_axis) override; + void Unroll(const LoopRV& loop_rv) override; /******** Schedule: Insert cache stages ********/ /******** Schedule: Compute location ********/ void ComputeInline(const BlockRV& block) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 22e25f1c54a77..c3eaa6a15176e 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -42,7 +42,6 @@ Array GetBlocks(const ScheduleState& self, const String& name, const S */ Array GetLoops(const StmtSRef& block_sref); /******** Schedule: Transform loops ********/ - /*! * Split a loop into a list of consecutive loops. It requires: * 1) The loop can't have annotation or thread binding. @@ -65,6 +64,47 @@ TVM_DLL Array Split(ScheduleState self, const StmtSRef& loop_sref, */ TVM_DLL StmtSRef Fuse(ScheduleState self, const Array& loop_srefs); /******** Schedule: Manipulate ForKind ********/ +/*! + * \brief Parallelize the input loop. It requires: + * 1) The scope block that the loop is in should have stage-pipeline property + * 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine + * bindings + * 3) For each block under the loop, the loop can only be contained in data-parallel block iters' + * bindings + * \param self The state of the schedule + * \param loop_sref The sref of the loop to be parallelized + */ +TVM_DLL void Parallel(ScheduleState self, const StmtSRef& loop_sref); +/*! + * \brief Vectorize the input loop. It requires: + * 1) The scope block that the loop is in should have stage-pipeline property + * 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine + * bindings + * 3) For each block under the loop, the loop can only be contained in data-parallel block iters' + * bindings + * \param self The state of the schedule + * \param loop_sref The sref of the loop to be vectorized + */ +TVM_DLL void Vectorize(ScheduleState self, const StmtSRef& loop_sref); +/*! + * \brief Bind the input loop to the given thread axis. It requires: + * 1) The scope block that the loop is in should have stage-pipeline property + * 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine + * bindings + * 3) For each block under the loop, if the thread axis starts with "threadIdx`, the loop can only + * be contained in data-parallel block iter and reduction block iters' bindings. Otherwise the + * loop can only be contained in data-parallel block iters' bindings + * \param self The state of the schedule + * \param loop_sref The sref of the loop to be bound to the thread axis + * \param thread_axis The given thread axis + */ +TVM_DLL void Bind(ScheduleState self, const StmtSRef& loop_sref, const IterVar& thread_axis); +/*! + * \brief Unroll the input loop. It requires nothing + * \param self The state of the schedule + * \param loop_sref The loop to be unrolled + */ +TVM_DLL void Unroll(ScheduleState self, const StmtSRef& loop_sref); /******** Schedule: Insert cache stages ********/ /******** Schedule: Compute location ********/ /*! @@ -96,6 +136,7 @@ TVM_DLL void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref /*! * \brief Factor a reduction block by the specified loop * \details See python/tvm/tir/schedule/schedule.py + * \param self The state of the schedule * \param loop_sref The loop outside block for which we want to do rfactor * \param factor_axis The position where the new dimension is placed in the new introduced rfactor * buffer. Suppose the original reduction block writes to buffer `B` with diff --git a/src/tir/schedule/primitive/for_kind.cc b/src/tir/schedule/primitive/for_kind.cc new file mode 100644 index 0000000000000..a6056d6070424 --- /dev/null +++ b/src/tir/schedule/primitive/for_kind.cc @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +class WrongBlockIterTypeError : public ScheduleError { + public: + explicit WrongBlockIterTypeError(IRModule mod, ForKind for_kind, Var loop_var, Block block) + : mod_(std::move(mod)), loop_var_(std::move(loop_var)), block_(std::move(block)) { + op_str_ = for_kind == ForKind::kParallel + ? "parallel" + : for_kind == ForKind::kVectorized ? "vectorize" : "bind"; + } + String FastErrorString() const final { + std::ostringstream os; + os << "ScheduleError: The \"" << op_str_ + << "\" cannot be fulfilled with regard to some of its underlying block"; + return os.str(); + } + String DetailRenderTemplate() const final { + std::ostringstream os; + if (op_str_ != "bind") { + os << "The \"" << op_str_ + << "\" cannot be fulfilled with regard to block {0} because some block iter whose block " + "binding contains the loop var is not a data parallel block iter"; + } else { + os << "The \"bind\" cannot be fulfilled with regard to block {0}. This is because some of its" + " block iter whose block binding contains " + << loop_var_ + << " does not meet any of the conditions:\n1) the block iter is data parallel;\n2) the " + "block iter is a reduction block iter, and the thread axis to be bound is " + "\"threadIdx.x/y/z\""; + } + return os.str(); + } + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + IRModule mod_; + std::string op_str_; + Var loop_var_; + Block block_; +}; + +/*! + * \brief Check if a loop can be parallelized/vectorized/bound with regard to a specific block + * \details There are two conditions: + * 1) The block is required to have affine bindings, and + * 2) For each block iter whose binding contains the input loop variable, either + * - the block iter is data parallel, or + * - the block iter is a reduction block iter, and the input `thread_tag` starts with "threadIdx" + * in case of cross-thread reduction. + * \param self The schedule state + * \param for_kind The desired ForKind (only `kParallel`, `kVectorized` and `kThreadBinding` are + * allowed) + * \param loop_var The loop variable of the loop to be checked + * \param block_realize The block-realize of the block to be checked + * \param thread_scope The thread scope of the thread axis to be bound, which is an invalid value if + * the operation is not "bind" + * \throws ScheduleError If the input loop cannot be parallelized/vectorized/bound with regard to + * the input block + */ +void CheckLoopParallelizableInBlock(const ScheduleState& self, ForKind for_kind, + const Var& loop_var, const BlockRealize& block_realize, + runtime::ThreadScope thread_scope) { + const Block& block = block_realize->block; + + // Cond 1. The block is required to have affine bindings. + CheckAffineBinding(self, block); + + // Cond 2. For each block iter whose binding contains `loop_var`, only two cases are allowed. + ICHECK_EQ(block->iter_vars.size(), block_realize->iter_values.size()); + int n_iters = static_cast(block->iter_vars.size()); + for (int i = 0; i < n_iters; ++i) { + const IterVar& iter_var = block->iter_vars[i]; + const PrimExpr& binding = block_realize->iter_values[i]; + + if (!UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) { + continue; + } + // Only two cases are allowed: + // - The block iter is data parallel, or + // - The block iter is a reduction block iter, and the `thread_scope` is "threadIdx.x/y/z" + // in case of cross-thread reduction. + IterVarType iter_type = iter_var->iter_type; + if (!(iter_type == kDataPar || + (iter_type == kCommReduce && thread_scope.rank == 1 && thread_scope.dim_index != -1))) { + throw WrongBlockIterTypeError(self->mod, for_kind, loop_var, block); + } + } +} + +/*! + * \brief For each block (recursive) under the given loop, check whether the input loop can be + * parallelized/vectorized/bound with regard to the block + * \param self The schedule state + * \param loop The loop to be parallelized/vectorized/bound + * \param for_kind The desired ForKind (only `kParallel`, `kVectorized` and `kThreadBinding` are + * allowed) + * \param thread_scope The thread scope of the thread axis to be bound, which is an invalid value if + * the operation is not "bind" + */ +void CheckParallelizability(const ScheduleState& self, const For& loop, ForKind for_kind, + runtime::ThreadScope thread_scope) { + PreOrderVisit(loop, [&](const ObjectRef& node) { + if (const auto* realize = node.as()) { + CheckLoopParallelizableInBlock(self, for_kind, loop->loop_var, GetRef(realize), + thread_scope); + } + return true; + }); +} + +/*! + * \brief The implementation of parallelizing/vectorizing/binding a given loop + * \param self The schedule state + * \param loop_sref The sref of the loop to be parallelized/vectorized/bound + * \param for_kind The type of the operation (only `kParallel`, `kVectorized` and `kThreadBinding` + * are allowed) + * \param thread_axis The thread axis that the input loop is bound to, which is defined only when + * `for_kind` is `kThreadBinding` + */ +void ParallelizeComputation(const ScheduleState& self, const StmtSRef& loop_sref, ForKind for_kind, + Optional thread_axis) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + + /* + * Check: + * - 1. the subtree rooted from the input loop in sref tree has compact data flow + * - 2. all the blocks under the given loop have affine block bindings + * - 3. the input loop can be only bound to data parallel block iters, or the loop can be bound to + * reduction block iter if `thread` is `threadIdx.x/y/z` in case of cross-thread reduction + * When the above conditions are all satisfied, this input loop can be + * parallelized/vectorized/bound. + */ + // Step 1. Check whether the subtree rooted from the `loop` in sref tree has compact data flow. + CheckSRefSubtreeCompactDataFlow(self, loop_sref); + + // Step 2. Check whether the loop can be parallelized/vectorized/bound with regard to each + // underlying block. + CheckParallelizability(self, GetRef(loop), for_kind, + thread_axis.defined() + ? runtime::ThreadScope::Create(thread_axis.value()->thread_tag) + : runtime::ThreadScope{-1, -1}); + + // Step 3. Loop update and IR replacement + ObjectPtr new_loop = make_object(*loop); + new_loop->kind = for_kind; + new_loop->thread_binding = std::move(thread_axis); + self->Replace(loop_sref, For(new_loop), {}); +} + +void Parallel(ScheduleState self, const StmtSRef& loop_sref) { + ParallelizeComputation(self, loop_sref, ForKind::kParallel, NullOpt); +} + +void Vectorize(ScheduleState self, const StmtSRef& loop_sref) { + ParallelizeComputation(self, loop_sref, ForKind::kVectorized, NullOpt); +} + +void Bind(ScheduleState self, const StmtSRef& loop_sref, const IterVar& thread_axis) { + ParallelizeComputation(self, loop_sref, ForKind::kThreadBinding, thread_axis); +} + +void Unroll(ScheduleState self, const StmtSRef& loop_sref) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + ObjectPtr new_loop = make_object(*loop); + new_loop->kind = ForKind::kUnrolled; + new_loop->thread_binding = NullOpt; + self->Replace(loop_sref, For(new_loop), {}); +} + +/******** Instruction Registration ********/ + +struct ParallelTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Parallel"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv) { + return sch->Parallel(loop_rv); + } + + static String UnpackedAsPython(Array outputs, String loop_rv) { + PythonAPICall py("parallel"); + py.Input("loop", loop_rv); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct VectorizeTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Vectorize"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv) { + return sch->Vectorize(loop_rv); + } + + static String UnpackedAsPython(Array outputs, String loop_rv) { + PythonAPICall py("vectorize"); + py.Input("loop", loop_rv); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct BindTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Bind"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 1; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, String thread) { + return sch->Bind(loop_rv, thread); + } + + static String UnpackedAsPython(Array outputs, String loop_rv, String thread) { + PythonAPICall py("bind"); + py.Input("loop", loop_rv); + py.Input("thread", thread); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct UnrollTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Unroll"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv) { return sch->Unroll(loop_rv); } + + static String UnpackedAsPython(Array outputs, String loop_rv) { + PythonAPICall py("unroll"); + py.Input("loop", loop_rv); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(ParallelTraits); +TVM_REGISTER_INST_KIND_TRAITS(VectorizeTraits); +TVM_REGISTER_INST_KIND_TRAITS(BindTraits); +TVM_REGISTER_INST_KIND_TRAITS(UnrollTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 3232a3344ee7d..0d53b1d2ac1ab 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -126,6 +126,12 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetLoops") TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method(&ScheduleNode::Fuse); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method(&ScheduleNode::Split); /******** (FFI) Manipulate ForKind ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleParallel") + .set_body_method(&ScheduleNode::Parallel); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleVectorize") + .set_body_method(&ScheduleNode::Vectorize); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBind").set_body_method(&ScheduleNode::Bind); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnroll").set_body_method(&ScheduleNode::Unroll); /******** (FFI) Insert cache stages ********/ /******** (FFI) Compute location ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeInline") diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index 6dd09680e987b..9a9b97497e042 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -112,7 +112,7 @@ bool ProducerCoversConsumer(const Array& buffer_shape, * \param self The schedule class * \param stmt The statement, or the realize node of the statement whose sref to be set * \param seq_index The seq_index to be set - * \note The method is NOP for statements that are not scheduleable, i.e. not For or Block + * \note The method is NOP for statements that are not schedulable, i.e. not For or Block */ void SetSeqIndex(ScheduleStateNode* self, const Stmt& stmt, int seq_index) { if (const auto* realize = stmt.as()) { @@ -405,7 +405,7 @@ class StateCreator : private StmtVisitor { std::unordered_map block2realize_; /*! \brief The stack frames of blocks in the DFS visit. */ std::vector> block_frames_; - /*! \brief The auxilary analyzer */ + /*! \brief The auxiliary analyzer */ arith::Analyzer analyzer_; }; @@ -565,7 +565,7 @@ class SRefTreePruner : public StmtVisitor { } auto it = self_->stmt2ref.find(op); ICHECK(it != self_->stmt2ref.end()) - << "IndexError: Cannot find correpsonding StmtSRef for the loop:\n" + << "IndexError: Cannot find corresponding StmtSRef for the loop:\n" << GetRef(op); StmtSRef& sref = it->second; // Detect reuse @@ -588,7 +588,7 @@ class SRefTreePruner : public StmtVisitor { } auto it = self_->stmt2ref.find(op); ICHECK(it != self_->stmt2ref.end()) - << "IndexError: Cannot find correpsonding StmtSRef for the block:\n" + << "IndexError: Cannot find corresponding StmtSRef for the block:\n" << GetRef(op); StmtSRef& sref = it->second; // Detect reuse diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index d664d7f6ce98d..2cdfe5d21b921 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -101,6 +101,46 @@ Array TracedScheduleNode::Split(const LoopRV& loop_rv, /******** Schedule: Manipulate ForKind ********/ +void TracedScheduleNode::Parallel(const LoopRV& loop_rv) { + ConcreteScheduleNode::Parallel(loop_rv); + + static const InstructionKind& kind = InstructionKind::Get("Parallel"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv}, + /*attrs=*/{}, + /*outputs=*/{})); +} + +void TracedScheduleNode::Vectorize(const LoopRV& loop_rv) { + ConcreteScheduleNode::Vectorize(loop_rv); + + static const InstructionKind& kind = InstructionKind::Get("Vectorize"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv}, + /*attrs=*/{}, + /*outputs=*/{})); +} + +void TracedScheduleNode::Bind(const LoopRV& loop_rv, const String& thread_axis) { + ConcreteScheduleNode::Bind(loop_rv, thread_axis); + + static const InstructionKind& kind = InstructionKind::Get("Bind"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv}, + /*attrs=*/{thread_axis}, + /*outputs=*/{})); +} + +void TracedScheduleNode::Unroll(const LoopRV& loop_rv) { + ConcreteScheduleNode::Unroll(loop_rv); + + static const InstructionKind& kind = InstructionKind::Get("Unroll"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv}, + /*attrs=*/{}, + /*outputs=*/{})); +} + /******** Schedule: Insert cache stages ********/ /******** Schedule: Compute location ********/ diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index b4518cbba8b57..6d6e4db1225e3 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -55,6 +55,10 @@ class TracedScheduleNode : public ConcreteScheduleNode { LoopRV Fuse(const Array& loop_rvs) final; Array Split(const LoopRV& loop_rv, const Array>& factor_rvs) final; /******** Schedule: Manipulate ForKind ********/ + void Parallel(const LoopRV& loop_rv) final; + void Vectorize(const LoopRV& loop_rv) final; + void Bind(const LoopRV& loop_rv, const String& thread_axis) final; + void Unroll(const LoopRV& loop_rv) final; /******** Schedule: Insert cache stages ********/ /******** Schedule: Compute location ********/ void ComputeInline(const BlockRV& block_rv) final; diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index 85c4123460569..5eb6d5b039211 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -140,7 +140,10 @@ class BufferFlattener : public StmtExprMutator { /*var=*/std::move(var), /*iter_type=*/IterVarType::kThreadIndex, /*thread_tag=*/thread_tag); - String attr_key = thread_tag == "vthread" ? attr::virtual_thread : attr::thread_extent; + String attr_key = (thread_tag == "vthread" || thread_tag == "vthread.x" || + thread_tag == "vthread.y" || thread_tag == "vthread.z") + ? attr::virtual_thread + : attr::thread_extent; return AttrStmt(/*node=*/std::move(iter_var), /*attr_key=*/std::move(attr_key), /*value=*/std::move(extent), diff --git a/tests/python/unittest/test_tir_schedule_for_kind.py b/tests/python/unittest/test_tir_schedule_for_kind.py new file mode 100644 index 0000000000000..2a481c329242e --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_for_kind.py @@ -0,0 +1,366 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import sys + +import numpy as np +import pytest +import tvm +import tvm.testing +from tvm import tir +from tvm.script import ty +from tvm.tir.schedule.testing import verify_trace_roundtrip + +# pylint: disable=no-member,invalid-name,unused-variable + + +@tvm.script.tir +def element_wise(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + +@tvm.script.tir +def element_wise_parallelized(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i0 in tir.parallel(0, 128): + for i1 in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i0) + tir.bind(vj, i1) + B[vi, vj] = A[vi, vj] * 2.0 + + +@tvm.script.tir +def element_wise_i_bound(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i0 in tir.thread_binding(0, 128, thread="threadIdx.x"): + for i1 in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i0) + tir.bind(vj, i1) + B[vi, vj] = A[vi, vj] * 2.0 + + +@tvm.script.tir +def element_wise_compute_at_split(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + B = tir.alloc_buffer((128, 128)) + for i in tir.serial(0, 128): + for j0 in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j0) + B[vi, vj] = A[vi, vj] * 2.0 + for j1o, j1i in tir.grid(32, 4): + with tir.block([128, 128], "C") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j1o * 4 + j1i) + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def element_wise_compute_at_split_vectorized(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + B = tir.alloc_buffer((128, 128)) + for i in tir.serial(0, 128): + for j0 in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j0) + B[vi, vj] = A[vi, vj] * 2.0 + for j1o in tir.serial(0, 32): + for j1i in tir.vectorized(0, 4): + with tir.block([128, 128], "C") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j1o * 4 + j1i) + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def element_wise_split_predicate(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + for i, j_0, j_1 in tir.grid(128, 13, 10): + with tir.block([128, 128], "B") as [vi, vj]: + tir.where(j_0 * 10 + j_1 < 128) + tir.bind(vi, i) + tir.bind(vj, j_0 * 10 + j_1) + B[vi, vj] = A[vi, vj] * 2.0 + + +@tvm.script.tir +def element_wise_split_predicate_parallelized(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + for i in tir.serial(0, 128): + for j_0 in tir.parallel(0, 13): + for j_1 in tir.serial(0, 10): + with tir.block([128, 128], "B") as [vi, vj]: + tir.where(j_0 * 10 + j_1 < 128) + tir.bind(vi, i) + tir.bind(vj, j_0 * 10 + j_1) + B[vi, vj] = A[vi, vj] * 2.0 + + +@tvm.script.tir +def element_wise_split_predicate_vectorized(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + for i in tir.vectorized(0, 128): + for j_0, j_1 in tir.grid(13, 10): + with tir.block([128, 128], "B") as [vi, vj]: + tir.where(j_0 * 10 + j_1 < 128) + tir.bind(vi, i) + tir.bind(vj, j_0 * 10 + j_1) + B[vi, vj] = A[vi, vj] * 2.0 + + +@tvm.script.tir +def element_wise_compute_at_split_j0_j1o_bound(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + B = tir.alloc_buffer((128, 128)) + for i in tir.serial(0, 128): + for j0 in tir.thread_binding(0, 128, thread="threadIdx.x"): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j0) + B[vi, vj] = A[vi, vj] * 2.0 + for j1o in tir.thread_binding(0, 32, thread="threadIdx.x"): + for j1i in tir.serial(0, 4): + with tir.block([128, 128], "C") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j1o * 4 + j1i) + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + + with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@tvm.script.tir +def rowsum(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + with tir.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@tvm.script.tir +def rowsum_unrolled(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + for i0 in tir.unroll(0, 128): + for i1 in tir.serial(0, 128): + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + tir.bind(vi, i0) + tir.bind(vk, i1) + with tir.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@tvm.script.tir +def rowsum_not_quasi_affine(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + + for i, k in tir.grid(128, 16): + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + tir.bind(vi, i) + tir.bind(vk, tir.floordiv(k * k, 2)) + with tir.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@tvm.script.tir +def rowsum_not_compact_data_flow(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + with tir.init(): + B[vk] = 0.0 + B[vk] = B[vk] + A[vi, vk] + + +@tvm.script.tir +def rowsum_cross_thread_reduction(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + for i0 in tir.serial(0, 128): + for i1 in tir.thread_binding(0, 128, thread="threadIdx.x"): + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + tir.bind(vi, i0) + tir.bind(vk, i1) + with tir.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@tvm.script.tir +def opaque_block(a: ty.handle) -> None: + A = tir.match_buffer(a, (16,)) + for i in tir.serial(0, 15): + with tir.block([], "opaque"): + A[i + 1] = A[i + 1] + A[i] + + +# pylint: enable=no-member,invalid-name,unused-variable + + +def test_parallel(): + s = tir.Schedule(element_wise, debug_mask="all") + i, _ = s.get_loops(s.get_block("B")) + s.parallel(i) + tvm.ir.assert_structural_equal(s.mod["main"], element_wise_parallelized) + verify_trace_roundtrip(s, mod=element_wise) + + +def test_parallel_predicate(): + s = tir.Schedule(element_wise_split_predicate, debug_mask="all") + _, j, _ = s.get_loops(s.get_block("B")) + s.parallel(j) + tvm.ir.assert_structural_equal(s.mod["main"], element_wise_split_predicate_parallelized) + verify_trace_roundtrip(s, mod=element_wise_split_predicate) + + +def test_parallel_reduction_block_iter(): + s = tir.Schedule(matmul, debug_mask="all") + _, _, k = s.get_loops(s.get_block("C")) + with pytest.raises(tvm.tir.ScheduleError): + s.parallel(k) + + +def test_parallel_not_quasi_affine(): + s = tir.Schedule(rowsum_not_quasi_affine, debug_mask="all") + i, _ = s.get_loops(s.get_block("B")) + with pytest.raises(tvm.tir.ScheduleError): + s.parallel(i) + + +def test_parallel_not_compact_data_flow(): + s = tir.Schedule(rowsum_not_compact_data_flow, debug_mask="all") + i, _ = s.get_loops(s.get_block("B")) + with pytest.raises(tvm.tir.ScheduleError): + s.parallel(i) + + +def test_vectorize(): + s = tir.Schedule(element_wise_compute_at_split, debug_mask="all") + _, _, j1i = s.get_loops(s.get_block("C")) + s.vectorize(j1i) + tvm.ir.assert_structural_equal(s.mod["main"], element_wise_compute_at_split_vectorized) + verify_trace_roundtrip(s, mod=element_wise_compute_at_split) + + +def test_vectorize_predicate(): + s = tir.Schedule(element_wise_split_predicate, debug_mask="all") + i, _, _ = s.get_loops(s.get_block("B")) + s.vectorize(i) + tvm.ir.assert_structural_equal(s.mod["main"], element_wise_split_predicate_vectorized) + verify_trace_roundtrip(s, mod=element_wise_split_predicate) + + +def test_vectorize_opaque_block(): + s = tir.Schedule(opaque_block, debug_mask="all") + (i,) = s.get_loops(s.get_block("opaque")) + with pytest.raises(tvm.tir.ScheduleError): + s.vectorize(i) + + +def test_unroll(): + s = tir.Schedule(rowsum, debug_mask="all") + i, _ = s.get_loops(s.get_block("B")) + s.unroll(i) + tvm.ir.assert_structural_equal(s.mod["main"], rowsum_unrolled) + verify_trace_roundtrip(s, mod=rowsum) + + +def test_unroll_after_bind(): + s = tir.Schedule(rowsum, debug_mask="all") + i, _ = s.get_loops(s.get_block("B")) + s.bind(i, "blockIdx.x") + s.unroll(i) + tvm.ir.assert_structural_equal(s.mod["main"], rowsum_unrolled) + verify_trace_roundtrip(s, mod=rowsum) + + +def test_bind1(): + s = tir.Schedule(element_wise, debug_mask="all") + i, _ = s.get_loops(s.get_block("B")) + s.bind(i, "threadIdx.x") + tvm.ir.assert_structural_equal(s.mod["main"], element_wise_i_bound) + verify_trace_roundtrip(s, mod=element_wise) + + +def test_bind2(): + s = tir.Schedule(element_wise_compute_at_split, debug_mask="all") + _, j0 = s.get_loops(s.get_block("B")) + _, j1o, _ = s.get_loops(s.get_block("C")) + s.bind(j0, "threadIdx.x") + s.bind(j1o, "threadIdx.x") + tvm.ir.assert_structural_equal(s.mod["main"], element_wise_compute_at_split_j0_j1o_bound) + verify_trace_roundtrip(s, mod=element_wise_compute_at_split) + + +def test_bind_cross_thread_reduction(): + s = tir.Schedule(rowsum, debug_mask="all") + _, k = s.get_loops(s.get_block("B")) + s.bind(k, "threadIdx.x") + tvm.ir.assert_structural_equal(s.mod["main"], rowsum_cross_thread_reduction) + verify_trace_roundtrip(s, mod=rowsum) + + +def test_bind_not_cross_thread_reduction(): + s = tir.Schedule(rowsum, debug_mask="all") + _, k = s.get_loops(s.get_block("B")) + with pytest.raises(tvm.tir.ScheduleError): + s.bind(k, "blockIdx.x") + + +def test_bind_after_bind(): + s = tir.Schedule(element_wise, debug_mask="all") + i, _ = s.get_loops(s.get_block("B")) + s.bind(i, "blockIdx.x") + s.bind(i, "threadIdx.x") + tvm.ir.assert_structural_equal(s.mod["main"], element_wise_i_bound) + verify_trace_roundtrip(s, mod=element_wise) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:]))