Skip to content

Commit

Permalink
[TIR][REFACTIR] RewriteForTensorCore -> te/schedule
Browse files Browse the repository at this point in the history
RewriteForTensor depends on the schedule information, which makes it differ
from a typical pass(which should get all the information from the input TIR).

As a result, we refactor it as a SchedulePostProc step for now.
We should revisit it later as we introduce more support for tensor core patterns in the TIR.
  • Loading branch information
tqchen committed Apr 19, 2020
1 parent c3511c5 commit 016d03a
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 72 deletions.
49 changes: 32 additions & 17 deletions include/tvm/te/schedule_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,23 @@
namespace tvm {
namespace te {

/*!
* \brief To automatically inline the element-wise operations.
*
* \param sch The schedule to be inlined.
*/
void AutoInlineElemWise(Schedule sch);

/*!
* \brief To automatically inline operations with injective writes
* (i.e. writes without reduction or sequential loops). Note
* that in this case, guarantees about contiguity, transpose, stride,
* alignemnt and memory footprint in general do not hold.
*
* \param sch The schedule to be inlined.
*/
TVM_DLL void AutoInlineInjective(Schedule sch);

/*!
* \brief Infer the bound of all iteration variables relates to the schedule.
*
Expand All @@ -55,6 +72,21 @@ Map<IterVar, Range> InferBound(const Schedule& sch);
*/
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map, bool debug_keep_trivial_loop);


/*!
* \brief Try to modify the AST generated by ScheduleOps to support TensorCore.
*
* \param stmt The stmt to be trasnformed.
* \param schedule The original schedule.
* \param extern_buffer Map specifies external
* buffer assignment of input and outputs.
* \return Transformed stmt.
*/
Stmt SchedulePostProcRewriteForTensorCore(
Stmt stmt,
Schedule schedule,
Map<Tensor, Buffer> extern_buffer);

/*!
* \brief Postprocessing the Stmt generated by ScheduleOps to create
* a PrimFunc that can then be used for further TIR optimizations.
Expand All @@ -75,23 +107,6 @@ PrimFunc SchedulePostProcToPrimFunc(Array<ObjectRef> arg_list,
Stmt body,
Optional<Map<Tensor, Buffer>> bindings);

/*!
* \brief To automatically inline the element-wise operations.
*
* \param sch The schedule to be inlined.
*/
void AutoInlineElemWise(Schedule sch);

/*!
* \brief To automatically inline operations with injective writes
* (i.e. writes without reduction or sequential loops). Note
* that in this case, guarantees about contiguity, transpose, stride,
* alignemnt and memory footprint in general do not hold.
*
* \param sch The schedule to be inlined.
*/
TVM_DLL void AutoInlineInjective(Schedule sch);

} // namespace te
} // namespace tvm
#endif // TVM_TE_SCHEDULE_PASS_H_
13 changes: 0 additions & 13 deletions include/tvm/tir/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,19 +164,6 @@ Stmt Inline(Stmt stmt,
Array<Var> args,
PrimExpr body);

/*!
* \brief Try to modify the AST to support TensorCore
*
* \param stmt The stmt to be trasnformed.
* \param schedule The original schedule.
* \param extern_buffer Map specifies external
* buffer assignment of input and outputs.
* \return Transformed stmt.
*/
Stmt RewriteForTensorCore(Stmt stmt,
te::Schedule schedule,
Map<te::Tensor, Buffer> extern_buffer);

/*!
* \brief Verify if there is any argument bound to compact buffer.
*
Expand Down
45 changes: 27 additions & 18 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,23 +84,43 @@ def get_binds(args, compact=False, binds=None):
return binds, arg_list


def form_body(sch):
def form_irmodule(sch, args, name, binds):
"""According to the given schedule, form a function.
Parameters
----------
sch : tvm.te.schedule.Schedule
The given scheduler to form the raw body
The given scheduler to form the raw body
args : list of Buffer or Tensor or Var
The argument lists to the function.
name : str
The name of result function.
binds : dict of :any:`Tensor` to :any:`Buffer`, optional
The binds information
Returns
-------
The body formed according to the given schedule
"""
# normalize schedule first
cfg = BuildConfig.current()
sch = sch.normalize()
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
return stmt

compact = ir_pass.VerifyCompactBuffer(stmt)
binds, arg_list = get_binds(args, compact, binds)

stmt = schedule.SchedulePostProcRewriteForTensorCore(stmt, sch, binds)
func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds)

func = func.with_attr("global_symbol", name)
if cfg.restricted_func:
func = func.with_attr("tir.noalias", True)
return tvm.IRModule({name: func})


def _wrap_as_prim_func_pass(flist, name):
Expand Down Expand Up @@ -166,24 +186,13 @@ def lower(sch,

# Phase 0
if isinstance(sch, schedule.Schedule):
stmt = form_body(sch)

for f in lower_phase0:
stmt = f(stmt)

compact = ir_pass.VerifyCompactBuffer(stmt)
binds, arg_list = get_binds(args, compact, binds)
stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)

# Start the new style pass manager.
func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds)
func = func.with_attr("global_symbol", name)
if cfg.restricted_func:
func = func.with_attr("tir.noalias", True)
mod = tvm.IRModule({name: func})
mod = form_irmodule(sch, args, name, binds)
else:
mod = sch

# Phase 1
pass_list = [
_wrap_as_prim_func_pass(lower_phase0, "Custom-Phase0"),
tvm.tir.transform.InjectPrefetch(),
tvm.tir.transform.StorageFlatten(64, cfg.instrument_bound_checkers),
tvm.tir.transform.NarrowDataType(32),
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/te/hybrid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
# 2. Support multi-level HalideIR
import inspect
import tvm._ffi
from tvm.driver.build_module import form_body
import tvm.te.schedule
from tvm._ffi.base import decorate

from .module import HybridModule
Expand Down Expand Up @@ -87,8 +87,10 @@ def build(sch, inputs, outputs, name="hybrid_func"):
The built results is wrapped in a HybridModule.
The usage of HybridModule is roughly the same as normal TVM-built modules.
"""
sch = sch.normalize()
bounds = tvm.te.schedule.InferBound(sch)
stmt = tvm.te.schedule.ScheduleOps(sch, bounds)

stmt = form_body(sch)
src = _Dump(stmt, inputs, outputs, name)

return HybridModule(src, name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
*/

