From 1f0f8f19b98093332ba3253d96551b1e499bf984 Mon Sep 17 00:00:00 2001 From: "Huang, Guangtai" Date: Sat, 19 Jun 2021 05:59:28 +0800 Subject: [PATCH] [IRPrinter] Prevent multiple printing of optional info (#8279) * fix * test --- src/printer/relay_text_printer.cc | 4 +++- src/printer/text_printer.h | 2 ++ tests/python/relay/test_ir_text_printer.py | 9 +++++++++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 2de331be9581..aad42fc9b0ea 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -54,6 +54,9 @@ namespace relay { */ Doc RelayTextPrinter::PrintOptionalInfo(const Expr& expr) { Doc doc; + if (!opt_info_memo_.insert(expr).second) { + return doc; + } // default annotations if (annotate_ == nullptr) { if ((expr.as() || expr.as()) && expr->checked_type_.defined()) { @@ -65,7 +68,6 @@ Doc RelayTextPrinter::PrintOptionalInfo(const Expr& expr) { doc << annotated_expr; } } - return doc; } diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 52ab701008c7..7a529cc0b914 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -176,6 +176,8 @@ class RelayTextPrinter : public ExprFunctor, std::vector doc_stack_{}; /*! \brief Set for introduced vars */ std::unordered_set var_memo_; + /*! \brief Set for exprs have been printed optional information */ + std::unordered_set opt_info_memo_; /*! \brief Map for result and memo_ diffs for visited expression */ std::unordered_map result_memo_; /*! \brief Map from Expr to Doc */ diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 4968660b95c8..b4d02e4815fb 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -276,5 +276,14 @@ def test_span(): assert "Add1" in txt +def test_optional_info(): + c = relay.const(1) + call = relay.add(c, c) + m = tvm.IRModule.from_expr(call) + m = relay.transform.InferType()(m) + txt = astext(m) + assert txt.count("/* ty=int32 */") == 3 + + if __name__ == "__main__": pytest.main([__file__])