Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RelayTextPrinter is now non-recursive. ExpandDataflow refactored #7817

Merged
merged 3 commits into from
Apr 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 58 additions & 45 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@
#include <tvm/relay/function.h>
#include <tvm/relay/op.h>

#include <stack>
#include <deque>
#include <string>
#include <unordered_map>
#include <utility>

#include <vector>
namespace tvm {
namespace relay {

Expand Down Expand Up @@ -276,7 +276,9 @@ class MixedModeVisitor : public ::tvm::relay::ExprVisitor {
*/
class MixedModeMutator : public ::tvm::relay::ExprMutator {
public:
MixedModeMutator(bool pre = false) : pre_{pre} {};
Expr VisitExpr(const Expr& expr) final;

virtual Expr DispatchVisitExpr(const Expr& expr);
Expr VisitExpr_(const TupleNode* op) final { return Rewrite(op); };
Expr VisitExpr_(const CallNode* call_node) final { return Rewrite(call_node); };
Expand All @@ -294,6 +296,7 @@ class MixedModeMutator : public ::tvm::relay::ExprMutator {
virtual Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) { return post; }

protected:
bool pre_;
/*! \brief Implement Rewrite API by calling ExprMutator's VisitExpr_(op) to get a `post` node with
* changed inputs.
*/
Expand Down Expand Up @@ -410,72 +413,82 @@ Expr PostOrderRewrite(const Expr& expr, ExprRewriter* rewriter);
*/
void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);

/*!
* \brief A struct to keep info of traversed expr in ExpandDataflow function
*/
struct v_info {
explicit v_info(Expr node_) : node{node_} {}
v_info(Expr node_, bool children_expanded_)
: node{node_}, children_expanded{children_expanded_} {};
Expr node{};
bool children_expanded{false};
};
d-smirnov marked this conversation as resolved.
Show resolved Hide resolved

/*!
* \brief A function to iteratively traverse dataflow regions of a graph
*
* ExpandDataflow manually manages a stack and performs DFS to determine the processing
* order of nodes in an input graph.
*
* If it finds a dataflow node (Call, Tuple, TupleGetItem), it checks if the arguments to that node
* need to be processed via fcheck_visited. If so, the function pushes those arguments to the stack
* and continues iteratively to process the top of the stack. When it finds a node that doesn't
* match the dataflow types, or a node who's inputs have all been processed, it visits the current
* leaf via fvisit_leaf.
* By default fexpand_expr implemented in a way that if it finds a dataflow node (Call, Tuple,
* TupleGetItem), it checks if the arguments to that node need to be processed via fcheck_visited.
* If so, the function pushes those arguments to the stack and continues iteratively to process
* the top of the stack. When it finds a node that doesn't match the dataflow types, or a node who's
* inputs have all been processed, it visits the current leaf via fvisit_leaf.
*
* This function should be used internally to other classes to implement mixed-mode traversals. The
* expectation is that fvisit_leaf will perform recursive analysis within mixed-mode traversal if it
* hits a non-dataflow node.
*
* fcheck_visited and fvisit_leaf are templated to encourage compiler inlining.
* fcheck_visited, fvisit_leaf and fexpand_expr are templated to encourage reusing.
*/
template <typename FCheckVisited, typename FVisitLeaf>
void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf) {
std::stack<std::pair<Expr, bool>> stack;
template <typename FCheckVisited, typename FVisitLeaf, typename FExpandExpr>
void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf,
FExpandExpr fexpand_expr) {
std::deque<v_info> stack;
auto fpush_to_stack = [&fcheck_visited, &stack](const Expr& expr) {
// The second state of the stack indicate whether the child has been
// expanded in the pre-order.
// NOTE: function will be inlined.
if (!fcheck_visited(expr)) {
stack.push({expr, false});
stack.emplace_front(v_info(expr));
}
};

fpush_to_stack(expr);
while (stack.size() > 0) {
auto node = stack.top().first;
if (fcheck_visited(node)) {
// if this node was visited through another path
// after being added to the stack ignore it.
stack.pop();
} else if (stack.top().second) {
// all the children have already been expanded.
// we can just run post order visit on it.
fvisit_leaf(node);
stack.pop();
} else if (const CallNode* op = node.as<CallNode>()) {
// mark expanded = true
stack.top().second = true;
// push the children to the stack in reverse order
// to match recursive processing order
v_info* front = &stack.front();
if (fcheck_visited(front->node)) {
stack.pop_front();
} else if (front->children_expanded) {
fvisit_leaf(front->node);
// TODO(d-smirnov): this is for compatibility with current implementation of MixedModeVisitor
stack.pop_front();
} else {
front->children_expanded = true;
for (auto e : fexpand_expr(front->node)) {
fpush_to_stack(e);
}
}
}
}

template <typename FCheckVisited, typename FVisitLeaf>
void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf) {
auto fexpand_expr = [](const Expr& expr) {
std::vector<Expr> result;
if (const CallNode* op = expr.as<CallNode>()) {
for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) {
fpush_to_stack(*it);
result.push_back(*it);
}
fpush_to_stack(op->op);
} else if (const TupleNode* op = node.as<TupleNode>()) {
stack.top().second = true;
// push the children to the stack in reverse order
// to match recursive processing order
result.push_back(op->op);
} else if (const TupleNode* op = expr.as<TupleNode>()) {
for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) {
fpush_to_stack(*it);
result.push_back(*it);
}
} else if (const TupleGetItemNode* op = node.as<TupleGetItemNode>()) {
stack.top().second = true;
fpush_to_stack(op->tuple);
} else {
// No need to expand the children directly run visit.
fvisit_leaf(node);
stack.pop();
} else if (const TupleGetItemNode* op = expr.as<TupleGetItemNode>()) {
result.push_back(op->tuple);
}
}
return result;
};
ExpandDataflow(expr, fcheck_visited, fvisit_leaf, fexpand_expr);
}

