Skip to content

Commit

Permalink
[TVMSCRIPT] Fix printing of rank 0 buffer access (apache#8215)
Browse files Browse the repository at this point in the history
* [TVMSCRIPT] Fix printing of rank 0 buffer access

Also improve error messages and fix min/max/Select.

* fixes

* return fix

* remove print
  • Loading branch information
tkonolige authored and trevor-m committed Jun 17, 2021
1 parent b15d74c commit d50f5f4
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 40 deletions.
5 changes: 5 additions & 0 deletions python/tvm/script/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,11 @@ def opaque_axis(begin, end, span):
return get_axis(begin, end, "opaque", span)


@register
def Select(cond, if_body, else_body, span): # pylint: disable=invalid-name
return tvm.tir.Select(cond, if_body, else_body, span)


@register
class EvaluateIntrin(Intrin):
def __init__(self):
Expand Down
31 changes: 20 additions & 11 deletions python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,14 +749,17 @@ def f():
node.call.func_name.span,
)

if isinstance(func, Intrin) and func.stmt:
return call_with_error_reporting(
self.report_error,
node.call.func_name.span,
func.handle,
arg_list,
node.call.func_name.span,
)
if isinstance(func, Intrin):
if func.stmt:
return call_with_error_reporting(
self.report_error,
node.call.func_name.span,
func.handle,
arg_list,
node.call.func_name.span,
)
else:
self.report_error(f"This intrinsic cannot be used as a statement.", node.call.span)
elif isinstance(func, WithScopeHandler) and func.concise_scope and not func.def_symbol:
func.enter_scope(node, self.context, arg_list, node.call.func_name.span)
func.body = self.parse_body(node)
Expand All @@ -765,7 +768,11 @@ def f():
func.handle(node, self.context, arg_list, node.call.func_name.span)
return

self.report_error(f"Invalid Expr stmt {type(func).__name__}.", node.call.func_name.span)
self.report_error(
"Unexpected statement. Expected an assert, an intrinsic, a with statement, or a "
f"special statement, but got {type(func).__name__}.",
node.call.func_name.span,
)

def transform_Slice(self, node):
start = self.transform(node.start)
Expand All @@ -785,7 +792,9 @@ def transform_Subscript(self, node):

symbol = self.transform(node.params[0])
if symbol is None:
self.report_error(f"Variable {node.value.id} is not defined.", node.params[0].span)
self.report_error(
f"Variable {node.params[0].id.name} is not defined.", node.params[0].span
)

indexes = [self.transform(x) for x in node.params[1].values]
if isinstance(symbol, tvm.tir.expr.Var):
Expand Down Expand Up @@ -844,7 +853,7 @@ def transform_Attr(self, node):
self.report_error("Unsupported Attribute expression.", node.object.span)
if not hasattr(symbol, node.field.name):
self.report_error(
f"Type {type(symbol)} does not have a field called `{node.field}`.", node.span
f"Type {type(symbol)} does not have a field called `{node.field.name}`.", node.span
)
res = getattr(symbol, node.field.name)
return res
Expand Down
69 changes: 40 additions & 29 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc VisitExpr_(const IntImmNode* op) override;
Doc VisitExpr_(const FloatImmNode* op) override;
Doc VisitExpr_(const StringImmNode* op) override;
Doc VisitExpr_(const ProducerLoadNode* op) override;
Doc VisitExpr_(const BufferLoadNode* op) override;
Doc VisitExpr_(const LoadNode* op) override;
Doc VisitExpr_(const RampNode* op) override;
Expand Down Expand Up @@ -387,19 +388,19 @@ Doc TVMScriptPrinter::Print(const ObjectRef& node) {
} else if (node->IsInstance<MatchBufferRegionNode>()) {
return PrintMatchBufferRegion(node.as<MatchBufferRegionNode>());
} else {
meta_collector_.Collect(node);
return this->meta_.GetMetaNode(node);
LOG(FATAL) << "Do not know how to print " << node->GetTypeKey();
return Doc();
}
}

Doc TVMScriptPrinter::VisitExprDefault_(const Object* op) {
meta_collector_.Collect(GetRef<ObjectRef>(op));
return this->meta_.GetMetaNode(GetRef<ObjectRef>(op));
LOG(FATAL) << "Do not know how to print " << op->GetTypeKey();
return Doc();
}

Doc TVMScriptPrinter::VisitStmtDefault_(const Object* op) {
meta_collector_.Collect(GetRef<ObjectRef>(op));
return this->meta_.GetMetaNode(GetRef<ObjectRef>(op));
LOG(FATAL) << "Do not know how to print " << op->GetTypeKey();
return Doc();
}

