Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorIR][M2a] Reduction Factoring (RFactor) #8544

Merged
merged 9 commits into from
Jul 31, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,22 +96,20 @@ TVM_DLL Array<Var> UndefinedVars(const PrimExpr& expr);
TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr);

/*!
* \brief Whether e expression used any var in variable set.
* \param expr The expression to be checked.
* \param vset_contains The check function to see if var is in the vset.
* \return Whether e uses vset.
* \brief Whether the given Stmt uses any var in the given variable set.
* \param stmt The Stmt to be checked.
* \param vset_contains The check function to see if a var is in the variable set.
* \return Whether `stmt` uses any var in the given variable set.
*/
TVM_DLL bool ExprUseVar(const PrimExpr& expr, std::function<bool(const VarNode*)> vset_contains);
TVM_DLL bool UsesVar(const Stmt& stmt, std::function<bool(const VarNode*)> vset_contains);

/*!
* \brief Whether e expression used var.
* \param expr The expression to be checked.
* \param var The variable.
* \return Whether e uses v.
* \brief Whether the given PrimExpr uses any var in the given variable set.
* \param expr The PrimExpr to be checked.
* \param vset_contains The check function to see if var is in the variable set.
* \return Whether `expr` uses any var in the given variable set.
*/
inline bool ExprUseVar(const PrimExpr& expr, const Var& var) {
return ExprUseVar(expr, [&](const VarNode* node) { return var.get() == node; });
}
TVM_DLL bool UsesVar(const PrimExpr& expr, std::function<bool(const VarNode*)> vset_contains);

/*!
* \brief Verifies whether the IR stmt or Expr is in SSA form.
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/tir/schedule/block_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ class BlockScope : public ObjectRef {
* \param child_block_srefs The srefs to the leaf blocks
* \note We assume the leaf blocks are given in pre-DFS order
*/
TVM_DLL BlockScope(const Array<StmtSRef>& child_block_srefs);
TVM_DLL explicit BlockScope(const Array<StmtSRef>& child_block_srefs);

TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockScope, ObjectRef, BlockScopeNode);
};
Expand Down
20 changes: 20 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,26 @@ class ScheduleNode : public runtime::Object {
/******** Schedule: loop binding/annotation ********/
/******** Schedule: cache read/write ********/
/******** Schedule: reduction ********/
/*!
* \brief Factorize an associative reduction block by the specified loop.
* \details An associative reduction cannot be parallelized directly,
* because it leads to potential race condition during accumulation.
Alternatively, the reduction could be factorized on a loop with the following steps:
- Step 1: evenly slice the reduction into `n` separate chunks, where `n` is the loop extent
- Step 2: compute the chunks separately and write the result into `n` intermediate buffers;
- Step 3: accumulate the `n` separate buffer into the result buffer.
Note that the Step 2 above introduces opportunities for parallelization.
RFactor is a schedule primitive that implements the transformation described above.


MasterJH5574 marked this conversation as resolved.
Show resolved Hide resolved
* \param loop_rv The loop outside block we want to do rfactor
* \param factor_axis The position where the new dimension is placed in the new introduced rfactor
* buffer. Suppose the original reduction block writes to buffer `B` with
* ndim(B) dimensions, then `factor_axis` should be in range `[-ndim(B) - 1,
* ndim(B)]`, and the negative index will be normalized to a non-negative one
* \return The rfactor block
*/
virtual BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) = 0;
/******** Schedule: blockize & tensorize ********/
};

Expand Down
19 changes: 19 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,7 @@ class For : public Stmt {
Map<String, ObjectRef> annotations = Map<String, ObjectRef>(), Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ForNode);
};