/*!
* \file tensor_core.cc
* \file schedule_postproc_rewrite_for_tensor_core.cc
*
* \brief Rewrite the Stmt generated by ScheduleOps
* to accomondate tensorcore.
*/
// IR Passes for TensorCore CodeGen
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/te/operation.h>
Expand All @@ -32,12 +35,11 @@
#include <tvm/target/target.h>
#include <tvm/runtime/device_api.h>
#include <unordered_map>
#include "ir_util.h"
#include "../../arith/compute_expr.h"
#include "../../runtime/thread_storage_scope.h"

namespace tvm {
namespace tir {
namespace te {

using namespace te;
using runtime::StorageRank;
Expand Down Expand Up @@ -86,10 +88,10 @@ class MMAMatcher: public StmtVisitor {
}

void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::pragma_tensor_core) {
if (op->attr_key == tir::attr::pragma_tensor_core) {
tensor_core_on_ = true;
StmtVisitor::VisitStmt_(op);
} else if (op->attr_key == attr::realize_scope) {
} else if (op->attr_key == tir::attr::realize_scope) {
storage_scope_[op->node.get()] = op->value.as<StringImmNode>()->value;
this->VisitStmt(op->body);
} else {
Expand Down Expand Up @@ -414,18 +416,18 @@ class BufferAnalyser : public StmtExprVisitor {
}

void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
if (op->attr_key == tir::attr::thread_extent) {
if (const IntImmNode* value = op->value.as<IntImmNode>()) {
thread_extent_.insert(
std::make_pair(
op->node.as<IterVarNode>()->var->name_hint,
value->value));
}
StmtExprVisitor::VisitStmt_(op);
} else if (op->attr_key == attr::realize_scope) {
} else if (op->attr_key == tir::attr::realize_scope) {
storage_scope_[op->node.get()] = op->value.as<StringImmNode>()->value;
this->VisitStmt(op->body);
} else if (op->attr_key == attr::buffer_dim_align) {
} else if (op->attr_key == tir::attr::buffer_dim_align) {
te::Tensor tensor = Downcast<te::Tensor>(op->node);
const CallNode* tuple = op->value.as<CallNode>();
CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
Expand Down Expand Up @@ -850,7 +852,7 @@ class TensorCoreIRMutator : public StmtExprMutator {

Stmt VisitStmt_(const AttrStmtNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
if (op->attr_key == attr::realize_scope) {
if (op->attr_key == tir::attr::realize_scope) {
auto node = op->node.as<te::OperationNode>();
if (node != nullptr) {
if (!frag_reg_.count(node->name)) {
Expand Down Expand Up @@ -1186,9 +1188,10 @@ class TensorCoreIRMutator : public StmtExprMutator {
int warp_threads_y_{-1};
};

Stmt RewriteForTensorCore(Stmt stmt,
Schedule schedule,
Map<Tensor, Buffer> extern_buffer) {
Stmt SchedulePostProcRewriteForTensorCore(
Stmt stmt,
Schedule schedule,
Map<Tensor, Buffer> extern_buffer) {
// Check if current lower target is CUDA
auto target = tvm::Target::Current(true);
if (target.defined() && target->target_name != "cuda") {
Expand Down Expand Up @@ -1223,5 +1226,13 @@ Stmt RewriteForTensorCore(Stmt stmt,
return TensorCoreIRMutator(schedule_analyser, buffer_analyser)(std::move(stmt));
}

} // namespace tir
TVM_REGISTER_GLOBAL("schedule.SchedulePostProcRewriteForTensorCore")
.set_body_typed([](Stmt stmt,
Schedule schedule,
Map<te::Tensor, Buffer> extern_buffer) {
return SchedulePostProcRewriteForTensorCore(
stmt, schedule, extern_buffer);
});

} // namespace te
} // namespace tvm
8 changes: 0 additions & 8 deletions src/tir/pass/ffi_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,6 @@ TVM_REGISTER_GLOBAL("ir_pass.Substitute")
}
});

TVM_REGISTER_GLOBAL("ir_pass.RewriteForTensorCore")
.set_body_typed
([](const Stmt& stmt,
const te::Schedule& schedule,
const Map<te::Tensor, Buffer>& extern_buffer) {
return RewriteForTensorCore(stmt, schedule, extern_buffer);
});

TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = ExprUseVar(args[0].operator PrimExpr(), args[1].operator Var());
Expand Down

0 comments on commit 016d03a

Please sign in to comment.