void ExpandANormalForm(const LetNode* op, std::function<void(const LetNode*)> pre_visit,
Expand Down
39 changes: 35 additions & 4 deletions src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,34 @@ bool RelayTextPrinter::AlwaysInline(const Expr& expr) {
expr.as<VarNode>() || expr.as<ConstructorNode>();
}

Doc RelayTextPrinter::VisitLeaf(const Expr& expr) {
if (!CheckVisited(expr)) {
Doc result = ExprFunctor<Doc(const Expr&)>::VisitExpr(expr);
// Add if not added after visiting
if (!CheckVisited(expr)) {
memo_[expr] = result;
} else {
result_memo_[expr] = result;
}
return result;
}
return memo_[expr];
}

bool RelayTextPrinter::CheckVisited(const Expr& expr) { return (memo_.count(expr)); }

Doc RelayTextPrinter::VisitExpr(const Expr& expr) {
auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr); };
auto fvisit_leaf = [this](const Expr& expr) { return this->VisitLeaf(expr); };

if (fcheck_visited(expr)) {
return memo_[expr];
} else {
ExpandDataflow(expr, fcheck_visited, fvisit_leaf);
return memo_[expr];
}
}

//------------------------------------
// Overload of Expr printing functions
//------------------------------------
Expand All @@ -252,9 +280,6 @@ Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline, bo
inline_expr |= IsUnique(expr);
}

auto it = memo_.find(expr);
if (it != memo_.end()) return it->second;

Doc printed_expr;