/*!
Expand Down Expand Up @@ -1361,6 +1362,24 @@ TVM_DLL PrimExpr TypeAnnotation(DataType dtype, Span span = Span());
// overload printing of for type.
TVM_DLL std::ostream& operator<<(std::ostream& os, ForKind kind);

// inline implementations
inline const char* ForKind2String(ForKind t) {
switch (t) {
case ForKind::kSerial:
return "serial";
case ForKind::kParallel:
return "parallel";
case ForKind::kVectorized:
return "vectorized";
case ForKind::kUnrolled:
return "unroll";
case ForKind::kThreadBinding:
return "thread_binding";
}
LOG(FATAL) << "Unknown ForKind";
MasterJH5574 marked this conversation as resolved.
Show resolved Hide resolved
return "Unknown";
}

} // namespace tir
} // namespace tvm
#endif // TVM_TIR_STMT_H_
2 changes: 1 addition & 1 deletion python/tvm/script/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def alloc_buffer(
data=None,
strides=None,
elem_offset=None,
scope="",
scope="global",
align=-1,
offset_factor=0,
buffer_type="default",
Expand Down
148 changes: 146 additions & 2 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def before_inline(a: ty.handle, c: ty.handle) -> None:

.. code-block:: python

sch = tir.Schedule(before_inline, debug_mode=True)
sch = tir.Schedule(before_inline)
sch.compute_inline(sch.get_block("B"))
print(tvm.script.asscript(sch.mod["main"]))

Expand Down Expand Up @@ -491,7 +491,7 @@ def before_inline(a: ty.handle, c: ty.handle) -> None:

.. code-block:: python

sch = tir.Schedule(before_inline, debug_mode=True)
sch = tir.Schedule(before_inline)
sch.reverse_compute_inline(sch.get_block("C"))
print(tvm.script.asscript(sch.mod["main"]))

Expand All @@ -512,6 +512,150 @@ def after_inline(a: ty.handle, c: ty.handle) -> None:
########## Schedule: loop binding/annotation ##########
########## Schedule: cache read/write ##########
########## Schedule: reduction ##########
def rfactor(self, loop: LoopRV, factor_axis: int) -> LoopRV:
"""Factorize an associative reduction block by the specified loop.

An associative reduction cannot be parallelized directly,
because it leads to potential race condition during accumulation.
Alternatively, the reduction could be factorized on a loop with the following steps:
- Step 1: evenly slice the reduction into `n` separate chunks, where `n` is the loop extent
- Step 2: compute the chunks separately and write the result into `n` intermediate buffers;
- Step 3: accumulate the `n` separate buffer into the result buffer.
Note that the Step 2 above introduces opportunities for parallelization.

RFactor is a schedule primitive that implements the transformation described above:
Given a block that writes to buffer `B`, it factorizes a loop of extent `n`.

For example, the pesudocode below accumulates `B[i] = sum(A[i, : , : ])`:


MasterJH5574 marked this conversation as resolved.
Show resolved Hide resolved
.. code-block:: python

for i in range(128): # loop i is a data parallel loop
for j in range(128): # loop j is a reduction loop
for k in range(128): # loop k is a reduction loop
B[i] = B[i] + A[i, j, k]


MasterJH5574 marked this conversation as resolved.
Show resolved Hide resolved
Suppose RFactor is applied on the innermost loop `k` and `factor_axis = 1`.
RFactor then creates an intermediate buffer and two blocks.

- The intermediate buffer, or "rf-buffer" is a buffer of rank `ndim(B) + 1` and
size `size(B) * n`, whose shape expands from `shape(B)` by adding an axis of `n`
at the position specified by `factor_axis`. For example,

* shape(B) = [1, 2, 3], factor_axis = 0 => shape(B_rf) = [n, 1, 2, 3]
* shape(B) = [1, 2, 3], factor_axis = 1 => shape(B_rf) = [1, n, 2, 3]
* shape(B) = [1, 2, 3], factor_axis = 2 => shape(B_rf) = [1, 2, n, 3]
* shape(B) = [1, 2, 3], factor_axis = 3 => shape(B_rf) = [1, 2, 3, n]

- The rfactor block, or "rf-block", is a block that writes to the `rf-buffer` without
accumulating over the loop `k`, i.e. the loop `k` is converted from a reduction loop
to a data parallel loop. In our example, the rf-block is:


MasterJH5574 marked this conversation as resolved.
Show resolved Hide resolved
.. code-block:: python

B_rf = np.zeros((128, 128)) # the rf-buffer
for k in range(128): # loop k is converted to a data parallel loop
for i in range(128): # loop i is a data parallel loop (unchanged)
for j in range(128): # loop j is a reduction loop (unchanged)
B_rf[i, k] = B_rf[i, k] + A[i, j, k]


MasterJH5574 marked this conversation as resolved.
Show resolved Hide resolved
- The write-back block, or `wb-block`, is a block that accumulates the rf-buffer into
the result buffer. All the reduction loops are removed except the loop `k` for accumulation.
In our example, the wb-block is:

.. code-block:: python

for i in range(128): # loop i is a data parallel loop (unchanged)
# loop j is removed because it is a reduction loop
for k in range(128): # loop k is a reduction loop (unchanged)
B[i] = B[i] + B_rf[i, k]

Parameters
----------
loop : LoopRV
The loop outside block for which we want to do rfactor
factor_axis : int
The position where the new dimension is placed in the new introduced rfactor buffer

Returns
-------
rf_block : BlockRV
The block which computes partial results over each slices (i.e., the first block
as described in the above illustration)

Examples
--------

Before rfactor, in TensorIR, the IR is:

.. code-block:: python

@tvm.script.tir
def before_rfactor(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128, 128), "float32")
B = tir.match_buffer(b, (128,), "float32")
with tir.block([128, tir.reduce_axis(0, 128),
tir.reduce_axis(0, 128)], "B") as [vii, vi, vj]:
with tir.init():
B[vii] = 0.0
B[vii] = B[vii] + A[vii, vi, vj]

Create the schedule and do rfactor:

.. code-block:: python

sch = tir.Schedule(before_rfactor)
_, _, k = sch.get_loops(sch.get_block("B"))
sch.rfactor(k, 0)
print(tvm.script.asscript(sch.mod["main"]))

After applying rfactor, the IR becomes:

.. code-block:: python

@tvm.script.tir
def after_rfactor(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, [128, 128, 128])
B = tir.match_buffer(b, [128])
MasterJH5574 marked this conversation as resolved.
Show resolved Hide resolved
B_rf = tir.alloc_buffer([128, 128])
with tir.block([128, 128, tir.reduce_axis(0, 128)], "B_rf") as [vi2, vii, vi]:
with tir.init():
B_rf[vi2, vii] = 0.0
B_rf[vi2, vii] = (B_rf[vi2, vii] + A[vii, vi, vi2])
with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vii_1, vi2_1]:
with tir.init():
B[vii_1] = 0.0
B[vii_1] = (B[vii_1] + B_rf[vi2_1, vii_1])


Note
----

Rfactor requires:
1) `loop` has only one child block, and it is a reduction block;
2) `loop` is a reduction loop, i.e. the loop variable is bound to only reduction variables
in the block binding;
3) `loop` is not parallelized, vectorized, unrolled or bound to any thread axis;
4) The block scope that `loop` is in is a staged-pipeline;
5) The outermost loop outside the reduction block should has the reduction block as its first child block;
6) The outermost reduction loop should have only one child block;
7) An unary extent loop that is not bound to any reduction or data parallel variables in the block binding
should not appear under some reduction loop;
8) The reduction block should write to only one buffer, and its init and body block only is
MasterJH5574 marked this conversation as resolved.
Show resolved Hide resolved
a simple `BufferStore`, and the pattern is registered as associative reducer.
The pre-defined patterns include: plus, multiplication, min and max;
9) Each of the loops on top of the block cannot be bound to a data parallel and a reduction
block binding at the same time;
10) `factor_axis` should be in range `[-ndim(B) - 1, ndim(B)]`,
where `B` is the buffer that the reduction block writes to.
Negative indexing is normalized according to numpy convention.
"""
return _ffi_api_schedule.ScheduleRFactor(self, loop, factor_axis) # type: ignore # pylint: disable=no-member

########## Schedule: blockize & tensorize ##########


Expand Down
6 changes: 4 additions & 2 deletions src/arith/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1137,8 +1137,10 @@ PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op)
// and recursively mark the corresponding components
for (size_t i = 0; i < simplified_result.size(); ++i)
if (!used[i]) {
if (ExprUseVar(simplified_result[idx], op->combiner->lhs[i]) ||
ExprUseVar(simplified_result[idx], op->combiner->rhs[i]))
if (UsesVar(simplified_result[idx],
[v = op->combiner->lhs[i].get()](const VarNode* var) { return var == v; }) ||
UsesVar(simplified_result[idx],
[v = op->combiner->rhs[i].get()](const VarNode* var) { return var == v; }))
mark_used(i);
}
};
Expand Down
4 changes: 2 additions & 2 deletions src/arith/detect_linear_equation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class LinearEqDetector : public ExprFunctor<LinearEqEntry(const PrimExpr&, const
}
LinearEqEntry VisitExprDefault_(const Object* op, const PrimExpr& e) final {
if (fail_) return LinearEqEntry();
if (ExprUseVar(e, var_)) {
if (UsesVar(e, [this](const VarNode* var) { return var == var_.get(); })) {
fail_ = true;
return LinearEqEntry();
} else {
Expand Down Expand Up @@ -159,7 +159,7 @@ Array<PrimExpr> DetectLinearEquation(const PrimExpr& e, const Array<Var>& vars)
for (size_t i = vars.size(); i > 1; --i) {
vset.insert(vars[i - 1].get());
// The previous coeff contains the variable
if (ExprUseVar(coeff[i - 2], vset_contains)) {
if (UsesVar(coeff[i - 2], vset_contains)) {
return Array<PrimExpr>();
}
}
Expand Down
17 changes: 0 additions & 17 deletions src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -485,23 +485,6 @@ Doc TIRTextPrinter::VisitStmt_(const EvaluateNode* op) {
return doc;
}

inline const char* ForKind2String(ForKind t) {
switch (t) {
case ForKind::kSerial:
return "serial";
case ForKind::kParallel:
return "parallel";
case ForKind::kVectorized:
return "vectorized";
case ForKind::kUnrolled:
return "unroll";
case ForKind::kThreadBinding:
return "thread_binding";
}
LOG(FATAL) << "Unknown ForKind";
return "Unknown";
}

Doc TIRTextPrinter::VisitStmt_(const ForNode* op) {
Doc doc;
doc << "for (" << Print(op->loop_var) << ", " << Print(op->min) << ", "
Expand Down
17 changes: 0 additions & 17 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -710,23 +710,6 @@ Doc TVMScriptPrinter::VisitStmt_(const EvaluateNode* op) {
return doc;
}

inline const char* ForKind2String(ForKind t) {
switch (t) {
case ForKind::kSerial:
return "serial";
case ForKind::kParallel:
return "parallel";
case ForKind::kVectorized:
return "vectorized";
case ForKind::kUnrolled:
return "unroll";
case ForKind::kThreadBinding:
return "thread_binding";
}
LOG(FATAL) << "Unknown ForKind";
return "Unknown";
}

Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) {
Doc doc;
var_not_in_headers.insert(op->loop_var.get());
Expand Down
4 changes: 2 additions & 2 deletions src/te/autodiff/ad_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,7 @@ std::pair<PrimExpr, PrimExpr> ImplicationNotContainingVars(
return {pair_a.first || pair_b.first, (pair_a.first || pair_b.second) &&
(pair_b.first || pair_a.second) &&
(pair_a.second || pair_b.second)};
} else if (!tir::ExprUseVar(cond, [&vars](const VarNode* var) { return vars.count(var); })) {
} else if (!tir::UsesVar(cond, [&vars](const VarNode* var) { return vars.count(var); })) {
return {cond, const_true()};
} else {
return {const_true(), cond};
Expand Down Expand Up @@ -1014,7 +1014,7 @@ PrimExpr TrySimplifyCompute(const PrimExpr& expr, const PrimExpr& cond,
// Keep only those variables of the new vars which are used in the new_expr
Array<Var> used_res_variables;
for (const Var& var : res->dst->variables) {
if (ExprUseVar(new_expr, var)) {
if (tir::UsesVar(new_expr, [&var](const VarNode* var_) { return var_ == var.get(); })) {
ICHECK(res->dst->ranges.count(var)) << "Range of " << var << " cannot be inferred.";
used_res_variables.push_back(var);
}
Expand Down
2 changes: 1 addition & 1 deletion src/te/operation/compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ Stmt TransformUpdate(const Stage& stage, const std::unordered_map<IterVar, Range
auto fbanned = [&](const VarNode* node) { return banned.count(node); };

for (const PrimExpr& pred : n.main_predicates) {
if (tir::ExprUseVar(pred, fbanned)) {
if (tir::UsesVar(pred, fbanned)) {
LOG(FATAL) << "Tensorize update transform failed, the condition " << pred
<< " has a conflict with the reset condition";
}
Expand Down
Loading