Skip to content

Commit

Permalink
RelayTextPrinter is now non-recursive. ExpandDataflow refactored (apa…
Browse files Browse the repository at this point in the history
…che#7817)

* RelayTextPrinter is now non-recursive. ExpandDataflow refactored

RelayTextPrinter is now non-recursive to allow printing larger
graphs. ExpandDataflow is generalised to have separate node expander.

Change-Id: Id5a3a470fbc8b90822502fbc8d24d534df1ea355

* requested changes

Change-Id: Iac69766428d5b9783279cb02a57064fd82842001

* unit test added

Change-Id: Id20ae72f9f5f8dd92d4d182360b28156c035e667
  • Loading branch information
d-smirnov authored and Trevor Morris committed May 6, 2021
1 parent 2ab729f commit 4f60f2e
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 53 deletions.
103 changes: 58 additions & 45 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@
#include <tvm/relay/function.h>
#include <tvm/relay/op.h>

#include <stack>
#include <deque>
#include <string>
#include <unordered_map>
#include <utility>

#include <vector>
namespace tvm {
namespace relay {

Expand Down Expand Up @@ -276,7 +276,9 @@ class MixedModeVisitor : public ::tvm::relay::ExprVisitor {
*/
class MixedModeMutator : public ::tvm::relay::ExprMutator {
public:
MixedModeMutator(bool pre = false) : pre_{pre} {};
Expr VisitExpr(const Expr& expr) final;

virtual Expr DispatchVisitExpr(const Expr& expr);
Expr VisitExpr_(const TupleNode* op) final { return Rewrite(op); };
Expr VisitExpr_(const CallNode* call_node) final { return Rewrite(call_node); };
Expand All @@ -294,6 +296,7 @@ class MixedModeMutator : public ::tvm::relay::ExprMutator {
virtual Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) { return post; }

protected:
bool pre_;
/*! \brief Implement Rewrite API by calling ExprMutator's VisitExpr_(op) to get a `post` node with
* changed inputs.
*/
Expand Down Expand Up @@ -410,72 +413,82 @@ Expr PostOrderRewrite(const Expr& expr, ExprRewriter* rewriter);
*/
void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);

/*!
* \brief A struct to keep info of traversed expr in ExpandDataflow function
*/
struct v_info {
explicit v_info(Expr node_) : node{node_} {}
v_info(Expr node_, bool children_expanded_)
: node{node_}, children_expanded{children_expanded_} {};
Expr node{};
bool children_expanded{false};
};

/*!
* \brief A function to iteratively traverse dataflow regions of a graph
*
* ExpandDataflow manually manages a stack and performs DFS to determine the processing
* order of nodes in an input graph.
*
* If it finds a dataflow node (Call, Tuple, TupleGetItem), it checks if the arguments to that node
* need to be processed via fcheck_visited. If so, the function pushes those arguments to the stack
* and continues iteratively to process the top of the stack. When it finds a node that doesn't
* match the dataflow types, or a node who's inputs have all been processed, it visits the current
* leaf via fvisit_leaf.
* By default fexpand_expr implemented in a way that if it finds a dataflow node (Call, Tuple,
* TupleGetItem), it checks if the arguments to that node need to be processed via fcheck_visited.
* If so, the function pushes those arguments to the stack and continues iteratively to process
* the top of the stack. When it finds a node that doesn't match the dataflow types, or a node who's
* inputs have all been processed, it visits the current leaf via fvisit_leaf.
*
* This function should be used internally to other classes to implement mixed-mode traversals. The
* expectation is that fvisit_leaf will perform recursive analysis within mixed-mode traversal if it
* hits a non-dataflow node.
*
* fcheck_visited and fvisit_leaf are templated to encourage compiler inlining.
* fcheck_visited, fvisit_leaf and fexpand_expr are templated to encourage reusing.
*/
template <typename FCheckVisited, typename FVisitLeaf>
void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf) {
std::stack<std::pair<Expr, bool>> stack;
template <typename FCheckVisited, typename FVisitLeaf, typename FExpandExpr>
void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf,
FExpandExpr fexpand_expr) {
std::deque<v_info> stack;
auto fpush_to_stack = [&fcheck_visited, &stack](const Expr& expr) {
// The second state of the stack indicate whether the child has been
// expanded in the pre-order.
// NOTE: function will be inlined.
if (!fcheck_visited(expr)) {
stack.push({expr, false});
stack.emplace_front(v_info(expr));
}
};

fpush_to_stack(expr);
while (stack.size() > 0) {
auto node = stack.top().first;
if (fcheck_visited(node)) {
// if this node was visited through another path
// after being added to the stack ignore it.
stack.pop();
} else if (stack.top().second) {
// all the children have already been expanded.
// we can just run post order visit on it.
fvisit_leaf(node);
stack.pop();
} else if (const CallNode* op = node.as<CallNode>()) {
// mark expanded = true
stack.top().second = true;
// push the children to the stack in reverse order
// to match recursive processing order
v_info* front = &stack.front();
if (fcheck_visited(front->node)) {
stack.pop_front();
} else if (front->children_expanded) {
fvisit_leaf(front->node);
// TODO(d-smirnov): this is for compatibility with current implementation of MixedModeVisitor
stack.pop_front();
} else {
front->children_expanded = true;
for (auto e : fexpand_expr(front->node)) {
fpush_to_stack(e);
}
}
}
}

template <typename FCheckVisited, typename FVisitLeaf>
void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf) {
auto fexpand_expr = [](const Expr& expr) {
std::vector<Expr> result;
if (const CallNode* op = expr.as<CallNode>()) {
for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) {
fpush_to_stack(*it);
result.push_back(*it);
}
fpush_to_stack(op->op);
} else if (const TupleNode* op = node.as<TupleNode>()) {
stack.top().second = true;
// push the children to the stack in reverse order
// to match recursive processing order
result.push_back(op->op);
} else if (const TupleNode* op = expr.as<TupleNode>()) {
for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) {
fpush_to_stack(*it);
result.push_back(*it);
}
} else if (const TupleGetItemNode* op = node.as<TupleGetItemNode>()) {
stack.top().second = true;
fpush_to_stack(op->tuple);
} else {
// No need to expand the children directly run visit.
fvisit_leaf(node);
stack.pop();
} else if (const TupleGetItemNode* op = expr.as<TupleGetItemNode>()) {
result.push_back(op->tuple);
}
}
return result;
};
ExpandDataflow(expr, fcheck_visited, fvisit_leaf, fexpand_expr);
}

void ExpandANormalForm(const LetNode* op, std::function<void(const LetNode*)> pre_visit,
Expand Down
39 changes: 35 additions & 4 deletions src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,34 @@ bool RelayTextPrinter::AlwaysInline(const Expr& expr) {
expr.as<VarNode>() || expr.as<ConstructorNode>();
}

Doc RelayTextPrinter::VisitLeaf(const Expr& expr) {
if (!CheckVisited(expr)) {
Doc result = ExprFunctor<Doc(const Expr&)>::VisitExpr(expr);
// Add if not added after visiting
if (!CheckVisited(expr)) {
memo_[expr] = result;
} else {
result_memo_[expr] = result;
}
return result;
}
return memo_[expr];
}

bool RelayTextPrinter::CheckVisited(const Expr& expr) { return (memo_.count(expr)); }

Doc RelayTextPrinter::VisitExpr(const Expr& expr) {
auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr); };
auto fvisit_leaf = [this](const Expr& expr) { return this->VisitLeaf(expr); };

if (fcheck_visited(expr)) {
return memo_[expr];
} else {
ExpandDataflow(expr, fcheck_visited, fvisit_leaf);
return memo_[expr];
}
}

//------------------------------------
// Overload of Expr printing functions
//------------------------------------
Expand All @@ -252,9 +280,6 @@ Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline, bo
inline_expr |= IsUnique(expr);
}

auto it = memo_.find(expr);
if (it != memo_.end()) return it->second;

Doc printed_expr;