if (meta) {
Expand All @@ -277,13 +302,19 @@ Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline, bo
if (expr.as<VarNode>()) {
// This is our first time visiting the var and we hit the VarNode case
// in the visitor. Thus the variable is free.
doc_stack_.back() << "free_var " << printed_expr << ";" << Doc::NewLine();
if (var_memo_.insert(expr).second && result_memo_.count(expr)) {
doc_stack_.back() << "free_var " << result_memo_[expr] << ";" << Doc::NewLine();
}
// Memoization is done in AllocVar.
return memo_[expr];
} else if (inline_expr) {
memo_[expr] = printed_expr;
return printed_expr;
} else {
// Already exists. Reuse
if (!var_memo_.insert(expr).second) {
return memo_[expr];
}
Doc temp_var = AllocTemp();
memo_[expr] = temp_var;
doc_stack_.back() << temp_var << " = " << printed_expr << ";" << Doc::NewLine();
Expand Down
10 changes: 9 additions & 1 deletion src/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "../ir/attr_functor.h"
Expand All @@ -60,6 +61,9 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
explicit RelayTextPrinter(bool show_meta_data, TextMetaDataContext* meta,
runtime::TypedPackedFunc<std::string(ObjectRef)> annotate)
: show_meta_data_(show_meta_data), annotate_(annotate), meta_(meta) {}
Doc VisitExpr(const Expr& expr) override;
virtual Doc VisitLeaf(const Expr& expr);
virtual bool CheckVisited(const Expr& expr);

/*!
* \brief Print additional info about expr in comment.
Expand Down Expand Up @@ -145,7 +149,7 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
Doc PrintType(const Type& type, bool meta);
Doc VisitTypeDefault_(const Object* node) final;
Doc VisitType_(const TypeVarNode* node) final;
Doc VisitType_(const GlobalTypeVarNode* node);
Doc VisitType_(const GlobalTypeVarNode* node) final;
Doc VisitType_(const TypeCallNode* node) final;
Doc PrintDType(DataType dtype);
Doc VisitType_(const TensorTypeNode* node) final;
Expand All @@ -170,6 +174,10 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
runtime::TypedPackedFunc<std::string(ObjectRef)> annotate_;
/*! \brief Stack of docs to implement scoped GNFing. */
std::vector<Doc> doc_stack_{};
/*! \brief Set for introduced vars */
std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> var_memo_;
/*! \brief Map for result and memo_ diffs for visited expression */
std::unordered_map<Expr, Doc, ObjectPtrHash, ObjectPtrEqual> result_memo_;
/*! \brief Map from Expr to Doc */
std::unordered_map<Expr, Doc, ObjectPtrHash, ObjectPtrEqual> memo_;
/*! \brief Map from Type to Doc */
Expand Down
6 changes: 3 additions & 3 deletions src/relay/analysis/dependency_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace tvm {
namespace relay {

// Creator of DependencyGraph
class DependencyGraph::Creator : private ExprFunctor<void(const Expr& e)> {
class DependencyGraph::Creator : private MixedModeVisitor {
public:
explicit Creator(support::Arena* arena) : arena_(arena) {}

Expand Down Expand Up @@ -73,13 +73,13 @@ class DependencyGraph::Creator : private ExprFunctor<void(const Expr& e)> {
return ret;
}

void VisitExpr(const Expr& e) final {
void VisitLeaf(const Expr& e) override {
if (visited_.count(e) == 0) {
if (graph_.expr_node.count(e) == 0) {
graph_.expr_node[e] = NewNode(false);
}
visited_.insert(e);
ExprFunctor<void(const Expr&)>::VisitExpr(e);
MixedModeVisitor::VisitLeaf(e);
graph_.post_dfs_order.push_back(graph_.expr_node[e]);
}
}
Expand Down
64 changes: 64 additions & 0 deletions tests/cpp/relay_text_printer_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <gtest/gtest.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/type_functor.h>
#include <tvm/node/functor.h>
#include <tvm/node/structural_equal.h>
#include <tvm/relay/adt.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/function.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/op_strategy.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/topi/broadcast.h>
#include <tvm/topi/generic/injective.h>

using namespace tvm;
using namespace tvm::relay;

TEST(Relay, LargeGraphPrint) {
auto foo = [] {
auto add_op = relay::Op::Get("add");
auto c_data = tvm::runtime::NDArray::Empty({1, 2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto c1 = relay::Constant(c_data);
Call y1 = relay::Call(add_op, {c1, c1});
for (int i = 0; i < 1e6; i++) {
y1 = relay::Call(add_op, {c1, y1});
}
relay::Function func = relay::Function({}, y1, relay::Type(), {});
std::string result = AsText(func);
ASSERT_GT(0, result.size());
};
ASSERT_EXIT((foo(), exit(0)), ::testing::ExitedWithCode(0), ".*");
}

int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}