From 8800b7ae0647c50c500dbcfb1b6dc7c5013fa20f Mon Sep 17 00:00:00 2001 From: Dmitriy Smirnov Date: Fri, 23 Apr 2021 20:15:42 +0100 Subject: [PATCH] RelayTextPrinter is now non-recursive. ExpandDataflow refactored (#7817) * RelayTextPrinter is now non-recursive. ExpandDataflow refactored RelayTextPrinter is now non-recursive to allow printing larger graphs. ExpandDataflow is generalised to have separate node expander. Change-Id: Id5a3a470fbc8b90822502fbc8d24d534df1ea355 * requested changes Change-Id: Iac69766428d5b9783279cb02a57064fd82842001 * unit test added Change-Id: Id20ae72f9f5f8dd92d4d182360b28156c035e667 --- include/tvm/relay/expr_functor.h | 103 ++++++++++++++----------- src/printer/relay_text_printer.cc | 39 +++++++++- src/printer/text_printer.h | 10 ++- src/relay/analysis/dependency_graph.cc | 6 +- tests/cpp/relay_text_printer_test.cc | 64 +++++++++++++++ 5 files changed, 169 insertions(+), 53 deletions(-) create mode 100644 tests/cpp/relay_text_printer_test.cc diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index e6eec61a7e9d0..688ad8254fa85 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -32,11 +32,11 @@ #include #include -#include +#include #include #include #include - +#include namespace tvm { namespace relay { @@ -276,7 +276,9 @@ class MixedModeVisitor : public ::tvm::relay::ExprVisitor { */ class MixedModeMutator : public ::tvm::relay::ExprMutator { public: + MixedModeMutator(bool pre = false) : pre_{pre} {}; Expr VisitExpr(const Expr& expr) final; + virtual Expr DispatchVisitExpr(const Expr& expr); Expr VisitExpr_(const TupleNode* op) final { return Rewrite(op); }; Expr VisitExpr_(const CallNode* call_node) final { return Rewrite(call_node); }; @@ -294,6 +296,7 @@ class MixedModeMutator : public ::tvm::relay::ExprMutator { virtual Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) { return post; } protected: + bool pre_; /*! \brief Implement Rewrite API by calling ExprMutator's VisitExpr_(op) to get a `post` node with * changed inputs. */ @@ -410,72 +413,82 @@ Expr PostOrderRewrite(const Expr& expr, ExprRewriter* rewriter); */ void PostOrderVisit(const Expr& node, std::function fvisit); +/*! + * \brief A struct to keep info of traversed expr in ExpandDataflow function + */ +struct v_info { + explicit v_info(Expr node_) : node{node_} {} + v_info(Expr node_, bool children_expanded_) + : node{node_}, children_expanded{children_expanded_} {}; + Expr node{}; + bool children_expanded{false}; +}; + /*! * \brief A function to iteratively traverse dataflow regions of a graph * * ExpandDataflow manually manages a stack and performs DFS to determine the processing * order of nodes in an input graph. * - * If it finds a dataflow node (Call, Tuple, TupleGetItem), it checks if the arguments to that node - * need to be processed via fcheck_visited. If so, the function pushes those arguments to the stack - * and continues iteratively to process the top of the stack. When it finds a node that doesn't - * match the dataflow types, or a node who's inputs have all been processed, it visits the current - * leaf via fvisit_leaf. + * By default fexpand_expr implemented in a way that if it finds a dataflow node (Call, Tuple, + * TupleGetItem), it checks if the arguments to that node need to be processed via fcheck_visited. + * If so, the function pushes those arguments to the stack and continues iteratively to process + * the top of the stack. When it finds a node that doesn't match the dataflow types, or a node who's + * inputs have all been processed, it visits the current leaf via fvisit_leaf. * * This function should be used internally to other classes to implement mixed-mode traversals. The * expectation is that fvisit_leaf will perform recursive analysis within mixed-mode traversal if it * hits a non-dataflow node. * - * fcheck_visited and fvisit_leaf are templated to encourage compiler inlining. + * fcheck_visited, fvisit_leaf and fexpand_expr are templated to encourage reusing. */ -template -void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf) { - std::stack> stack; +template +void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf, + FExpandExpr fexpand_expr) { + std::deque stack; auto fpush_to_stack = [&fcheck_visited, &stack](const Expr& expr) { - // The second state of the stack indicate whether the child has been - // expanded in the pre-order. - // NOTE: function will be inlined. if (!fcheck_visited(expr)) { - stack.push({expr, false}); + stack.emplace_front(v_info(expr)); } }; + fpush_to_stack(expr); while (stack.size() > 0) { - auto node = stack.top().first; - if (fcheck_visited(node)) { - // if this node was visited through another path - // after being added to the stack ignore it. - stack.pop(); - } else if (stack.top().second) { - // all the children have already been expanded. - // we can just run post order visit on it. - fvisit_leaf(node); - stack.pop(); - } else if (const CallNode* op = node.as()) { - // mark expanded = true - stack.top().second = true; - // push the children to the stack in reverse order - // to match recursive processing order + v_info* front = &stack.front(); + if (fcheck_visited(front->node)) { + stack.pop_front(); + } else if (front->children_expanded) { + fvisit_leaf(front->node); + // TODO(d-smirnov): this is for compatibility with current implementation of MixedModeVisitor + stack.pop_front(); + } else { + front->children_expanded = true; + for (auto e : fexpand_expr(front->node)) { + fpush_to_stack(e); + } + } + } +} + +template +void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf) { + auto fexpand_expr = [](const Expr& expr) { + std::vector result; + if (const CallNode* op = expr.as()) { for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) { - fpush_to_stack(*it); + result.push_back(*it); } - fpush_to_stack(op->op); - } else if (const TupleNode* op = node.as()) { - stack.top().second = true; - // push the children to the stack in reverse order - // to match recursive processing order + result.push_back(op->op); + } else if (const TupleNode* op = expr.as()) { for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) { - fpush_to_stack(*it); + result.push_back(*it); } - } else if (const TupleGetItemNode* op = node.as()) { - stack.top().second = true; - fpush_to_stack(op->tuple); - } else { - // No need to expand the children directly run visit. - fvisit_leaf(node); - stack.pop(); + } else if (const TupleGetItemNode* op = expr.as()) { + result.push_back(op->tuple); } - } + return result; + }; + ExpandDataflow(expr, fcheck_visited, fvisit_leaf, fexpand_expr); } void ExpandANormalForm(const LetNode* op, std::function pre_visit, diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 31f98ce4d2705..4fc03039466c7 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -236,6 +236,34 @@ bool RelayTextPrinter::AlwaysInline(const Expr& expr) { expr.as() || expr.as(); } +Doc RelayTextPrinter::VisitLeaf(const Expr& expr) { + if (!CheckVisited(expr)) { + Doc result = ExprFunctor::VisitExpr(expr); + // Add if not added after visiting + if (!CheckVisited(expr)) { + memo_[expr] = result; + } else { + result_memo_[expr] = result; + } + return result; + } + return memo_[expr]; +} + +bool RelayTextPrinter::CheckVisited(const Expr& expr) { return (memo_.count(expr)); } + +Doc RelayTextPrinter::VisitExpr(const Expr& expr) { + auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr); }; + auto fvisit_leaf = [this](const Expr& expr) { return this->VisitLeaf(expr); }; + + if (fcheck_visited(expr)) { + return memo_[expr]; + } else { + ExpandDataflow(expr, fcheck_visited, fvisit_leaf); + return memo_[expr]; + } +} + //------------------------------------ // Overload of Expr printing functions //------------------------------------ @@ -252,9 +280,6 @@ Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline, bo inline_expr |= IsUnique(expr); } - auto it = memo_.find(expr); - if (it != memo_.end()) return it->second; - Doc printed_expr; if (meta) { @@ -277,13 +302,19 @@ Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline, bo if (expr.as()) { // This is our first time visiting the var and we hit the VarNode case // in the visitor. Thus the variable is free. - doc_stack_.back() << "free_var " << printed_expr << ";" << Doc::NewLine(); + if (var_memo_.insert(expr).second && result_memo_.count(expr)) { + doc_stack_.back() << "free_var " << result_memo_[expr] << ";" << Doc::NewLine(); + } // Memoization is done in AllocVar. return memo_[expr]; } else if (inline_expr) { memo_[expr] = printed_expr; return printed_expr; } else { + // Already exists. Reuse + if (!var_memo_.insert(expr).second) { + return memo_[expr]; + } Doc temp_var = AllocTemp(); memo_[expr] = temp_var; doc_stack_.back() << temp_var << " = " << printed_expr << ";" << Doc::NewLine(); diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 90e46c5624fad..e2a61e1d940ff 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -37,6 +37,7 @@ #include #include +#include #include #include "../ir/attr_functor.h" @@ -60,6 +61,9 @@ class RelayTextPrinter : public ExprFunctor, explicit RelayTextPrinter(bool show_meta_data, TextMetaDataContext* meta, runtime::TypedPackedFunc annotate) : show_meta_data_(show_meta_data), annotate_(annotate), meta_(meta) {} + Doc VisitExpr(const Expr& expr) override; + virtual Doc VisitLeaf(const Expr& expr); + virtual bool CheckVisited(const Expr& expr); /*! * \brief Print additional info about expr in comment. @@ -145,7 +149,7 @@ class RelayTextPrinter : public ExprFunctor, Doc PrintType(const Type& type, bool meta); Doc VisitTypeDefault_(const Object* node) final; Doc VisitType_(const TypeVarNode* node) final; - Doc VisitType_(const GlobalTypeVarNode* node); + Doc VisitType_(const GlobalTypeVarNode* node) final; Doc VisitType_(const TypeCallNode* node) final; Doc PrintDType(DataType dtype); Doc VisitType_(const TensorTypeNode* node) final; @@ -170,6 +174,10 @@ class RelayTextPrinter : public ExprFunctor, runtime::TypedPackedFunc annotate_; /*! \brief Stack of docs to implement scoped GNFing. */ std::vector doc_stack_{}; + /*! \brief Set for introduced vars */ + std::unordered_set var_memo_; + /*! \brief Map for result and memo_ diffs for visited expression */ + std::unordered_map result_memo_; /*! \brief Map from Expr to Doc */ std::unordered_map memo_; /*! \brief Map from Type to Doc */ diff --git a/src/relay/analysis/dependency_graph.cc b/src/relay/analysis/dependency_graph.cc index 3a4fb59475a4c..66ff8e684115c 100644 --- a/src/relay/analysis/dependency_graph.cc +++ b/src/relay/analysis/dependency_graph.cc @@ -32,7 +32,7 @@ namespace tvm { namespace relay { // Creator of DependencyGraph -class DependencyGraph::Creator : private ExprFunctor { +class DependencyGraph::Creator : private MixedModeVisitor { public: explicit Creator(support::Arena* arena) : arena_(arena) {} @@ -73,13 +73,13 @@ class DependencyGraph::Creator : private ExprFunctor { return ret; } - void VisitExpr(const Expr& e) final { + void VisitLeaf(const Expr& e) override { if (visited_.count(e) == 0) { if (graph_.expr_node.count(e) == 0) { graph_.expr_node[e] = NewNode(false); } visited_.insert(e); - ExprFunctor::VisitExpr(e); + MixedModeVisitor::VisitLeaf(e); graph_.post_dfs_order.push_back(graph_.expr_node[e]); } } diff --git a/tests/cpp/relay_text_printer_test.cc b/tests/cpp/relay_text_printer_test.cc new file mode 100644 index 0000000000000..ed80290647205 --- /dev/null +++ b/tests/cpp/relay_text_printer_test.cc @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace tvm; +using namespace tvm::relay; + +TEST(Relay, LargeGraphPrint) { + auto foo = [] { + auto add_op = relay::Op::Get("add"); + auto c_data = tvm::runtime::NDArray::Empty({1, 2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto c1 = relay::Constant(c_data); + Call y1 = relay::Call(add_op, {c1, c1}); + for (int i = 0; i < 1e6; i++) { + y1 = relay::Call(add_op, {c1, y1}); + } + relay::Function func = relay::Function({}, y1, relay::Type(), {}); + std::string result = AsText(func); + ASSERT_GT(0, result.size()); + }; + ASSERT_EXIT((foo(), exit(0)), ::testing::ExitedWithCode(0), ".*"); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +}