Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR] IRPrinter->NodePrinter, move to node/printer.h #4622

Merged
merged 1 commit into from
Jan 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 0 additions & 31 deletions include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -470,37 +470,6 @@ inline std::unordered_map<K, V> as_unordered_map(const Map<K, V>& dmap) {
}
return ret;
}

// Printer infra.
/*! \brief A Pretty printer class to print the IR. */
class IRPrinter {
public:
/*! \brief The output stream */
std::ostream& stream;
/*! \brief The indentation level. */
int indent{0};
explicit IRPrinter(std::ostream& stream) // NOLINT(*)
: stream(stream) {}

/*! \brief The node to be printed. */
TVM_DLL void Print(const ObjectRef& node);
/*! \brief Print indent to the stream */
TVM_DLL void PrintIndent();
// Allow registration to be printer.
using FType = NodeFunctor<void(const ObjectRef&, IRPrinter *)>;
TVM_DLL static FType& vtable();
};
} // namespace tvm

namespace tvm {
namespace runtime {
// default print function for all objects
// provide in the runtime namespace as this is where objectref originally comes from.
inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLINT(*)
IRPrinter(os).Print(n);
return os;
}
} // namespace runtime
} // namespace tvm

namespace std {
Expand Down
20 changes: 11 additions & 9 deletions include/tvm/node/functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,16 @@
#define TVM_NODE_FUNCTOR_H_

#include <dmlc/logging.h>
#include <tvm/runtime/registry.h>
#include <tvm/node/node.h>
#include <tvm/runtime/object.h>

#include <vector>
#include <type_traits>
#include <utility>

namespace tvm {

using runtime::ObjectRef;

/*!
* \brief A dynamically dispatched functor on the type of the first argument.
*
Expand Down Expand Up @@ -137,11 +139,11 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
* \brief Useful macro to set NodeFunctor dispatch in a global static field.
*
* \code
* // Use NodeFunctor to implement IRPrinter similar to Visitor Pattern.
* // Use NodeFunctor to implement NodePrinter similar to Visitor Pattern.
* // vtable allows easy patch of new Node types, without changing
* // interface of IRPrinter.
* // interface of NodePrinter.
*
* class IRPrinter {
* class NodePrinter {
* public:
* std::ostream& stream;
* // the dispatch function.
Expand All @@ -150,18 +152,18 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
* f(e, this);
* }
*
* using FType = NodeFunctor<void (const ObjectRef&, IRPrinter *)>;
* using FType = NodeFunctor<void (const ObjectRef&, NodePrinter* )>;
* // function to return global function table
* static FType& vtable();
* };
*
* // in cpp/cc file
* IRPrinter::FType& IRPrinter::vtable() { // NOLINT(*)
* NodePrinter::FType& NodePrinter::vtable() { // NOLINT(*)
* static FType inst; return inst;
* }
*
* TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
* .set_dispatch<Add>([](const ObjectRef& ref, IRPrinter* p) {
* TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
* .set_dispatch<Add>([](const ObjectRef& ref, NodePrinter* p) {
* auto* n = static_cast<const Add*>(ref.get());
* p->print(n->a);
* p->stream << '+'
Expand Down
1 change: 1 addition & 0 deletions include/tvm/node/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
#include <tvm/node/reflection.h>
#include <tvm/node/printer.h>

#include <string>
#include <vector>
Expand Down
61 changes: 61 additions & 0 deletions include/tvm/node/printer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/node/printer.h
* \brief Printer class to print repr string of each AST/IR nodes.
*/
#ifndef TVM_NODE_PRINTER_H_
#define TVM_NODE_PRINTER_H_

#include <tvm/node/functor.h>
#include <iostream>

namespace tvm {
/*! \brief A printer class to print the AST/IR nodes. */
class NodePrinter {
public:
/*! \brief The output stream */
std::ostream& stream;
/*! \brief The indentation level. */
int indent{0};

explicit NodePrinter(std::ostream& stream) // NOLINT(*)
: stream(stream) {}

/*! \brief The node to be printed. */
TVM_DLL void Print(const ObjectRef& node);
/*! \brief Print indent to the stream */
TVM_DLL void PrintIndent();
// Allow registration to be printer.
using FType = NodeFunctor<void(const ObjectRef&, NodePrinter*)>;
TVM_DLL static FType& vtable();
};
} // namespace tvm

namespace tvm {
namespace runtime {
// default print function for all objects
// provide in the runtime namespace as this is where objectref originally comes from.
inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLINT(*)
NodePrinter(os).Print(n);
return os;
}
} // namespace runtime
} // namespace tvm
#endif // TVM_NODE_PRINTER_H_
4 changes: 2 additions & 2 deletions src/arithmetic/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ inline void PrintBoundValue(std::ostream& os, int64_t val) {
}
}

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ConstIntBoundNode>([](const ObjectRef& node, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ConstIntBoundNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const ConstIntBoundNode*>(node.get());
p->stream << "ConstIntBound[";
PrintBoundValue(p->stream, op->min_value);
Expand Down
4 changes: 2 additions & 2 deletions src/arithmetic/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -811,8 +811,8 @@ IntSet EvalSet(Range r,

TVM_REGISTER_NODE_TYPE(IntervalSetNode);

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IntervalSetNode>([](const ObjectRef& node, IRPrinter *p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IntervalSetNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const IntervalSetNode*>(node.get());
p->stream << "IntervalSet"
<< "[" << op->min_value << ", "
Expand Down
4 changes: 2 additions & 2 deletions src/arithmetic/modular_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ ModularSet::ModularSet(int64_t coeff, int64_t base) {
data_ = std::move(node);
}

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ModularSetNode>([](const ObjectRef& node, IRPrinter *p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ModularSetNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const ModularSetNode*>(node.get());
p->stream << "ModularSet("
<< "coeff=" << op->coeff << ", base="
Expand Down
8 changes: 4 additions & 4 deletions src/codegen/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ using runtime::PackedFunc;
TVM_REGISTER_NODE_TYPE(TargetNode);
TVM_REGISTER_NODE_TYPE(GenericFuncNode);

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TargetNode>([](const ObjectRef& node, IRPrinter *p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TargetNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const TargetNode*>(node.get());
p->stream << op->str();
});
Expand Down Expand Up @@ -665,8 +665,8 @@ tvm::BuildConfig BuildConfig::Current() {

TVM_REGISTER_NODE_TYPE(BuildConfigNode);

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<BuildConfigNode>([](const ObjectRef& node, IRPrinter *p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<BuildConfigNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const BuildConfigNode*>(node.get());
p->stream << "build_config(";
p->stream << "data_alignment=" << op->data_alignment << ", ";
Expand Down
3 changes: 3 additions & 0 deletions src/codegen/llvm/codegen_arm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
* \brief ARM specific code generator
*/
#ifdef TVM_LLVM_VERSION

#include <tvm/runtime/registry.h>

#include "codegen_cpu.h"

namespace tvm {
Expand Down
2 changes: 2 additions & 0 deletions src/codegen/llvm/codegen_x86_64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
* \brief X86-64 specific code generator
*/
#ifdef TVM_LLVM_VERSION

#include <tvm/runtime/registry.h>
#include "codegen_cpu.h"

#include "llvm/MC/MCSubtargetInfo.h"
Expand Down
2 changes: 2 additions & 0 deletions src/codegen/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
* \brief LLVM runtime module for TVM
*/
#ifdef TVM_LLVM_VERSION

#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/codegen.h>
#include <mutex>
#include "llvm_common.h"
Expand Down
1 change: 1 addition & 0 deletions src/codegen/source_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* \brief Source code module, only for viewing
*/
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include "codegen_source_base.h"
#include "../runtime/file_util.h"
#include "../runtime/meta_data.h"
Expand Down
1 change: 1 addition & 0 deletions src/contrib/hybrid/codegen_hybrid.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
/*!
* \file codegen_hybrid.cc
*/
#include <tvm/runtime/registry.h>
#include <iomanip>
#include <cctype>
#include "codegen_hybrid.h"
Expand Down
9 changes: 5 additions & 4 deletions src/ir/span.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* \brief The span data structure.
*/
#include <tvm/ir/span.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>

namespace tvm {
Expand Down Expand Up @@ -48,8 +49,8 @@ SourceName SourceName::Get(const std::string& name) {
TVM_REGISTER_GLOBAL("relay._make.SourceName")
.set_body_typed(SourceName::Get);

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<SourceNameNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<SourceNameNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const SourceNameNode*>(ref.get());
p->stream << "SourceName(" << node->name << ", " << node << ")";
});
Expand All @@ -73,8 +74,8 @@ TVM_REGISTER_NODE_TYPE(SpanNode);
TVM_REGISTER_GLOBAL("relay._make.Span")
.set_body_typed(SpanNode::make);

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<SpanNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<SpanNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const SpanNode*>(ref.get());
p->stream << "Span(" << node->source << ", " << node->lineno << ", "
<< node->col_offset << ")";
Expand Down
13 changes: 7 additions & 6 deletions src/ir/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* \brief Common type system AST nodes throughout the IR.
*/
#include <tvm/ir/type.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>

namespace tvm {
Expand All @@ -40,8 +41,8 @@ TVM_REGISTER_GLOBAL("relay._make.TypeVar")
return TypeVarNode::make(name, static_cast<TypeKind>(kind));
});

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TypeVarNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TypeVarNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const TypeVarNode*>(ref.get());
p->stream << "TypeVar(" << node->name_hint << ", "
<< node->kind << ")";
Expand All @@ -61,8 +62,8 @@ TVM_REGISTER_GLOBAL("relay._make.GlobalTypeVar")
return GlobalTypeVarNode::make(name, static_cast<TypeKind>(kind));
});

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<GlobalTypeVarNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<GlobalTypeVarNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const GlobalTypeVarNode*>(ref.get());
p->stream << "GlobalTypeVar(" << node->name_hint << ", "
<< node->kind << ")";
Expand All @@ -85,8 +86,8 @@ TVM_REGISTER_NODE_TYPE(FuncTypeNode);
TVM_REGISTER_GLOBAL("relay._make.FuncType")
.set_body_typed(FuncTypeNode::make);

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<FuncTypeNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<FuncTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const FuncTypeNode*>(ref.get());
p->stream << "FuncType(" << node->type_params << ", "
<< node->arg_types << ", " << node->ret_type << ", "
Expand Down
4 changes: 2 additions & 2 deletions src/lang/attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ Attrs DictAttrsNode::make(Map<std::string, ObjectRef> dict) {
return Attrs(n);
}

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<DictAttrsNode>([](const ObjectRef& node, IRPrinter *p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<DictAttrsNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const DictAttrsNode*>(node.get());
p->stream << op->dict;
});
Expand Down
4 changes: 2 additions & 2 deletions src/lang/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -450,8 +450,8 @@ Buffer BufferNode::make(Var data,
return Buffer(n);
}

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<BufferNode>([](const ObjectRef& node, IRPrinter *p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<BufferNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const BufferNode*>(node.get());
p->stream << "buffer(" << op->name << ", " << op << ")";
});
Expand Down
8 changes: 4 additions & 4 deletions src/lang/data_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ int32_t Layout::FactorOf(const LayoutAxis& axis) const {
return -1;
}

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<LayoutNode>([](const ObjectRef& node, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<LayoutNode>([](const ObjectRef& node, NodePrinter* p) {
auto* l = static_cast<const LayoutNode*>(node.get());
p->stream << "Layout(" << l->name << ")";
});
Expand Down Expand Up @@ -361,8 +361,8 @@ BijectiveLayout BijectiveLayoutNode::make(const Layout& src_layout,
return BijectiveLayout(n);
}

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<BijectiveLayoutNode>([](const ObjectRef& node, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<BijectiveLayoutNode>([](const ObjectRef& node, NodePrinter* p) {
auto* b = static_cast<const BijectiveLayoutNode*>(node.get());
p->stream << "BijectiveLayout(" << b->src_layout.name()
<< "->" << b->dst_layout.name() << ")";
Expand Down
Loading