Doc TVMScriptPrinter::VisitExpr_(const IntImmNode* op) {
Expand All @@ -414,11 +415,7 @@ Doc TVMScriptPrinter::VisitExpr_(const StringImmNode* op) { return Doc::StrLiter

Doc TVMScriptPrinter::VisitExpr_(const CastNode* op) {
Doc doc;
if (cast(op->dtype, op->value)->IsInstance<CastNode>()) {
doc << Print(op->value) << ".astype(" << PrintDType(op->dtype) << ")";
} else {
doc << "tir.cast(" << Print(op->value) << ", " << PrintDType(op->dtype) << ")";
}
doc << "tir.cast(" << Print(op->value) << ", " << PrintDType(op->dtype) << ")";
return doc;
}

Expand Down Expand Up @@ -480,14 +477,24 @@ Doc TVMScriptPrinter::VisitExpr_(const NotNode* op) {

Doc TVMScriptPrinter::VisitExpr_(const SelectNode* op) {
Doc doc;
doc << "tir.select(" << Print(op->condition) << ", " << Print(op->true_value) << ", "
doc << "tir.Select(" << Print(op->condition) << ", " << Print(op->true_value) << ", "
<< Print(op->false_value) << ")";
return doc;
}

Doc TVMScriptPrinter::VisitExpr_(const ProducerLoadNode* op) {
LOG(FATAL) << "Cannot print a tir.ProducerLoad as it is not valid in TIR Primfuncs. You need to "
"lower this function first.";
return Doc();
}

Doc TVMScriptPrinter::VisitExpr_(const BufferLoadNode* op) {
Doc doc;
doc << Print(op->buffer) << Print(op->indices);
if (op->indices.size() == 0) {
doc << Print(op->buffer) << "[()]";
} else {
doc << Print(op->buffer) << Print(op->indices);
}
return doc;
}

Expand Down Expand Up @@ -661,12 +668,8 @@ Doc TVMScriptPrinter::VisitStmt_(const AssertStmtNode* op) {

Doc TVMScriptPrinter::VisitStmt_(const StoreNode* op) {
Doc doc;
if (!is_one(op->predicate) || op->value.dtype().lanes() != 1) {
doc << "tir.store(" << Print(op->buffer_var) << ", " << Print(op->index) << ", "
<< Print(op->value) << ", " << Print(op->predicate) << ")";
} else {
doc << Print(op->buffer_var) << "[" << Print(op->index) << "] = " << Print(op->value);
}
doc << "tir.store(" << Print(op->buffer_var) << ", " << Print(op->index) << ", "
<< Print(op->value) << ", " << Print(op->predicate) << ")";
return doc;
}

Expand Down Expand Up @@ -786,7 +789,11 @@ Doc TVMScriptPrinter::VisitType_(const TupleTypeNode* node) {

Doc TVMScriptPrinter::VisitStmt_(const BufferStoreNode* op) {
Doc doc;
doc << Print(op->buffer) << Print(op->indices) << " = " << Print(op->value);
if (op->indices.size() == 0) {
doc << Print(op->buffer) << "[()] = " << Print(op->value);
} else {
doc << Print(op->buffer) << Print(op->indices) << " = " << Print(op->value);
}
return doc;
}

Expand Down Expand Up @@ -1051,17 +1058,21 @@ Doc TVMScriptPrinter::PrintBuffer(const BufferNode* op) {

Doc TVMScriptPrinter::PrintBufferRegion(const BufferRegionNode* op) {
Doc doc;
doc << Print(op->buffer) << "[";
for (size_t i = 0; i < op->region.size(); ++i) {
if (i != 0) doc << ", ";
const auto& range = op->region[i];
if (!is_one(range->extent)) {
doc << Print(range->min) << ":" << Print(range->min + range->extent);
} else {
doc << Print(range->min);
if (op->region.size() == 0) {
doc << Print(op->buffer) << "[()]";
} else {
doc << Print(op->buffer) << "[";
for (size_t i = 0; i < op->region.size(); ++i) {
if (i != 0) doc << ", ";
const auto& range = op->region[i];
if (!is_one(range->extent)) {
doc << Print(range->min) << ":" << Print(range->min + range->extent);
} else {
doc << Print(range->min);
}
}
doc << "]";
}
doc << "]";
return doc;
}

Expand Down
58 changes: 58 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -2888,6 +2888,64 @@ def test_opaque_block():
assert len(root_block.body.body[1].block.iter_vars) == 0


@tvm.script.tir
def rank0(a: ty.handle) -> None:
A = tir.match_buffer(a, (), "float32")
B = tir.alloc_buffer((), "float32")
A[()] = 2
B[()] = A[()]


def test_rank0_buffers():
func = rank0
rt_func = tvm.script.from_source(tvm.script.asscript(func, True))
tvm.ir.assert_structural_equal(func, rt_func)


@tvm.script.tir
def rank0_block(a: ty.handle) -> None:
A = tir.match_buffer(a, (), "float32")
B = tir.alloc_buffer((), "float32")
tir.store(B.data, 0, tir.load("float32", A.data, 0))

with tir.block([], "update") as []:
tir.reads([A[()]])
tir.writes([B[()]])
for i in range(0, 1):
B[()] = A[()]


def test_rank0_blocks():
func = rank0_block
rt_func = tvm.script.from_source(tvm.script.asscript(func, True))
tvm.ir.assert_structural_equal(func, rt_func)


@tvm.script.tir
def select(a: ty.handle) -> None:
A = tir.match_buffer(a, (), "float32")
A[()] = tir.Select(True, 1, 2)


def test_select():
func = select
rt_func = tvm.script.from_source(tvm.script.asscript(func, True))
tvm.ir.assert_structural_equal(func, rt_func)


@tvm.script.tir
def minmax(a: ty.handle) -> None:
A = tir.match_buffer(a, (), "float32")
A[()] = tir.min(1, 2)
A[()] = tir.max(1, 2)


def test_minmax():
func = minmax
rt_func = tvm.script.from_source(tvm.script.asscript(func, True))
tvm.ir.assert_structural_equal(func, rt_func)


if __name__ == "__main__":
test_opt_gemm_normalize()
test_opt_gemm_mod_host()
Expand Down

0 comments on commit d50f5f4

Please sign in to comment.