Skip to content

Commit

Permalink
[TensorIR][M2a] Parallel, Vectorize, Bind & Unroll
Browse files Browse the repository at this point in the history
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
  • Loading branch information
5 people committed Aug 12, 2021
1 parent 66ac470 commit 6c86365
Show file tree
Hide file tree
Showing 14 changed files with 1,155 additions and 22 deletions.
39 changes: 38 additions & 1 deletion include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class ScheduleNode : public runtime::Object {
* guaranteeing that
* 1) SRef tree is completely reconstructed;
* 2) The IRModule being scheduled is not modified;
* 3) All the random variables are valid in the copy, pointing to the correpsonding sref
* 3) All the random variables are valid in the copy, pointing to the corresponding sref
* reconstructed
*/
virtual Schedule Copy() const = 0;
Expand Down Expand Up @@ -220,6 +220,43 @@ class ScheduleNode : public runtime::Object {
*/
virtual Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors) = 0;
/******** Schedule: Manipulate ForKind ********/
/*!
* \brief Parallelize the input loop. It requires:
* 1) The scope block that the loop is in should have stage-pipeline property
* 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine
* bindings
* 3) For each block under the loop, the loop can only be contained in data-parallel block iters'
* bindings
* \param loop_rv The loop to be parallelized
*/
virtual void Parallel(const LoopRV& loop_rv) = 0;
/*!
* \brief Vectorize the input loop. It requires:
* 1) The scope block that the loop is in should have stage-pipeline property
* 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine
* bindings
* 3) For each block under the loop, the loop can only be contained in data-parallel block iters'
* bindings
* \param loop_rv The loop to be vectorized
*/
virtual void Vectorize(const LoopRV& loop_rv) = 0;
/*!
* \brief Bind the input loop to the given thread axis. It requires:
* 1) The scope block that the loop is in should have stage-pipeline property
* 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine
* bindings
* 3) For each block under the loop, if the thread axis starts with "threadIdx`, the loop can only
* be contained in data-parallel block iter and reduction block iters' bindings. Otherwise the
* loop can only be contained in data-parallel block iters' bindings
* \param loop_rv The loop to be bound to the thread axis
* \param thread_axis The given thread axis
*/
virtual void Bind(const LoopRV& loop_rv, const String& thread_axis) = 0;
/*!
* \brief Unroll the input loop. It requires nothing
* \param loop_rv The loop to be unrolled
*/
virtual void Unroll(const LoopRV& loop_rv) = 0;
/******** Schedule: Insert cache stages ********/
/******** Schedule: Compute location ********/
/*!
Expand Down
241 changes: 234 additions & 7 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tvm._ffi import register_object as _register_object
from tvm.error import TVMError, register_error
from tvm.ir import IRModule, PrimExpr
from tvm.runtime import Object
from tvm.runtime import Object, String
from tvm.tir import Block, For, IntImm, PrimFunc

from . import _ffi_api
Expand Down Expand Up @@ -170,7 +170,7 @@ def copy(self) -> "Schedule":
* guaranteeing that
* 1) SRef tree is completely reconstructed;
* 2) The IRModule being scheduled is untouched;
* 3) All the random variables are valid in the copy, pointing to the correpsonding sref
* 3) All the random variables are valid in the copy, pointing to the corresponding sref
* reconstructed
Returns
Expand Down Expand Up @@ -226,7 +226,7 @@ def get(
Returns
-------
result : Optional[Union[int, Block, For]]
The correpsonding result
The corresponding result
"""
if isinstance(rand_var_or_sref, StmtSRef):
return rand_var_or_sref.stmt
Expand All @@ -236,7 +236,7 @@ def get(
return result

def get_sref(self, rand_var_or_stmt: Union[BlockRV, LoopRV, Block, For]) -> Optional[StmtSRef]:
"""Returns the correpsonding sref to the given
"""Returns the corresponding sref to the given
1) LoopRV
2) BlockRV
3) Block
Expand All @@ -250,7 +250,7 @@ def get_sref(self, rand_var_or_stmt: Union[BlockRV, LoopRV, Block, For]) -> Opti
Returns
-------
result : Optional[StmtSRef]
The correpsonding result
The corresponding result
"""
return _ffi_api.ScheduleGetSRef( # type: ignore # pylint: disable=no-member
self, rand_var_or_stmt
Expand Down Expand Up @@ -413,7 +413,7 @@ def before_split(a: ty.handle, b: ty.handle) -> None:
with tir.block([128, 128], "B") as [vi, vj]:
B[vi, vj] = A[vi, vj] * 2.0
Create the schedule and do fuse:
Create the schedule and do split:
.. code-block:: python
Expand Down Expand Up @@ -444,6 +444,233 @@ def after_split(a: ty.handle, b: ty.handle) -> None:

########## Schedule: Manipulate ForKind ##########

def parallel(self, loop: LoopRV) -> None:
"""Parallelize the input loop. It requires:
1) The scope block that the loop is in should have stage-pipeline property
2) All the blocks under the loop are complete blocks or reduction blocks, and have affine
bindings
3) For each block under the loop, the loop can only be contained in data-parallel block
iters' bindings
Parameters
----------
loop : LoopRV
The loop to be parallelized
Examples
--------
Before parallel, in TensorIR, the IR is:
.. code-block:: python
@tvm.script.tir
def before_parallel(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, i)
tir.bind(vj, j)
B[vi, vj] = A[vi, vj] * 2.0
Create the schedule and do parallel:
.. code-block:: python
sch = tir.Schedule(before_parallel)
i, j = sch.get_loops(sch.get_block("B"))
sch.parallel(i)
After applying parallel, the IR becomes:
.. code-block:: python
@tvm.script.tir
def after_parallel(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i in tir.parallel(0, 128):
for j in tir.serial(0, 128):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, i)
tir.bind(vj, j)
B[vi, vj] = A[vi, vj] * 2.0
"""
_ffi_api.ScheduleParallel(self, loop) # type: ignore # pylint: disable=no-member

def vectorize(self, loop: LoopRV) -> None:
"""Vectorize the input loop. It requires:
1) The scope block that the loop is in should have stage-pipeline property
2) All the blocks under the loop are complete blocks or reduction blocks, and have affine
bindings
3) For each block under the loop, the loop can only be contained in data-parallel block
iters' bindings
Parameters
----------
loop : LoopRV
The loop to be vectorized
Examples
--------
Before vectorize, in TensorIR, the IR is:
.. code-block:: python
@tvm.script.tir
def before_vectorize(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, i)
tir.bind(vj, j)
B[vi, vj] = A[vi, vj] * 2.0
Create the schedule and do vectorize:
.. code-block:: python
sch = tir.Schedule(before_vectorize)
i, j = sch.get_loops(sch.get_block("B"))
sch.vectorize(j)
After applying vectorize, the IR becomes:
.. code-block:: python
@tvm.script.tir
def after_vectorize(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i in tir.serial(0, 128):
for j in tir.vectorized(0, 128):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, i)
tir.bind(vj, j)
B[vi, vj] = A[vi, vj] * 2.0
"""
_ffi_api.ScheduleVectorize(self, loop) # type: ignore # pylint: disable=no-member

def bind(self, loop: LoopRV, thread_axis: str) -> None:
"""Bind the input loop to the given thread axis. It requires:
1) The scope block that the loop is in should have stage-pipeline property
2) All the blocks under the loop are complete blocks or reduction blocks, and have affine
bindings
3) For each block under the loop, if the thread axis starts with "threadIdx`, the loop can
only be contained in data-parallel block iter and reduction block iters' bindings. Otherwise
the loop can only be contained in data-parallel block iters' bindings
Parameters
----------
loop : LoopRV
The loop to be bound to the thread axis
thread_axis : str
The thread axis to be bound to the loop. Possible candidates:
- blockIdx.x/y/z
- threadIdx.x/y/z
- vthread
- vthread.x/y/z
Examples
--------
Before bind, in TensorIR, the IR is:
.. code-block:: python
@tvm.script.tir
def before_bind(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, i)
tir.bind(vj, j)
B[vi, vj] = A[vi, vj] * 2.0
Create the schedule and do bind:
.. code-block:: python
sch = tir.Schedule(before_bind)
i, j = sch.get_loops(sch.get_block("B"))
sch.bind(i, "blockIdx.x")
sch.bind(j, "threadIdx.x")
After applying bind, the IR becomes:
.. code-block:: python
@tvm.script.tir
def after_bind(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i in tir.thread_binding(0, 128, thread = "blockIdx.x"):
for j in tir.thread_binding(0, 128, thread = "threadIdx.x"):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, i)
tir.bind(vj, j)
B[vi, vj] = A[vi, vj] * 2.0
"""
_ffi_api.ScheduleBind(self, loop, String(thread_axis)) # type: ignore # pylint: disable=no-member

def unroll(self, loop: LoopRV) -> None:
"""Unroll the input loop. It requires nothing
Parameters
----------
loop : LoopRV
The loop to be unrolled
Examples
--------
Before unroll, in TensorIR, the IR is:
.. code-block:: python
@tvm.script.tir
def before_unroll(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, i)
tir.bind(vj, j)
B[vi, vj] = A[vi, vj] * 2.0
Create the schedule and do unroll:
.. code-block:: python
sch = tir.Schedule(before_unroll)
i, j = sch.get_loops(sch.get_block("B"))
sch.unroll(i)
After applying unroll, the IR becomes:
.. code-block:: python
@tvm.script.tir
def after_unroll(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i in tir.unroll(0, 128):
for j in tir.serial(0, 128):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, i)
tir.bind(vj, j)
B[vi, vj] = A[vi, vj] * 2.0
"""
_ffi_api.ScheduleUnroll(self, loop) # type: ignore # pylint: disable=no-member

########## Schedule: Insert cache stages ##########

########## Schedule: Compute location ##########
Expand Down Expand Up @@ -581,7 +808,7 @@ def rfactor(self, loop: LoopRV, factor_axis: int) -> LoopRV:
RFactor is a schedule primitive that implements the transformation described above:
Given a block that writes to buffer `B`, it factorizes a loop of extent `n`.
For example, the pesudocode below accumulates `B[i] = sum(A[i, : , : ])`:
For example, the pseudocode below accumulates `B[i] = sum(A[i, : , : ])`:
.. code-block:: python
Expand Down
23 changes: 23 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,20 @@ bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref,
void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& scope_root_sref);

/*!
* \brief Check whether a subtree on SRef tree has compact data flow, and throw an exception if the
* subtree does not have compact data flow
* \details For a given StmtSRef, We say the subtree rooted from the StmtSRef has "compact data
* flow" property if:
* - the scope root of the input subtree root has stage-pipeline property, and
* - all its child blocks on SRef tree are complete blocks or reduction blocks.
* \param self The schedule state
* \param subtree_root_sref The root of the subtree to be checked in the SRef tree
* \throw ScheduleError If the subtree does not have compact data flow
* \sa IsCompleteBlock, IsReductionBlock
*/
void CheckSRefSubtreeCompactDataFlow(const ScheduleState& self, const StmtSRef& subtree_root_sref);

/******** Binding ********/
/*!
* \brief Verifies if the block binding in a specific BlockRealize is an affine binding.
Expand All @@ -132,6 +146,15 @@ void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref,
bool IsAffineBinding(const BlockRealize& realize, const Map<Var, Range>& loop_var_ranges,
arith::Analyzer* analyzer);

/*!
* \brief Check whether a block has an affine binding using the cached flag, and throw an exception
* if the block does not have an affine binding.
* \param self The schedule state
* \param block The block to be checked
* \throw ScheduleError If the input block does not have an affine binding
*/
void CheckAffineBinding(const ScheduleState& self, Block block);

/*!
* \brief Extracts the ranges of loop variables in a path of the sref tree
* \param low_inclusive The lowest node in the path
Expand Down
Loading

0 comments on commit 6c86365

Please sign in to comment.