Skip to content

Commit

Permalink
[TensorIR][M2a] Decompose-Reduction (#9041)
Browse files Browse the repository at this point in the history
This PR is part of the TensorIR upstreaming effort (#7527),
which adds the `decompose-reduction` scheduling primitive.

Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
  • Loading branch information
6 people committed Oct 2, 2021
1 parent feb4536 commit 6b3fe95
Show file tree
Hide file tree
Showing 13 changed files with 1,299 additions and 602 deletions.
16 changes: 16 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,22 @@ class ScheduleNode : public runtime::Object {
*/
virtual void ReverseComputeInline(const BlockRV& block) = 0;
/******** Schedule: Reduction ********/
/*!
* \brief Decompose a reduction block into two separate blocks.
* a) The init block, which is translated from the init statement of the reduction block;
* b) The update block, which is the original block without init statement.
*
* The init block is inserted right before the given loop.
*
* The schedule primitive requires:
* 1) The input block is a reduction block.
* 2) The input loop is the ancestor of the block.
* 3) The input loop is not lower than all the loops related to reduce block var.
* \param block_rv The reduction block to be decomposed
* \param loop_rv The loop above which the init block is inserted before.
* \return The init block
*/
virtual BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) = 0;
/*!
* \brief Factorize an associative reduction block by the specified loop.
* \details An associative reduction cannot be parallelized directly,
Expand Down
6 changes: 6 additions & 0 deletions include/tvm/tir/schedule/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,12 @@ class ScheduleStateNode : public Object {
/******** Property of blocks ********/
/*! \brief Returns the BlockInfo correpsonding to the block sref */
TVM_DLL BlockInfo GetBlockInfo(const StmtSRef& block_sref) const;
/*!
* \brief Recalculate the BlockInfo recursively under stmt.
* If stmt is a Block itself, we will not reset its affine binding flag unless it doesn't
* have block vars, since the affine flag depends on the outer scope of stmt.
*/
TVM_DLL void UpdateScopeBlockInfo(const Stmt& stmt);
/*!
* \brief Get the BlockScope correpsonding to the sref of scope root block
* \param scope_root The block sref to be retrieved
Expand Down
76 changes: 76 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,6 +1224,82 @@ def after_inline(a: T.handle, c: T.handle) -> None:

########## Schedule: Reduction ##########

def decompose_reduction(self, block: BlockRV, loop: LoopRV) -> BlockRV:
"""Decompose a reduction block into two separate blocks.
a) The init block, which is translated from the init statement of the reduction block;
b) The update block, which is the original block without init statement.
The init block is inserted right before the given loop.
The schedule primitive requires:
1) The input block is a reduction block.
2) The input loop is the ancestor of the block.
3) The input loop is not lower than all the loops related to reduce block var.
Parameters
----------
block : BlockRV
The reduction block to be decomposed
loop : LoopRV
The loop above which the init block is inserted before.
Returns
-------
init_block : BlockRV
The init block
Examples
--------
Before decompose-reduction, in TensorIR, the IR is:
.. code-block:: python
@tvm.script.tir
def before_decompose(a: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, [128, 128])
B = tir.match_buffer(b, [128, 128])
C = tir.match_buffer(c, [128, 128])
for i, j, k in tir.grid(128, 128, 128):
with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]:
with tir.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
Create the schedule and do decompose-reduction with specified loop:
.. code-block:: python
sch = tir.Schedule(before_decompose)
C = sch.get_block("C")
i, j, k = sch.get_loops(C)
sch.decompose_reduction(C, i)
print(tvm.script.asscript(sch.mod["main"]))
After applying decompose-reduction, the IR becomes:
.. code-block:: python
@tvm.script.tir
def after_decompose(a: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, [128, 128])
B = tir.match_buffer(b, [128, 128])
C = tir.match_buffer(c, [128, 128])
for i in tir.serial(128):
for j in tir.serial(128):
with tir.block([128, 128]) as [vi, vj]:
C[vi, vj] = 0.0
for i, j, k in tir.grid(128, 128, 128):
with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]:
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
"""
return _ffi_api.ScheduleDecomposeReduction(self, block, loop) # type: ignore # pylint: disable=no-member

def rfactor(self, loop: LoopRV, factor_axis: int) -> LoopRV:
"""Factorize an associative reduction block by the specified loop.
Expand Down
9 changes: 9 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,15 @@ void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_inde

/******** Schedule: Reduction ********/

BlockRV ConcreteScheduleNode::DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) {
StmtSRef result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
result = tir::DecomposeReduction(state_, this->GetSRef(block_rv), this->GetSRef(loop_rv));
TVM_TIR_SCHEDULE_END("decompose-reduction", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<BlockRV>(result);
}

BlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) {
StmtSRef result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
Expand Down
1 change: 1 addition & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class ConcreteScheduleNode : public ScheduleNode {
void ReverseComputeInline(const BlockRV& block) override;
/******** Schedule: Reduction ********/
BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) override;
BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) override;
/******** Schedule: Block annotation ********/
void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
int offset) override;
Expand Down
17 changes: 17 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,23 @@ TVM_DLL void ComputeInline(ScheduleState self, const StmtSRef& block_sref);
*/
TVM_DLL void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref);
/******** Schedule: Reduction ********/
/*!
* \brief Decompose a reduction block into two separate blocks.
* a) The init block, which is translated from the init statement of the reduction block;
* b) The update block, which is the original block without init statement.
*
* The init block is inserted right before the given loop.
*
* The schedule primitive requires:
* 1) The input block is a reduction block.
* 2) The input loop is the ancestor of the block.
* 3) The input loop is not lower than all the loops related to reduce block var.
* \param block_rv The reduction block to be decomposed
* \param loop_rv The loop above which the init block is inserted before.
* \return The init block
*/
TVM_DLL StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
const StmtSRef& loop_sref);
/*!
* \brief Factor a reduction block by the specified loop
* \details See python/tvm/tir/schedule/schedule.py
Expand Down
Loading

0 comments on commit 6b3fe95

Please sign in to comment.