Skip to content

Commit

Permalink
[IRPrinter] Prevent multiple printing of optional info (#8279)
Browse files Browse the repository at this point in the history
* fix

* test
  • Loading branch information
hgt312 committed Jun 18, 2021
1 parent edb7e77 commit 1f0f8f1
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ namespace relay {
*/
Doc RelayTextPrinter::PrintOptionalInfo(const Expr& expr) {
Doc doc;
if (!opt_info_memo_.insert(expr).second) {
return doc;
}
// default annotations
if (annotate_ == nullptr) {
if ((expr.as<ConstantNode>() || expr.as<CallNode>()) && expr->checked_type_.defined()) {
Expand All @@ -65,7 +68,6 @@ Doc RelayTextPrinter::PrintOptionalInfo(const Expr& expr) {
doc << annotated_expr;
}
}

return doc;
}

Expand Down
2 changes: 2 additions & 0 deletions src/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
std::vector<Doc> doc_stack_{};
/*! \brief Set for introduced vars */
std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> var_memo_;
/*! \brief Set for exprs have been printed optional information */
std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> opt_info_memo_;
/*! \brief Map for result and memo_ diffs for visited expression */
std::unordered_map<Expr, Doc, ObjectPtrHash, ObjectPtrEqual> result_memo_;
/*! \brief Map from Expr to Doc */
Expand Down
9 changes: 9 additions & 0 deletions tests/python/relay/test_ir_text_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,5 +276,14 @@ def test_span():
assert "Add1" in txt


def test_optional_info():
c = relay.const(1)
call = relay.add(c, c)
m = tvm.IRModule.from_expr(call)
m = relay.transform.InferType()(m)
txt = astext(m)
assert txt.count("/* ty=int32 */") == 3


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 1f0f8f1

Please sign in to comment.