if (meta) {
Expand All @@ -277,13 +302,19 @@ Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline, bo
if (expr.as<VarNode>()) {
// This is our first time visiting the var and we hit the VarNode case
// in the visitor. Thus the variable is free.
doc_stack_.back() << "free_var " << printed_expr << ";" << Doc::NewLine();
if (var_memo_.insert(expr).second && result_memo_.count(expr)) {
doc_stack_.back() << "free_var " << result_memo_[expr] << ";" << Doc::NewLine();
}
// Memoization is done in AllocVar.
return memo_[expr];
} else if (inline_expr) {
memo_[expr] = printed_expr;
return printed_expr;
} else {
// Already exists. Reuse
if (!var_memo_.insert(expr).second) {
return memo_[expr];
}
Doc temp_var = AllocTemp();
memo_[expr] = temp_var;
doc_stack_.back() << temp_var << " = " << printed_expr << ";" << Doc::NewLine();
Expand Down
10 changes: 9 additions & 1 deletion src/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "../ir/attr_functor.h"
Expand All @@ -60,6 +61,9 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
explicit RelayTextPrinter(bool show_meta_data, TextMetaDataContext* meta,
runtime::TypedPackedFunc<std::string(ObjectRef)> annotate)
: show_meta_data_(show_meta_data), annotate_(annotate), meta_(meta) {}
Doc VisitExpr(const Expr& expr) override;
virtual Doc VisitLeaf(const Expr& expr);
virtual bool CheckVisited(const Expr& expr);

/*!
* \brief Print additional info about expr in comment.
Expand Down Expand Up @@ -145,7 +149,7 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
Doc PrintType(const Type& type, bool meta);
Doc VisitTypeDefault_(const Object* node) final;
Doc VisitType_(const TypeVarNode* node) final;
Doc VisitType_(const GlobalTypeVarNode* node);
Doc VisitType_(const GlobalTypeVarNode* node) final;
Doc VisitType_(const TypeCallNode* node) final;
Doc PrintDType(DataType dtype);
Doc VisitType_(const TensorTypeNode* node) final;
Expand All @@ -170,6 +174,10 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
runtime::TypedPackedFunc<std::string(ObjectRef)> annotate_;
/*! \brief Stack of docs to implement scoped GNFing. */
std::vector<Doc> doc_stack_{};
/*! \brief Set for introduced vars */
std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> var_memo_;
/*! \brief Map for result and memo_ diffs for visited expression */
std::unordered_map<Expr, Doc, ObjectPtrHash, ObjectPtrEqual> result_memo_;
/*! \brief Map from Expr to Doc */
std::unordered_map<Expr, Doc, ObjectPtrHash, ObjectPtrEqual> memo_;
/*! \brief Map from Type to Doc */
Expand Down
6 changes: 3 additions & 3 deletions src/relay/analysis/dependency_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace tvm {
namespace relay {

// Creator of DependencyGraph
class DependencyGraph::Creator : private ExprFunctor<void(const Expr& e)> {
class DependencyGraph::Creator : private MixedModeVisitor {
public:
explicit Creator(support::Arena* arena) : arena_(arena) {}

Expand Down Expand Up @@ -73,13 +73,13 @@ class DependencyGraph::Creator : private ExprFunctor<void(const Expr& e)> {
return ret;
}

void VisitExpr(const Expr& e) final {
void VisitLeaf(const Expr& e) override {
if (visited_.count(e) == 0) {
if (graph_.expr_node.count(e) == 0) {
graph_.expr_node[e] = NewNode(false);
}
visited_.insert(e);
ExprFunctor<void(const Expr&)>::VisitExpr(e);
MixedModeVisitor::VisitLeaf(e);
graph_.post_dfs_order.push_back(graph_.expr_node[e]);
}
}
Expand Down
64 changes: 64 additions & 0 deletions tests/cpp/relay_text_printer_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <gtest/gtest.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/type_functor.h>
#include <tvm/node/functor.h>
#include <tvm/node/structural_equal.h>
#include <tvm/relay/adt.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/function.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/op_strategy.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/topi/broadcast.h>
#include <tvm/topi/generic/injective.h>

using namespace tvm;
using namespace tvm::relay;

TEST(Relay, LargeGraphPrint) {
auto foo = [] {
auto add_op = relay::Op::Get("add");
auto c_data = tvm::runtime::NDArray::Empty({1, 2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto c1 = relay::Constant(c_data);
Call y1 = relay::Call(add_op, {c1, c1});
for (int i = 0; i < 1e6; i++) {
y1 = relay::Call(add_op, {c1, y1});
}
relay::Function func = relay::Function({}, y1, relay::Type(), {});
std::string result = AsText(func);
ASSERT_GT(0, result.size());
};
ASSERT_EXIT((foo(), exit(0)), ::testing::ExitedWithCode(0), ".*");
}

int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}

0 comments on commit 4f60f2e

Please sign in to comment.