Skip to content

Commit

Permalink
[TensorIR][M2a] Parallel, Vectorize, Bind & Unroll (apache#8716)
Browse files Browse the repository at this point in the history
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
  • Loading branch information
6 people authored and ylc committed Jan 13, 2022
1 parent 2d3a4ed commit 1543108
Show file tree
Hide file tree
Showing 29 changed files with 1,600 additions and 40 deletions.
39 changes: 38 additions & 1 deletion include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class ScheduleNode : public runtime::Object {
* guaranteeing that
* 1) SRef tree is completely reconstructed;
* 2) The IRModule being scheduled is not modified;
* 3) All the random variables are valid in the copy, pointing to the correpsonding sref
* 3) All the random variables are valid in the copy, pointing to the corresponding sref
* reconstructed
*/
virtual Schedule Copy() const = 0;
Expand Down Expand Up @@ -220,6 +220,43 @@ class ScheduleNode : public runtime::Object {
*/
virtual Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors) = 0;
/******** Schedule: Manipulate ForKind ********/
/*!
* \brief Parallelize the input loop. It requires:
* 1) The scope block that the loop is in should have stage-pipeline property
* 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine
* bindings
* 3) For each block under the loop, the loop can only be contained in data-parallel block iters'
* bindings
* \param loop_rv The loop to be parallelized
*/
virtual void Parallel(const LoopRV& loop_rv) = 0;
/*!
* \brief Vectorize the input loop. It requires:
* 1) The scope block that the loop is in should have stage-pipeline property
* 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine
* bindings
* 3) For each block under the loop, the loop can only be contained in data-parallel block iters'
* bindings
* \param loop_rv The loop to be vectorized
*/
virtual void Vectorize(const LoopRV& loop_rv) = 0;
/*!
* \brief Bind the input loop to the given thread axis. It requires:
* 1) The scope block that the loop is in should have stage-pipeline property
* 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine
* bindings
* 3) For each block under the loop, if the thread axis starts with "threadIdx`, the loop can only
* be contained in data-parallel block iter and reduction block iters' bindings. Otherwise the
* loop can only be contained in data-parallel block iters' bindings
* \param loop_rv The loop to be bound to the thread axis
* \param thread_axis The thread axis to be bound to the loop
*/
virtual void Bind(const LoopRV& loop_rv, const String& thread_axis) = 0;
/*!
* \brief Unroll the input loop. It requires nothing
* \param loop_rv The loop to be unrolled
*/
virtual void Unroll(const LoopRV& loop_rv) = 0;
/******** Schedule: Insert cache stages ********/
/******** Schedule: Compute location ********/
/*!
Expand Down
12 changes: 12 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,18 @@ TVM_DLL Pass LowerMatchBuffer();
*/
TVM_DLL Pass FlattenBuffer();

/*!
* \brief Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and
* "vthread.x/y/z". Before the unification, two vars that are bound to a thread axis (e.g.,
* "threadIdx.x") use different IterVars and variables in their AttrStmts. After the
* unification, we use a consolidated IterVar and a variable for them.
* \return The pass.
* \note `vthread` is a legacy behavior that will be deprecated, though thread bindings of `vthread`
* are still also unified in this pass. Please use `vthread.x`, `vthread.y` and `vthread.z`
* instead.
*/
TVM_DLL Pass UnifyThreadBinding();

/*!
* A pass to merge multiple TIR-level dynamic shared memory allocations into one
*/
Expand Down
240 changes: 234 additions & 6 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def copy(self) -> "Schedule":
* guaranteeing that
* 1) SRef tree is completely reconstructed;
* 2) The IRModule being scheduled is untouched;
* 3) All the random variables are valid in the copy, pointing to the correpsonding sref
* 3) All the random variables are valid in the copy, pointing to the corresponding sref
* reconstructed
Returns
Expand Down Expand Up @@ -226,7 +226,7 @@ def get(
Returns
-------
result : Optional[Union[int, Block, For]]
The correpsonding result
The corresponding result
"""
if isinstance(rand_var_or_sref, StmtSRef):
return rand_var_or_sref.stmt
Expand All @@ -236,7 +236,7 @@ def get(
return result

def get_sref(self, rand_var_or_stmt: Union[BlockRV, LoopRV, Block, For]) -> Optional[StmtSRef]:
"""Returns the correpsonding sref to the given
"""Returns the corresponding sref to the given
1) LoopRV
2) BlockRV
3) Block
Expand All @@ -250,7 +250,7 @@ def get_sref(self, rand_var_or_stmt: Union[BlockRV, LoopRV, Block, For]) -> Opti
Returns
-------
result : Optional[StmtSRef]
The correpsonding result
The corresponding result
"""
return _ffi_api.ScheduleGetSRef( # type: ignore # pylint: disable=no-member
self, rand_var_or_stmt
Expand Down Expand Up @@ -413,7 +413,7 @@ def before_split(a: ty.handle, b: ty.handle) -> None:
with tir.block([128, 128], "B") as [vi, vj]:
B[vi, vj] = A[vi, vj] * 2.0
Create the schedule and do fuse:
Create the schedule and do split:
.. code-block:: python
Expand Down Expand Up @@ -444,6 +444,234 @@ def after_split(a: ty.handle, b: ty.handle) -> None:

########## Schedule: Manipulate ForKind ##########

def parallel(self, loop: LoopRV) -> None:
"""Parallelize the input loop. It requires:
1) The scope block that the loop is in should have stage-pipeline property
2) All the blocks under the loop are complete blocks or reduction blocks, and have affine
bindings
3) For each block under the loop, the loop can only be contained in data-parallel block
iters' bindings
Parameters
----------
loop : LoopRV
The loop to be parallelized
Examples
--------
Before parallel, in TensorIR, the IR is:
.. code-block:: python
@tvm.script.tir
def before_parallel(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, i)
tir.bind(vj, j)
B[vi, vj] = A[vi, vj] * 2.0
Create the schedule and do parallel:
.. code-block:: python
sch = tir.Schedule(before_parallel)
i, j = sch.get_loops(sch.get_block("B"))
sch.parallel(i)
After applying parallel, the IR becomes:
.. code-block:: python
@tvm.script.tir
def after_parallel(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i in tir.parallel(0, 128):
for j in tir.serial(0, 128):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, i)
tir.bind(vj, j)
B[vi, vj] = A[vi, vj] * 2.0
"""
_ffi_api.ScheduleParallel(self, loop) # type: ignore # pylint: disable=no-member

def vectorize(self, loop: LoopRV) -> None:
"""Vectorize the input loop. It requires:
1) The scope block that the loop is in should have stage-pipeline property
2) All the blocks under the loop are complete blocks or reduction blocks, and have affine
bindings
3) For each block under the loop, the loop can only be contained in data-parallel block
iters' bindings
Parameters
----------
loop : LoopRV
The loop to be vectorized
Examples
--------
Before vectorize, in TensorIR, the IR is:
.. code-block:: python
@tvm.script.tir
def before_vectorize(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, i)
tir.bind(vj, j)
B[vi, vj] = A[vi, vj] * 2.0
Create the schedule and do vectorize:
.. code-block:: python
sch = tir.Schedule(before_vectorize)
i, j = sch.get_loops(sch.get_block("B"))
sch.vectorize(j)
After applying vectorize, the IR becomes:
.. code-block:: python
@tvm.script.tir
def after_vectorize(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i in tir.serial(0, 128):
for j in tir.vectorized(0, 128):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, i)
tir.bind(vj, j)
B[vi, vj] = A[vi, vj] * 2.0
"""
_ffi_api.ScheduleVectorize(self, loop) # type: ignore # pylint: disable=no-member

def bind(self, loop: LoopRV, thread_axis: str) -> None:
"""Bind the input loop to the given thread axis. It requires:
1) The scope block that the loop is in should have stage-pipeline property
2) All the blocks under the loop are complete blocks or reduction blocks, and have affine
bindings
3) For each block under the loop, if the thread axis starts with "threadIdx`, the loop can
only be contained in data-parallel block iter and reduction block iters' bindings. Otherwise
the loop can only be contained in data-parallel block iters' bindings
Parameters
----------
loop : LoopRV
The loop to be bound to the thread axis
thread_axis : str
The thread axis to be bound to the loop. Possible candidates:
- blockIdx.x/y/z
- threadIdx.x/y/z
- vthread.x/y/z
- vthread (It is a legacy behavior that will be deprecated. Please use `vthread.x/y/z`
instead.)
Examples
--------
Before bind, in TensorIR, the IR is:
.. code-block:: python
@tvm.script.tir
def before_bind(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, i)
tir.bind(vj, j)
B[vi, vj] = A[vi, vj] * 2.0
Create the schedule and do bind:
.. code-block:: python
sch = tir.Schedule(before_bind)
i, j = sch.get_loops(sch.get_block("B"))
sch.bind(i, "blockIdx.x")
sch.bind(j, "threadIdx.x")
After applying bind, the IR becomes:
.. code-block:: python
@tvm.script.tir
def after_bind(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i in tir.thread_binding(0, 128, thread = "blockIdx.x"):
for j in tir.thread_binding(0, 128, thread = "threadIdx.x"):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, i)
tir.bind(vj, j)
B[vi, vj] = A[vi, vj] * 2.0
"""
_ffi_api.ScheduleBind(self, loop, thread_axis) # type: ignore # pylint: disable=no-member

def unroll(self, loop: LoopRV) -> None:
"""Unroll the input loop. It requires nothing
Parameters
----------
loop : LoopRV
The loop to be unrolled
Examples
--------
Before unroll, in TensorIR, the IR is:
.. code-block:: python
@tvm.script.tir
def before_unroll(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, i)
tir.bind(vj, j)
B[vi, vj] = A[vi, vj] * 2.0
Create the schedule and do unroll:
.. code-block:: python
sch = tir.Schedule(before_unroll)
i, j = sch.get_loops(sch.get_block("B"))
sch.unroll(i)
After applying unroll, the IR becomes:
.. code-block:: python
@tvm.script.tir
def after_unroll(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i in tir.unroll(0, 128):
for j in tir.serial(0, 128):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, i)
tir.bind(vj, j)
B[vi, vj] = A[vi, vj] * 2.0
"""
_ffi_api.ScheduleUnroll(self, loop) # type: ignore # pylint: disable=no-member

########## Schedule: Insert cache stages ##########

########## Schedule: Compute location ##########
Expand Down Expand Up @@ -581,7 +809,7 @@ def rfactor(self, loop: LoopRV, factor_axis: int) -> LoopRV:
RFactor is a schedule primitive that implements the transformation described above:
Given a block that writes to buffer `B`, it factorizes a loop of extent `n`.
For example, the pesudocode below accumulates `B[i] = sum(A[i, : , : ])`:
For example, the pseudocode below accumulates `B[i] = sum(A[i, : , : ])`:
.. code-block:: python
Expand Down
22 changes: 22 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,28 @@ def FlattenBuffer():
return _ffi_api.FlattenBuffer() # type: ignore


def UnifyThreadBinding():
"""Unify all the thread bindings for "blockIdx.x/y/z",
"threadIdx.x/y/z", and "vthread.x/y/z". Before the unification,
two vars that are bound to a thread axis (e.g., "threadIdx.x")
use different IterVars and variables in their AttrStmts. After
the unification, we use a consolidated IterVar and a variable
for them.
Returns
-------
fpass : tvm.transform.Pass
The result pass
Note
----
`vthread` is a legacy behavior that will be deprecated, though
thread bindings of `vthread` are still also unified in this
pass. Please use `vthread.x`, `vthread.y` and `vthread.z` instead.
"""
return _ffi_api.UnifyThreadBinding() # type: ignore


def MergeDynamicSharedMemoryAllocations():
"""This pass merges multiple TIR-level dynamic shared memory allocations
into one allocation.
Expand Down
1 change: 1 addition & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::UnifyThreadBinding());
pass_list.push_back(tir::transform::BF16Legalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
pass_list.push_back(tir::transform::Simplify());
Expand Down
Loading

0 comments on commit 1543108

Please sign in to comment.