diff --git a/include/tvm/api_registry.h b/include/tvm/api_registry.h index c41c3087f4ac0..292e4948d2116 100644 --- a/include/tvm/api_registry.h +++ b/include/tvm/api_registry.h @@ -49,7 +49,7 @@ namespace tvm { * \brief Node container of EnvFunc * \sa EnvFunc */ -class EnvFuncNode : public Node { +class EnvFuncNode : public Object { public: /*! \brief Unique name of the global function */ std::string name; @@ -63,7 +63,7 @@ class EnvFuncNode : public Node { } static constexpr const char* _type_key = "EnvFunc"; - TVM_DECLARE_NODE_TYPE_INFO(EnvFuncNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(EnvFuncNode, Object); }; /*! @@ -73,10 +73,10 @@ class EnvFuncNode : public Node { * An EnvFunc is saved by its name in the global registry * under the assumption that the same function is registered during load. */ -class EnvFunc : public NodeRef { +class EnvFunc : public ObjectRef { public: EnvFunc() {} - explicit EnvFunc(NodePtr n) : NodeRef(n) {} + explicit EnvFunc(ObjectPtr n) : ObjectRef(n) {} /*! \return The internal global function pointer */ const EnvFuncNode* operator->() const { return static_cast(get()); @@ -119,12 +119,12 @@ class TypedEnvFunc; * \sa EnvFunc */ template -class TypedEnvFunc : public NodeRef { +class TypedEnvFunc : public ObjectRef { public: /*! \brief short hand for this function type */ using TSelf = TypedEnvFunc; TypedEnvFunc() {} - explicit TypedEnvFunc(ObjectPtr n) : NodeRef(n) {} + explicit TypedEnvFunc(ObjectPtr n) : ObjectRef(n) {} /*! * \brief Assign global function to a TypedEnvFunc * \param other Another global function. diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index bda6ac647f557..e5f75673a9cbd 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -55,7 +55,7 @@ class Analyzer; * * set = [min_value, max_value] */ -class ConstIntBoundNode : public Node { +class ConstIntBoundNode : public Object { public: int64_t min_value; int64_t max_value; @@ -74,14 +74,14 @@ class ConstIntBoundNode : public Node { static const constexpr int64_t kNegInf = -kPosInf; static constexpr const char* _type_key = "arith.ConstIntBound"; - TVM_DECLARE_NODE_TYPE_INFO(ConstIntBoundNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(ConstIntBoundNode, Object); }; /*! * \brief reference class to ConstIntBoundNode * \sa ConstIntBoundNode */ -class ConstIntBound : public NodeRef { +class ConstIntBound : public ObjectRef { public: /*! * \brief constructor by fields. @@ -92,7 +92,7 @@ class ConstIntBound : public NodeRef { static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf; static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf; - TVM_DEFINE_NODE_REF_METHODS(ConstIntBound, NodeRef, ConstIntBoundNode); + TVM_DEFINE_OBJECT_REF_METHODS(ConstIntBound, ObjectRef, ConstIntBoundNode); }; /*! @@ -155,7 +155,7 @@ class ConstIntBoundAnalyzer { * This is useful to decide if the index is dividable by certain value. * For example, if index = 0 + 4 x, then we know it can be divided by 4. */ -class ModularSetNode : public Node { +class ModularSetNode : public Object { public: /*! \brief linear co-efficient */ int64_t coeff; @@ -168,18 +168,18 @@ class ModularSetNode : public Node { } static constexpr const char* _type_key = "arith.ModularSet"; - TVM_DECLARE_NODE_TYPE_INFO(ModularSetNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(ModularSetNode, Object); }; /*! * \brief reference of ModularSetNode * \sa ModularSetNode */ -class ModularSet : public NodeRef { +class ModularSet : public ObjectRef { public: TVM_DLL ModularSet(int64_t coeff, int64_t base); - TVM_DEFINE_NODE_REF_METHODS(ModularSet, NodeRef, ModularSetNode); + TVM_DEFINE_OBJECT_REF_METHODS(ModularSet, ObjectRef, ModularSetNode); }; /*! @@ -349,20 +349,20 @@ enum SignType { /*! * \brief Base class of all IntSet containers. */ -struct IntSetNode : public Node { +struct IntSetNode : public Object { static constexpr const char* _type_key = "IntSet"; - TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Object); + TVM_DECLARE_BASE_OBJECT_INFO(IntSetNode, Object); }; /*! * \brief Integer set class, represent a set of integers in one dimension. */ -class IntSet : public NodeRef { +class IntSet : public ObjectRef { public: /*! \brief constructor */ IntSet() {} // constructor from not container. - explicit IntSet(ObjectPtr n) : NodeRef(n) {} + explicit IntSet(ObjectPtr n) : ObjectRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -598,7 +598,7 @@ IntSet EvalSet(Range r, const std::unordered_map& dom_map); /*! \brief Map from Expr to IntSet */ -using ExprIntSetMap = std::unordered_map; +using ExprIntSetMap = std::unordered_map; /*! * \brief Find the integer set of every sub-expression, given the * domain of each iteration variables. diff --git a/include/tvm/attrs.h b/include/tvm/attrs.h index 8810c4e4a0df1..0178eabe02ebf 100644 --- a/include/tvm/attrs.h +++ b/include/tvm/attrs.h @@ -65,7 +65,7 @@ namespace tvm { */ #define TVM_DECLARE_ATTRS(ClassName, TypeKey) \ static constexpr const char* _type_key = TypeKey; \ - TVM_DECLARE_NODE_TYPE_INFO(ClassName, ::tvm::BaseAttrsNode) \ + TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \ template \ void __VisitAttrs__(FVisit& __fvisit__) // NOLINT(*) @@ -83,9 +83,9 @@ namespace tvm { * \tparam TNodeRef the type to be created. * \return A instance that will represent None. */ -template -inline TNodeRef NullValue() { - return TNodeRef(NodePtr(nullptr)); +template +inline TObjectRef NullValue() { + return TObjectRef(ObjectPtr(nullptr)); } template<> @@ -106,7 +106,7 @@ struct AttrError : public dmlc::Error { /*! * \brief Information about attribute fields in string representations. */ -class AttrFieldInfoNode : public Node { +class AttrFieldInfoNode : public Object { public: /*! \brief name of the field */ std::string name; @@ -121,11 +121,14 @@ class AttrFieldInfoNode : public Node { v->Visit("description", &description); } static constexpr const char* _type_key = "AttrFieldInfo"; - TVM_DECLARE_NODE_TYPE_INFO(AttrFieldInfoNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object); }; /*! \brief AttrFieldInfo */ -TVM_DEFINE_NODE_REF(AttrFieldInfo, AttrFieldInfoNode); +class AttrFieldInfo : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(AttrFieldInfo, ObjectRef, AttrFieldInfoNode); +}; class AttrsHashHandler; class AttrsEqualHandler; @@ -217,7 +220,7 @@ class AttrsHash { * subclass AttrsNode instead. * \sa AttrsNode */ -class BaseAttrsNode : public Node { +class BaseAttrsNode : public Object { public: using TVMArgs = runtime::TVMArgs; using TVMRetValue = runtime::TVMRetValue; @@ -271,16 +274,16 @@ class BaseAttrsNode : public Node { TVM_DLL virtual size_t ContentHash(AttrsHash hasher) const = 0; static constexpr const char* _type_key = "Attrs"; - TVM_DECLARE_BASE_NODE_INFO(BaseAttrsNode, Node); + TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object); }; /*! \brief Base attribute container for all attributes */ -class Attrs : public NodeRef { +class Attrs : public ObjectRef { public: // normal constructor Attrs() {} // construct from shared ptr. - explicit Attrs(NodePtr n) : NodeRef(n) {} + explicit Attrs(ObjectPtr n) : ObjectRef(n) {} /*! \return The attribute node */ const BaseAttrsNode* operator->() const { @@ -305,13 +308,13 @@ class Attrs : public NodeRef { class DictAttrsNode : public BaseAttrsNode { public: /*! \brief internal attrs map */ - Map dict; + Map dict; /*! * \brief Consruct a Attrs backed by DictAttrsNode. * \param dict The attributes. * \return The dict attributes. */ - TVM_DLL static Attrs make(Map dict); + TVM_DLL static Attrs make(Map dict); // implementations void VisitAttrs(AttrVisitor* v) final; void VisitNonDefaultAttrs(AttrVisitor* v) final; @@ -321,7 +324,7 @@ class DictAttrsNode : public BaseAttrsNode { size_t ContentHash(AttrsHash hasher) const final; // type info static constexpr const char* _type_key = "DictAttrs"; - TVM_DECLARE_NODE_TYPE_INFO(DictAttrsNode, BaseAttrsNode); + TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode); }; @@ -639,7 +642,7 @@ class AttrDocEntry { public: using TSelf = AttrDocEntry; - explicit AttrDocEntry(NodePtr info) + explicit AttrDocEntry(ObjectPtr info) : info_(info) { } TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { @@ -663,15 +666,15 @@ class AttrDocEntry { } private: - NodePtr info_; + ObjectPtr info_; }; class AttrDocVisitor { public: template AttrDocEntry operator()(const char* key, T* v) { - NodePtr info - = make_node(); + ObjectPtr info + = make_object(); info->name = key; info->type_info = TypeName::value; fields_.push_back(AttrFieldInfo(info)); diff --git a/include/tvm/buffer.h b/include/tvm/buffer.h index fac18a9b17533..44c7918631532 100644 --- a/include/tvm/buffer.h +++ b/include/tvm/buffer.h @@ -48,10 +48,10 @@ enum BufferType : int { * It is a composition of primitive symbolic types, * used to specify the memory layout of the Tensor used in program input. */ -class Buffer : public NodeRef { +class Buffer : public ObjectRef { public: Buffer() {} - explicit Buffer(ObjectPtr n) : NodeRef(n) {} + explicit Buffer(ObjectPtr n) : ObjectRef(n) {} /*! * \brief Return a new buffer that is equivalent with current one * but always add stride field. @@ -101,7 +101,7 @@ class Buffer : public NodeRef { }; /*! \brief Node to represent a buffer */ -class BufferNode : public Node { +class BufferNode : public Object { public: // Data fields. /*! @@ -169,7 +169,7 @@ class BufferNode : public Node { BufferType buffer_type); static constexpr const char* _type_key = "Buffer"; - TVM_DECLARE_NODE_TYPE_INFO(BufferNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object); }; inline const BufferNode* Buffer::operator->() const { diff --git a/include/tvm/build_module.h b/include/tvm/build_module.h index fba929cda1be7..5078621e4bdad 100644 --- a/include/tvm/build_module.h +++ b/include/tvm/build_module.h @@ -39,7 +39,7 @@ namespace tvm { * \brief Container for target device information. * Use target::llvm, target::cuda etc functions instead of constructing directly. */ -class TargetNode : public Node { +class TargetNode : public Object { public: /*! \brief The name of the target device */ std::string target_name; @@ -82,7 +82,7 @@ class TargetNode : public Node { TVM_DLL std::unordered_set libs() const; static constexpr const char* _type_key = "Target"; - TVM_DECLARE_NODE_TYPE_INFO(TargetNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(TargetNode, Object); private: /*! \brief Internal string repr. */ @@ -90,10 +90,10 @@ class TargetNode : public Node { }; /*! \brief reference cpass to the target. */ -class Target : public NodeRef { +class Target : public ObjectRef { public: Target() {} - explicit Target(ObjectPtr n) : NodeRef(n) {} + explicit Target(ObjectPtr n) : ObjectRef(n) {} /*! * \brief Create a Target given a string * \param target_str the string to parse @@ -178,7 +178,7 @@ TVM_DLL Target ext_dev(const std::vector& options = /*! * \brief Container for build configuration options */ -class BuildConfigNode : public Node { +class BuildConfigNode : public Object { public: /*! * \brief The data alignment to use when constructing buffers. If this is set to @@ -254,16 +254,16 @@ class BuildConfigNode : public Node { } static constexpr const char* _type_key = "BuildConfig"; - TVM_DECLARE_NODE_TYPE_INFO(BuildConfigNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(BuildConfigNode, Object); }; /*! * \brief Build configuration for compilations. */ -class BuildConfig : public ::tvm::NodeRef { +class BuildConfig : public ::tvm::ObjectRef { public: BuildConfig() {} - explicit BuildConfig(ObjectPtr n) : NodeRef(n) {} + explicit BuildConfig(ObjectPtr n) : ObjectRef(n) {} const BuildConfigNode* operator->() const { return static_cast(get()); } @@ -375,10 +375,10 @@ class GenericFuncNode; /*! * \brief Generic function that can be specialized on a per-target basis. */ -class GenericFunc : public NodeRef { +class GenericFunc : public ObjectRef { public: GenericFunc() {} - explicit GenericFunc(ObjectPtr n) : NodeRef(n) {} + explicit GenericFunc(ObjectPtr n) : ObjectRef(n) {} /*! * \brief Set the default function implementaiton. @@ -471,7 +471,7 @@ inline runtime::TVMRetValue GenericFunc::operator()(Args&& ...args) const { /*! * \brief Represents a generic function that can be specialized on a per-target basis. */ -class GenericFuncNode : public Node { +class GenericFuncNode : public Object { public: /*! \brief name of the function */ std::string name_; @@ -483,7 +483,7 @@ class GenericFuncNode : public Node { void VisitAttrs(AttrVisitor* v) {} static constexpr const char* _type_key = "GenericFunc"; - TVM_DECLARE_NODE_TYPE_INFO(GenericFuncNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(GenericFuncNode, Object); }; inline GenericFuncNode* GenericFunc::operator->() { diff --git a/include/tvm/data_layout.h b/include/tvm/data_layout.h index 5e2cc08660db1..8c7247ff860b1 100644 --- a/include/tvm/data_layout.h +++ b/include/tvm/data_layout.h @@ -92,7 +92,7 @@ class LayoutAxis { class Layout; // Internal node container Buffer -class LayoutNode : public Node { +class LayoutNode : public Object { public: /*! \brief string representation of layout, "" for scalar. */ std::string name; @@ -112,7 +112,7 @@ class LayoutNode : public Node { TVM_DLL static Layout make(const std::string& layout); static constexpr const char* _type_key = "Layout"; - TVM_DECLARE_NODE_TYPE_INFO(LayoutNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(LayoutNode, Object); }; /*! @@ -125,9 +125,9 @@ class LayoutNode : public Node { * Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel). * Layout for scalar is defined, while both its name and axes have size 0. */ -class Layout : public NodeRef { +class Layout : public ObjectRef { public: - explicit Layout(ObjectPtr n) : NodeRef(n) {} + explicit Layout(ObjectPtr n) : ObjectRef(n) {} /*! \brief default constructor */ Layout() = default; @@ -311,7 +311,7 @@ class Layout : public NodeRef { class BijectiveLayout; // Internal node container BijectiveLayout -class BijectiveLayoutNode : public Node { +class BijectiveLayoutNode : public Object { public: /*! \brief Describes how source axes can be mapped to the destination axes, * e.g., [i0 / 16, i1, i0 % 16] can describe NC -> NC16n @@ -333,7 +333,7 @@ class BijectiveLayoutNode : public Node { } static constexpr const char* _type_key = "BijectiveLayout"; - TVM_DECLARE_NODE_TYPE_INFO(BijectiveLayoutNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(BijectiveLayoutNode, Object); TVM_DLL static BijectiveLayout make(const Layout& src_layout, const Layout& dst_layout); @@ -344,10 +344,10 @@ class BijectiveLayoutNode : public Node { * provides API to transform N-dimention tensor from the source indices (i0, i1, …, im) * to the destination indices (j0, j1, … jm). */ -class BijectiveLayout : public NodeRef { +class BijectiveLayout : public ObjectRef { public: BijectiveLayout() = default; - explicit BijectiveLayout(NodePtr n) : NodeRef(n) {} + explicit BijectiveLayout(ObjectPtr n) : ObjectRef(n) {} // Given the source shape, infer the destination shape. TVM_DLL Array ForwardShape(const Array& shape) const; diff --git a/include/tvm/expr.h b/include/tvm/expr.h index f27cb9879fb76..0605cc512690f 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -38,20 +38,20 @@ namespace tvm { /*! \brief Base node of all expressions. */ -class ExprNode : public Node { +class ExprNode : public Object { public: /*! \brief The data type of the expression. */ DataType dtype; static constexpr const char* _type_key = "Expr"; - TVM_DECLARE_BASE_NODE_INFO(ExprNode, Node); + TVM_DECLARE_BASE_OBJECT_INFO(ExprNode, Object); }; /*! \brief Container of all expressions. */ -class Expr : public NodeRef { +class Expr : public ObjectRef { public: Expr() {} - explicit Expr(ObjectPtr ptr) : NodeRef(ptr) {} + explicit Expr(ObjectPtr ptr) : ObjectRef(ptr) {} /*! * \brief construct from integer. * \param value The value to be constructed. @@ -78,16 +78,16 @@ class Expr : public NodeRef { }; /*! \brief Base node of all statements. */ -class StmtNode : public Node { +class StmtNode : public Object { public: static constexpr const char* _type_key = "Stmt"; - TVM_DECLARE_BASE_NODE_INFO(StmtNode, Node); + TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object); }; /*! \brief Container of all statements */ -class Stmt : public NodeRef { +class Stmt : public ObjectRef { public: - TVM_DEFINE_NODE_REF_METHODS(Stmt, NodeRef, StmtNode); + TVM_DEFINE_OBJECT_REF_METHODS(Stmt, ObjectRef, StmtNode); }; class Var; @@ -118,7 +118,7 @@ class Variable : public ExprNode { } static constexpr const char* _type_key = "Variable"; - TVM_DECLARE_NODE_TYPE_INFO(Variable, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Variable, ExprNode); }; /*! \brief a named variable in TVM */ @@ -156,8 +156,8 @@ class Var : public Expr { // Backward compatibility, will be removed later. using VarExpr = Var; using BaseExprNode = ExprNode; -using ExprHash = NodeHash; -using ExprEqual = NodeEqual; +using ExprHash = ObjectHash; +using ExprEqual = ObjectEqual; class Integer; /*! \brief ExprNode: constant integer. */ @@ -174,7 +174,7 @@ class IntImm : public ExprNode { TVM_DLL static Integer make(DataType t, int64_t value); static constexpr const char* _type_key = "IntImm"; - TVM_DECLARE_NODE_TYPE_INFO(IntImm, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(IntImm, ExprNode); }; /*! @@ -222,7 +222,7 @@ class Integer : public Expr { }; /*! \brief range over one dimension */ -class RangeNode : public Node { +class RangeNode : public Object { public: /*! \brief beginning of the node */ Expr min; @@ -238,11 +238,11 @@ class RangeNode : public Node { } static constexpr const char* _type_key = "Range"; - TVM_DECLARE_NODE_TYPE_INFO(RangeNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object); }; /*! \brief Range constainer */ -class Range : public NodeRef { +class Range : public ObjectRef { public: /*! * \brief constructor by begin and end @@ -261,7 +261,7 @@ class Range : public NodeRef { */ static Range make_by_min_extent(Expr min, Expr extent); // declare range. - TVM_DEFINE_NODE_REF_METHODS(Range, NodeRef, RangeNode); + TVM_DEFINE_OBJECT_REF_METHODS(Range, ObjectRef, RangeNode); }; /*! \brief container class of iteration variable. */ @@ -343,12 +343,12 @@ enum IterVarType : int { * \brief Iteration Variable, * represents an iteration over an integer interval. */ -class IterVar : public NodeRef { +class IterVar : public ObjectRef { public: // construct a new iter var without a domain IterVar() {} // construct from shared ptr. - explicit IterVar(ObjectPtr n) : NodeRef(n) {} + explicit IterVar(ObjectPtr n) : ObjectRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -384,14 +384,14 @@ using Domain = Array; * \brief Dump the node to stderr, used for debug purposes. * \param node The input node */ -TVM_DLL void Dump(const NodeRef& node); +TVM_DLL void Dump(const ObjectRef& node); // definition of Node. /*! * \brief An iteration variable representing an iteration * over a one dimensional interval. */ -class IterVarNode : public Node { +class IterVarNode : public Object { public: /*! * \brief the domain of iteration, if known, can be None @@ -420,7 +420,7 @@ class IterVarNode : public Node { std::string thread_tag = ""); static constexpr const char* _type_key = "IterVar"; - TVM_DECLARE_NODE_TYPE_INFO(IterVarNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, Object); }; // inline implementations @@ -490,17 +490,22 @@ class IRPrinter { using FType = NodeFunctor; TVM_DLL static FType& vtable(); }; +} // namespace tvm -// default print function for all nodes +namespace tvm { +namespace runtime { +// default print function for all objects +// provide in the runtime namespace as this is where objectref originally comes from. inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLINT(*) IRPrinter(os).Print(n); return os; } +} // namespace runtime } // namespace tvm namespace std { template <> -struct hash<::tvm::IterVar> : public ::tvm::NodeHash { +struct hash<::tvm::IterVar> : public ::tvm::ObjectHash { }; } #endif // TVM_EXPR_H_ diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 33aa72b50805a..c55a4695de4d3 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -53,7 +53,7 @@ class UIntImm : public ExprNode { TVM_DLL static Expr make(DataType t, uint64_t value); static constexpr const char* _type_key = "UIntImm"; - TVM_DECLARE_NODE_TYPE_INFO(UIntImm, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(UIntImm, ExprNode); }; /*! \brief Floating point constants. */ @@ -70,7 +70,7 @@ class FloatImm : public ExprNode { TVM_DLL static Expr make(DataType t, double value); static constexpr const char* _type_key = "FloatImm"; - TVM_DECLARE_NODE_TYPE_INFO(FloatImm, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(FloatImm, ExprNode); }; /*! \brief String constants, only used in asserts. */ @@ -87,7 +87,7 @@ class StringImm : public ExprNode { TVM_DLL Expr static make(std::string value); static constexpr const char* _type_key = "StringImm"; - TVM_DECLARE_NODE_TYPE_INFO(StringImm, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(StringImm, ExprNode); }; /*! @@ -107,7 +107,7 @@ class Cast : public ExprNode { TVM_DLL static Expr make(DataType t, Expr v); static constexpr const char* _type_key = "Cast"; - TVM_DECLARE_NODE_TYPE_INFO(Cast, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Cast, ExprNode); }; /*! @@ -132,14 +132,14 @@ class BinaryOpNode : public ExprNode { CHECK(a.defined()) << "ValueError: a is undefined\n"; CHECK(b.defined()) << "ValueError: b is undefined\n"; CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types\n"; - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = a.dtype(); node->a = std::move(a); node->b = std::move(b); return Expr(node); } - TVM_DECLARE_NODE_TYPE_INFO(T, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(T, ExprNode); }; /*! \brief a + b */ @@ -224,14 +224,14 @@ class CmpOpNode : public ExprNode { CHECK(a.defined()) << "ValueError: a is undefined\n"; CHECK(b.defined()) << "ValueError: b is undefined\n"; CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types\n"; - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = DataType::Bool(a.dtype().lanes()); node->a = std::move(a); node->b = std::move(b); return Expr(node); } - TVM_DECLARE_NODE_TYPE_INFO(T, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(T, ExprNode); }; /*! \brief a == b */ @@ -287,7 +287,7 @@ class And : public ExprNode { TVM_DLL static Expr make(Expr a, Expr b); static constexpr const char* _type_key = "And"; - TVM_DECLARE_NODE_TYPE_INFO(And, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(And, ExprNode); }; /*! \brief a || b */ @@ -307,7 +307,7 @@ class Or : public ExprNode { TVM_DLL static Expr make(Expr a, Expr b); static constexpr const char* _type_key = "Or"; - TVM_DECLARE_NODE_TYPE_INFO(Or, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Or, ExprNode); }; /*! \brief !a */ @@ -324,7 +324,7 @@ class Not : public ExprNode { TVM_DLL static Expr make(Expr a); static constexpr const char* _type_key = "Not"; - TVM_DECLARE_NODE_TYPE_INFO(Not, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Not, ExprNode); }; /*! @@ -353,7 +353,7 @@ class Select : public ExprNode { TVM_DLL static Expr make(Expr condition, Expr true_value, Expr false_value); static constexpr const char* _type_key = "Select"; - TVM_DECLARE_NODE_TYPE_INFO(Select, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Select, ExprNode); }; /*! @@ -390,7 +390,7 @@ class Load : public ExprNode { TVM_DLL static Expr make(DataType dtype, Var buffer_var, Expr index, Expr predicate); static constexpr const char* _type_key = "Load"; - TVM_DECLARE_NODE_TYPE_INFO(Load, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Load, ExprNode); }; /*! @@ -421,7 +421,7 @@ class Ramp : public ExprNode { TVM_DLL static Expr make(Expr base, Expr stride, int lanes); static constexpr const char* _type_key = "Ramp"; - TVM_DECLARE_NODE_TYPE_INFO(Ramp, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Ramp, ExprNode); }; /*! \brief Create a vector where all the elements are value. */ @@ -441,7 +441,7 @@ class Broadcast : public ExprNode { TVM_DLL static Expr make(Expr value, int lanes); static constexpr const char* _type_key = "Broadcast"; - TVM_DECLARE_NODE_TYPE_INFO(Broadcast, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Broadcast, ExprNode); }; /*! @@ -466,7 +466,7 @@ class Let : public ExprNode { TVM_DLL static Expr make(Var var, Expr value, Expr body); static constexpr const char* _type_key = "Let"; - TVM_DECLARE_NODE_TYPE_INFO(Let, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Let, ExprNode); }; // Call node, represent a function call or a multi-dimensional array load. @@ -477,7 +477,7 @@ class Let : public ExprNode { // We should move most information into function itself and remove name. /*! \brief Base node of internal functions. */ -class FunctionBaseNode : public Node { +class FunctionBaseNode : public Object { public: /*! \return the name of the function */ virtual const std::string& func_name() const = 0; @@ -486,9 +486,9 @@ class FunctionBaseNode : public Node { }; /*! \brief reference to a function */ -class FunctionRef : public NodeRef { +class FunctionRef : public ObjectRef { public: - TVM_DEFINE_NODE_REF_METHODS(FunctionRef, NodeRef, FunctionBaseNode); + TVM_DEFINE_OBJECT_REF_METHODS(FunctionRef, ObjectRef, FunctionBaseNode); }; /*! @@ -560,7 +560,7 @@ class Call : public ExprNode { bool is_vectorizable() const; static constexpr const char* _type_key = "Call"; - TVM_DECLARE_NODE_TYPE_INFO(Call, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Call, ExprNode); // Build-in intrinsics static constexpr const char* reinterpret = "reinterpret"; @@ -602,16 +602,16 @@ class Shuffle : public ExprNode { TVM_DLL static Expr make_extract_element(Expr vector, int index); static constexpr const char* _type_key = "Shuffle"; - TVM_DECLARE_NODE_TYPE_INFO(Shuffle, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Shuffle, ExprNode); }; // Reduce operator class CommReducerNode; -class CommReducer : public NodeRef { +class CommReducer : public ObjectRef { public: CommReducer() {} - explicit CommReducer(NodePtr n) : NodeRef(n) {} + explicit CommReducer(ObjectPtr n) : ObjectRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -630,7 +630,7 @@ class CommReducer : public NodeRef { * \brief A commutative reducer node to represent a commutative * binary operator with identity element */ -class CommReducerNode : public Node { +class CommReducerNode : public Object { public: /*! \brief The left argument of reducer */ Array lhs; @@ -660,7 +660,7 @@ class CommReducerNode : public Node { } static constexpr const char* _type_key = "CommReducer"; - TVM_DECLARE_NODE_TYPE_INFO(CommReducerNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(CommReducerNode, Object); }; inline const CommReducerNode* CommReducer::get() const { @@ -704,7 +704,7 @@ class Reduce : public ExprNode { } static constexpr const char* _type_key = "Reduce"; - TVM_DECLARE_NODE_TYPE_INFO(Reduce, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Reduce, ExprNode); }; /*! \brief Any shape. */ @@ -719,7 +719,7 @@ class Any : public ExprNode { TVM_DLL static Expr make(); static constexpr const char* _type_key = "Any"; - TVM_DECLARE_NODE_TYPE_INFO(Any, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Any, ExprNode); }; // Statements @@ -744,7 +744,7 @@ class LetStmt : public StmtNode { TVM_DLL static Stmt make(Var var, Expr value, Stmt body); static constexpr const char* _type_key = "LetStmt"; - TVM_DECLARE_NODE_TYPE_INFO(LetStmt, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(LetStmt, StmtNode); }; /*! @@ -760,7 +760,7 @@ class LetStmt : public StmtNode { class AttrStmt : public StmtNode { public: /*! \brief this is attribute about certain node */ - NodeRef node; + ObjectRef node; /*! \brief the type key of the attribute */ std::string attr_key; /*! \brief The attribute value, value is well defined at current scope. */ @@ -775,13 +775,13 @@ class AttrStmt : public StmtNode { v->Visit("body", &body); } - TVM_DLL static Stmt make(NodeRef node, + TVM_DLL static Stmt make(ObjectRef node, std::string type_key, Expr value, Stmt body); static constexpr const char* _type_key = "AttrStmt"; - TVM_DECLARE_NODE_TYPE_INFO(AttrStmt, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(AttrStmt, StmtNode); }; /*! @@ -808,7 +808,7 @@ class AssertStmt : public StmtNode { TVM_DLL static Stmt make(Expr condition, Expr message, Stmt body); static constexpr const char* _type_key = "AssertStmt"; - TVM_DECLARE_NODE_TYPE_INFO(AssertStmt, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmt, StmtNode); }; // TODO(tvm-team): consider consolidate with AttrStmt. @@ -831,7 +831,7 @@ class ProducerConsumer : public StmtNode { TVM_DLL static Stmt make(FunctionRef func, bool is_producer, Stmt body); static constexpr const char* _type_key = "ProducerConsumer"; - TVM_DECLARE_NODE_TYPE_INFO(ProducerConsumer, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(ProducerConsumer, StmtNode); }; /*! @@ -876,7 +876,7 @@ class Store : public StmtNode { Expr predicate); static constexpr const char* _type_key = "Store"; - TVM_DECLARE_NODE_TYPE_INFO(Store, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Store, StmtNode); }; /*! @@ -906,7 +906,7 @@ class Provide : public StmtNode { Array args); static constexpr const char* _type_key = "Provide"; - TVM_DECLARE_NODE_TYPE_INFO(Provide, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Provide, StmtNode); }; /*! @@ -963,7 +963,7 @@ class Allocate : public StmtNode { const Array& extents); static constexpr const char* _type_key = "Allocate"; - TVM_DECLARE_NODE_TYPE_INFO(Allocate, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Allocate, StmtNode); }; /*! \brief Free the resources in the buffer before the scope ends. */ @@ -979,7 +979,7 @@ class Free : public StmtNode { TVM_DLL static Stmt make(Var buffer_var); static constexpr const char* _type_key = "Free"; - TVM_DECLARE_NODE_TYPE_INFO(Free, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Free, StmtNode); }; /*! @@ -1018,7 +1018,7 @@ class Realize : public StmtNode { Stmt body); static constexpr const char* _type_key = "Realize"; - TVM_DECLARE_NODE_TYPE_INFO(Realize, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Realize, StmtNode); }; /*! @@ -1040,7 +1040,7 @@ class Block : public StmtNode { TVM_DLL static Stmt make(const std::vector &stmts); static constexpr const char* _type_key = "Block"; - TVM_DECLARE_NODE_TYPE_INFO(Block, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Block, StmtNode); }; /*! @@ -1064,7 +1064,7 @@ class IfThenElse : public StmtNode { TVM_DLL static Stmt make(Expr condition, Stmt then_case, Stmt else_case = Stmt()); static constexpr const char* _type_key = "IfThenElse"; - TVM_DECLARE_NODE_TYPE_INFO(IfThenElse, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(IfThenElse, StmtNode); }; /*! @@ -1085,7 +1085,7 @@ class Evaluate : public StmtNode { TVM_DLL static Stmt make(Expr v); static constexpr const char* _type_key = "Evaluate"; - TVM_DECLARE_NODE_TYPE_INFO(Evaluate, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Evaluate, StmtNode); }; /*! \brief Additional annotation of for loop. */ @@ -1152,7 +1152,7 @@ class For : public StmtNode { } static constexpr const char* _type_key = "For"; - TVM_DECLARE_NODE_TYPE_INFO(For, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(For, StmtNode); }; /*! @@ -1182,7 +1182,7 @@ class Prefetch : public StmtNode { Region bounds); static constexpr const char* _type_key = "Prefetch"; - TVM_DECLARE_NODE_TYPE_INFO(Prefetch, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Prefetch, StmtNode); }; /*! @@ -1636,7 +1636,7 @@ namespace std { template <> struct hash<::tvm::ir::TensorKey> { std::size_t operator()(const ::tvm::ir::TensorKey& k) const { - size_t lhs = ::tvm::NodeHash()(k.f); + size_t lhs = ::tvm::ObjectHash()(k.f); size_t rhs = static_cast(k.value_index); lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); return lhs; diff --git a/include/tvm/ir_functor_ext.h b/include/tvm/ir_functor_ext.h index 04ce7934ff2f4..9b2632f87b3cf 100644 --- a/include/tvm/ir_functor_ext.h +++ b/include/tvm/ir_functor_ext.h @@ -164,7 +164,7 @@ class ExprFunctor { virtual R VisitExpr_(const UIntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const FloatImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const StringImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExprDefault_(const Node* op, Args ...) { + virtual R VisitExprDefault_(const Object* op, Args ...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); return R(); } @@ -255,7 +255,7 @@ class StmtFunctor { virtual R VisitStmt_(const Prefetch* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const Block* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmtDefault_(const Node* op, Args ...) { + virtual R VisitStmtDefault_(const Object* op, Args ...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); return R(); } diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index b0b13df729cc8..6e1fed5a85427 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -418,7 +418,7 @@ Stmt HoistIfThenElse(Stmt stmt); */ LoweredFunc MakeAPI(Stmt body, std::string name, - Array api_args, + Array api_args, int num_unpacked_args, bool is_restricted); diff --git a/include/tvm/ir_visitor.h b/include/tvm/ir_visitor.h index b85cf233a42fb..cffcdcbdf5b83 100644 --- a/include/tvm/ir_visitor.h +++ b/include/tvm/ir_visitor.h @@ -87,7 +87,7 @@ class TVM_DLL IRVisitor { /*! * \brief recursively visit an IR node */ - virtual void Visit(const NodeRef& node) { + virtual void Visit(const ObjectRef& node) { static const FVisit& f = vtable(); if (node.defined()) f(node, this); } @@ -152,7 +152,7 @@ class TVM_DLL IRVisitor { * \param node The ir to be visited. * \param fvisit The visitor function to be applied. */ -TVM_DLL void PostOrderVisit(const NodeRef& node, std::function fvisit); +TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function fvisit); } // namespace ir } // namespace tvm diff --git a/include/tvm/lowered_func.h b/include/tvm/lowered_func.h index 6709f545cb399..3de6bfdbb0873 100644 --- a/include/tvm/lowered_func.h +++ b/include/tvm/lowered_func.h @@ -131,7 +131,7 @@ class LoweredFuncNode : public ir::FunctionBaseNode { } static constexpr const char* _type_key = "LoweredFunc"; - TVM_DECLARE_NODE_TYPE_INFO(LoweredFuncNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(LoweredFuncNode, Object); }; // Implementations of inline functions @@ -143,7 +143,7 @@ inline const LoweredFuncNode* LoweredFunc::operator->() const { namespace std { template <> -struct hash<::tvm::LoweredFunc> : public tvm::NodeHash { +struct hash<::tvm::LoweredFunc> : public tvm::ObjectHash { }; } diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h index 1a276ae695fc2..d20fb288039cb 100644 --- a/include/tvm/node/container.h +++ b/include/tvm/node/container.h @@ -35,7 +35,7 @@ namespace tvm { /*! \brief array node content in array */ -class ArrayNode : public Node { +class ArrayNode : public Object { public: /*! \brief the data content */ std::vector data; @@ -44,11 +44,11 @@ class ArrayNode : public Node { } static constexpr const char* _type_key = "Array"; - TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Object); }; /*! \brief map node content */ -class MapNode : public Node { +class MapNode : public Object { public: void VisitAttrs(AttrVisitor* visitor) { } @@ -63,12 +63,12 @@ class MapNode : public Node { ContainerType data; static constexpr const char* _type_key = "Map"; - TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object); }; /*! \brief specialized map node with string as key */ -class StrMapNode : public Node { +class StrMapNode : public Object { public: /*! \brief The corresponding conatiner type */ using ContainerType = std::unordered_map; @@ -80,7 +80,7 @@ class StrMapNode : public Node { ContainerType data; static constexpr const char* _type_key = "StrMap"; - TVM_DECLARE_FINAL_OBJECT_INFO(StrMapNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(StrMapNode, Object); }; /*! @@ -138,13 +138,13 @@ class IterAdapter { */ template::value>::type > -class Array : public NodeRef { +class Array : public ObjectRef { public: /*! * \brief default constructor */ Array() { - data_ = make_node(); + data_ = make_object(); } /*! * \brief move constructor @@ -164,7 +164,7 @@ class Array : public NodeRef { * \brief constructor from pointer * \param n the container pointer */ - explicit Array(ObjectPtr n) : NodeRef(n) {} + explicit Array(ObjectPtr n) : ObjectRef(n) {} /*! * \brief constructor from iterator * \param begin begin of iterator @@ -195,7 +195,7 @@ class Array : public NodeRef { * \param val The init value */ explicit Array(size_t n, const T& val) { - auto tmp_node = make_node(); + auto tmp_node = make_object(); for (size_t i = 0; i < n; ++i) { tmp_node->data.push_back(val); } @@ -227,7 +227,7 @@ class Array : public NodeRef { */ template void assign(IterType begin, IterType end) { - auto n = make_node(); + auto n = make_object(); for (IterType it = begin; it != end; ++it) { n->data.push_back(T(*it)); } @@ -257,7 +257,7 @@ class Array : public NodeRef { */ inline ArrayNode* CopyOnWrite() { if (data_.get() == nullptr || !data_.unique()) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->data = static_cast(data_.get())->data; ObjectPtr(std::move(n)).swap(data_); } @@ -333,13 +333,13 @@ template::value || std::is_base_of::value >::type, typename = typename std::enable_if::value>::type> -class Map : public NodeRef { +class Map : public ObjectRef { public: /*! * \brief default constructor */ Map() { - data_ = make_node(); + data_ = make_object(); } /*! * \brief move constructor @@ -352,13 +352,13 @@ class Map : public NodeRef { * \brief copy constructor * \param other source */ - Map(const Map &other) : NodeRef(other.data_) { // NOLINT(*) + Map(const Map &other) : ObjectRef(other.data_) { // NOLINT(*) } /*! * \brief constructor from pointer * \param n the container pointer */ - explicit Map(ObjectPtr n) : NodeRef(n) {} + explicit Map(ObjectPtr n) : ObjectRef(n) {} /*! * \brief constructor from iterator * \param begin begin of iterator @@ -410,7 +410,7 @@ class Map : public NodeRef { */ template void assign(IterType begin, IterType end) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); for (IterType i = begin; i != end; ++i) { n->data.emplace(std::make_pair(i->first, i->second)); } @@ -454,7 +454,7 @@ class Map : public NodeRef { */ inline MapNode* CopyOnWrite() { if (data_.get() == nullptr || !data_.unique()) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->data = static_cast(data_.get())->data; ObjectPtr(std::move(n)).swap(data_); } @@ -507,18 +507,18 @@ class Map : public NodeRef { // specialize of string map template -class Map : public NodeRef { +class Map : public ObjectRef { public: // for code reuse Map() { - data_ = make_node(); + data_ = make_object(); } Map(Map && other) { // NOLINT(*) data_ = std::move(other.data_); } - Map(const Map &other) : NodeRef(other.data_) { // NOLINT(*) + Map(const Map &other) : ObjectRef(other.data_) { // NOLINT(*) } - explicit Map(ObjectPtr n) : NodeRef(n) {} + explicit Map(ObjectPtr n) : ObjectRef(n) {} template Map(IterType begin, IterType end) { assign(begin, end); @@ -541,7 +541,7 @@ class Map : public NodeRef { } template void assign(IterType begin, IterType end) { - auto n = make_node(); + auto n = make_object(); for (IterType i = begin; i != end; ++i) { n->data.emplace(std::make_pair(i->first, i->second)); } @@ -565,7 +565,7 @@ class Map : public NodeRef { } inline StrMapNode* CopyOnWrite() { if (data_.get() == nullptr || !data_.unique()) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->data = static_cast(data_.get())->data; ObjectPtr(std::move(n)).swap(data_); } diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h index 4014c3700596f..bb5da415c4638 100644 --- a/include/tvm/node/node.h +++ b/include/tvm/node/node.h @@ -56,105 +56,5 @@ using runtime::ObjectHash; using runtime::ObjectEqual; using runtime::make_object; -using NodeHash = ObjectHash; -using NodeEqual = ObjectEqual; -using Node = Object; - -/*! - * \brief Base class of all references to AST/IR nodes. - */ -class NodeRef : public ObjectRef { - public: - NodeRef() {} - explicit NodeRef(ObjectPtr n) : ObjectRef(n) {} -}; - -/*! - * \brief Allocate a node object. - * \param args arguments to the constructor. - * \tparam T the node type. - * \return The NodePtr to the allocated object. - * \note This function is an alias of make_object. - */ -template -inline NodePtr make_node(Args&&... args) { - return runtime::make_object(std::forward(args)...); -} - -/*! - * \brief helper macro to declare type information in a base node. - */ -#define TVM_DECLARE_BASE_NODE_INFO(TypeName, Parent) \ - TVM_DECLARE_BASE_OBJECT_INFO(TypeName, Parent) - -/*! - * \brief helper macro to declare type information in a terminal node - */ -#define TVM_DECLARE_NODE_TYPE_INFO(TypeName, Parent) \ - TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, Parent); - - -/*! - * \brief Macro to define common node ref methods. - * \param TypeName The name of the NodeRef. - * \param BaseTypeName The Base type. - * \param NodeName The node container type. - */ -#define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \ - TypeName() {} \ - explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n) \ - : BaseTypeName(n) {} \ - const NodeName* operator->() const { \ - return static_cast(data_.get()); \ - } \ - operator bool() const { return this->defined(); } \ - using ContainerType = NodeName; - -/*! - * \brief Macro to define CopyOnWrite function in a NodeRef. - * \param NodeName The Type of the Node. - * - * CopyOnWrite will generate a unique copy of the internal node. - * The node will be copied if it is referenced by multiple places. - * The function returns the raw pointer to the node to allow modification - * of the content. - * - * \code - * - * MyCOWNodeRef ref, ref2; - * ref2 = ref; - * ref.CopyOnWrite()->value = new_value; - * assert(ref2->value == old_value); - * assert(ref->value == new_value); - * - * \endcode - */ -#define TVM_DEFINE_NODE_REF_COW(NodeName) \ - NodeName* CopyOnWrite() { \ - CHECK(data_ != nullptr); \ - if (!data_.unique()) { \ - NodePtr n = make_node(*(operator->())); \ - ObjectPtr(std::move(n)).swap(data_); \ - } \ - return static_cast(data_.get()); \ - } - -/*! \brief Macro to make it easy to define node ref type given node */ -#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \ - class TypeName : public ::tvm::NodeRef { \ - public: \ - TVM_DEFINE_NODE_REF_METHODS(TypeName, ::tvm::NodeRef, NodeName); \ - }; \ - -/*! - * \brief Macro to make it easy to define node ref type that - * has a CopyOnWrite member function. - */ -#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \ - class TypeName : public BaseType { \ - public: \ - TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseType, NodeName); \ - TVM_DEFINE_NODE_REF_COW(NodeName); \ - }; } // namespace tvm #endif // TVM_NODE_NODE_H_ diff --git a/include/tvm/operation.h b/include/tvm/operation.h index 34f584b632615..681d06897355b 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -60,7 +60,7 @@ class OperationNode : public ir::FunctionBaseNode { /*! \brief optional tag of the operation */ std::string tag; /*! \brief additional attributes of the operation*/ - Map attrs; + Map attrs; /*! \return name of the operation */ const std::string& func_name() const final { return name; @@ -149,7 +149,7 @@ class OperationNode : public ir::FunctionBaseNode { static constexpr const char* _type_key = "Operation"; - TVM_DECLARE_BASE_NODE_INFO(OperationNode, Node); + TVM_DECLARE_BASE_OBJECT_INFO(OperationNode, Object); }; /*! @@ -200,7 +200,7 @@ class PlaceholderOpNode : public OperationNode { DataType dtype); static constexpr const char* _type_key = "PlaceholderOp"; - TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode, OperationNode); + TVM_DECLARE_FINAL_OBJECT_INFO(PlaceholderOpNode, OperationNode); }; /*! @@ -228,7 +228,7 @@ class TVM_DLL BaseComputeOpNode : public OperationNode { virtual size_t num_schedulable_dims() const = 0; static constexpr const char* _type_key = "BaseComputeOp"; - TVM_DECLARE_BASE_NODE_INFO(BaseComputeOpNode, OperationNode); + TVM_DECLARE_BASE_OBJECT_INFO(BaseComputeOpNode, OperationNode); }; @@ -269,12 +269,12 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode { } static Operation make(std::string name, std::string tag, - Map attrs, + Map attrs, Array axis, Array body); static constexpr const char* _type_key = "ComputeOp"; - TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, BaseComputeOpNode); + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode); }; /*! @@ -334,7 +334,7 @@ class TensorComputeOpNode : public BaseComputeOpNode { Array scalar_inputs); static constexpr const char* _type_key = "TensorComputeOp"; - TVM_DECLARE_NODE_TYPE_INFO(TensorComputeOpNode, BaseComputeOpNode); + TVM_DECLARE_FINAL_OBJECT_INFO(TensorComputeOpNode, BaseComputeOpNode); }; /*! @@ -407,7 +407,7 @@ class ScanOpNode : public OperationNode { } static Operation make(std::string name, std::string tag, - Map attrs, + Map attrs, IterVar axis, Array init, Array update, @@ -415,7 +415,7 @@ class ScanOpNode : public OperationNode { Array input); static constexpr const char* _type_key = "ScanOp"; - TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode, OperationNode); + TVM_DECLARE_FINAL_OBJECT_INFO(ScanOpNode, OperationNode); }; /*! @@ -472,14 +472,14 @@ class ExternOpNode : public OperationNode { } TVM_DLL static Operation make(std::string name, std::string tag, - Map attrs, + Map attrs, Array inputs, Array input_placeholders, Array output_placeholders, Stmt body); static constexpr const char* _type_key = "ExternOp"; - TVM_DECLARE_NODE_TYPE_INFO(ExternOpNode, OperationNode); + TVM_DECLARE_FINAL_OBJECT_INFO(ExternOpNode, OperationNode); }; /*! @@ -540,13 +540,13 @@ class HybridOpNode : public OperationNode { } TVM_DLL static Operation make(std::string name, std::string tag, - Map attrs, + Map attrs, Array inputs, Array outputs, Stmt body); static constexpr const char* _type_key = "HybridOp"; - TVM_DECLARE_NODE_TYPE_INFO(HybridOpNode, OperationNode); + TVM_DECLARE_FINAL_OBJECT_INFO(HybridOpNode, OperationNode); }; /*! \brief The compute function to specify the input source of a Tensor */ @@ -578,7 +578,7 @@ TVM_DLL Tensor compute(Array shape, FCompute fcompute, std::string name = "tensor", std::string tag = "", - Map attrs = {}); + Map attrs = {}); /*! * \brief Construct a new tensor by computing over shape, @@ -593,7 +593,7 @@ TVM_DLL Array compute(Array shape, FBatchCompute fcompute, std::string name = "tensor", std::string tag = "", - Map attrs = {}); + Map attrs = {}); /*! * \brief Construct new tensors by scan. @@ -613,14 +613,14 @@ TVM_DLL Array scan(Array init, Array inputs = Array(), std::string name = "scan", std::string tag = "", - Map attrs = {}); + Map attrs = {}); // same as compute, specialized for different fcompute function inline Tensor compute(Array shape, std::function f, std::string name = "tensor", std::string tag = "", - Map attrs = {}) { + Map attrs = {}) { FCompute fc = [f] (const Array& i) { return f(i[0]); }; return compute(shape, fc, name, tag, attrs); } @@ -628,7 +628,7 @@ inline Tensor compute(Array shape, std::function f, std::string name = "tensor", std::string tag = "", - Map attrs = {}) { + Map attrs = {}) { FCompute fc = [f] (const Array& i) { return f(i[0], i[1]); }; return compute(shape, fc, name, tag, attrs); } @@ -636,7 +636,7 @@ inline Tensor compute(Array shape, std::function f, std::string name = "tensor", std::string tag = "", - Map attrs = {}) { + Map attrs = {}) { FCompute fc = [f] (const Array& i) { return f(i[0], i[1], i[2]); }; return compute(shape, fc, name, tag, attrs); } @@ -644,7 +644,7 @@ inline Tensor compute(Array shape, std::function f, std::string name = "tensor", std::string tag = "", - Map attrs = {}) { + Map attrs = {}) { FCompute fc = [f] (const Array& i) { return f(i[0], i[1], i[2], i[3]); }; return compute(shape, fc, name, tag, attrs); } diff --git a/include/tvm/packed_func_ext.h b/include/tvm/packed_func_ext.h index c9f7a580621f6..b301a18ea3130 100644 --- a/include/tvm/packed_func_ext.h +++ b/include/tvm/packed_func_ext.h @@ -115,15 +115,15 @@ inline TVMPODValue_::operator tvm::Expr() const { Object* ptr = static_cast(value_.v_handle); if (ptr->IsInstance()) { - return IterVar(ObjectPtr(ptr))->var; + return IterVar(ObjectPtr(ptr))->var; } if (ptr->IsInstance()) { - return Tensor(ObjectPtr(ptr))(); + return Tensor(ObjectPtr(ptr))(); } CHECK(ObjectTypeChecker::Check(ptr)) << "Expect type " << ObjectTypeChecker::TypeName() << " but get " << ptr->GetTypeKey(); - return Expr(ObjectPtr(ptr)); + return Expr(ObjectPtr(ptr)); } inline TVMPODValue_::operator tvm::Integer() const { @@ -138,7 +138,7 @@ inline TVMPODValue_::operator tvm::Integer() const { CHECK(ObjectTypeChecker::Check(ptr)) << "Expect type " << ObjectTypeChecker::TypeName() << " but get " << ptr->GetTypeKey(); - return Integer(ObjectPtr(ptr)); + return Integer(ObjectPtr(ptr)); } } // namespace runtime } // namespace tvm diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index a74353239a008..dac39e014cc7c 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -38,7 +38,7 @@ namespace relay { class PatternNode : public RelayNode { public: static constexpr const char* _type_key = "relay.Pattern"; - TVM_DECLARE_BASE_NODE_INFO(PatternNode, Node); + TVM_DECLARE_BASE_OBJECT_INFO(PatternNode, Object); }; /*! @@ -49,10 +49,10 @@ class PatternNode : public RelayNode { * * ADT pattern matching thus takes a list of values and binds to the first that accepts the value. */ -class Pattern : public NodeRef { +class Pattern : public ObjectRef { public: Pattern() {} - explicit Pattern(ObjectPtr p) : NodeRef(p) {} + explicit Pattern(ObjectPtr p) : ObjectRef(p) {} using ContainerType = PatternNode; }; @@ -71,10 +71,13 @@ class PatternWildcardNode : public PatternNode { } static constexpr const char* _type_key = "relay.PatternWildcard"; - TVM_DECLARE_NODE_TYPE_INFO(PatternWildcardNode, PatternNode); + TVM_DECLARE_FINAL_OBJECT_INFO(PatternWildcardNode, PatternNode); }; -RELAY_DEFINE_NODE_REF(PatternWildcard, PatternWildcardNode, Pattern); +class PatternWildcard : public Pattern { + public: + TVM_DEFINE_OBJECT_REF_METHODS(PatternWildcard, Pattern, PatternWildcardNode); +}; /*! \brief A var pattern. Accept all input and bind to a var. */ class PatternVar; @@ -94,10 +97,13 @@ class PatternVarNode : public PatternNode { } static constexpr const char* _type_key = "relay.PatternVar"; - TVM_DECLARE_NODE_TYPE_INFO(PatternVarNode, PatternNode); + TVM_DECLARE_FINAL_OBJECT_INFO(PatternVarNode, PatternNode); }; -RELAY_DEFINE_NODE_REF(PatternVar, PatternVarNode, Pattern); +class PatternVar : public Pattern { + public: + TVM_DEFINE_OBJECT_REF_METHODS(PatternVar, Pattern, PatternVarNode); +}; /*! * \brief ADT constructor. @@ -132,10 +138,13 @@ class ConstructorNode : public ExprNode { } static constexpr const char* _type_key = "relay.Constructor"; - TVM_DECLARE_NODE_TYPE_INFO(ConstructorNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(Constructor, ConstructorNode, Expr); +class Constructor : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Constructor, Expr, ConstructorNode); +}; /*! \brief A constructor pattern. Matches a value with the given constructor, binds recursively. */ class PatternConstructor; @@ -158,10 +167,13 @@ class PatternConstructorNode : public PatternNode { } static constexpr const char* _type_key = "relay.PatternConstructor"; - TVM_DECLARE_NODE_TYPE_INFO(PatternConstructorNode, PatternNode); + TVM_DECLARE_FINAL_OBJECT_INFO(PatternConstructorNode, PatternNode); }; -RELAY_DEFINE_NODE_REF(PatternConstructor, PatternConstructorNode, Pattern); +class PatternConstructor : public Pattern { + public: + TVM_DEFINE_OBJECT_REF_METHODS(PatternConstructor, Pattern, PatternConstructorNode); +}; /*! \brief A tuple pattern. Matches a tuple, binds recursively. */ class PatternTuple; @@ -181,10 +193,13 @@ class PatternTupleNode : public PatternNode { } static constexpr const char* _type_key = "relay.PatternTuple"; - TVM_DECLARE_NODE_TYPE_INFO(PatternTupleNode, PatternNode); + TVM_DECLARE_FINAL_OBJECT_INFO(PatternTupleNode, PatternNode); }; -RELAY_DEFINE_NODE_REF(PatternTuple, PatternTupleNode, Pattern); +class PatternTuple : public Pattern { + public: + TVM_DEFINE_OBJECT_REF_METHODS(PatternTuple, Pattern, PatternTupleNode); +}; /*! * \brief Stores all data for an Algebraic Data Type (ADT). @@ -225,15 +240,18 @@ class TypeDataNode : public TypeNode { tvm::Array constructors); static constexpr const char* _type_key = "relay.TypeData"; - TVM_DECLARE_NODE_TYPE_INFO(TypeDataNode, TypeNode); + TVM_DECLARE_FINAL_OBJECT_INFO(TypeDataNode, TypeNode); }; -RELAY_DEFINE_NODE_REF(TypeData, TypeDataNode, Type); +class TypeData : public Type { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TypeData, Type, TypeDataNode); +}; /*! \brief A clause in a match expression. */ class Clause; /*! \brief Clause container node. */ -class ClauseNode : public Node { +class ClauseNode : public Object { public: /*! \brief The pattern the clause matches. */ Pattern lhs; @@ -248,10 +266,13 @@ class ClauseNode : public Node { TVM_DLL static Clause make(Pattern lhs, Expr rhs); static constexpr const char* _type_key = "relay.Clause"; - TVM_DECLARE_NODE_TYPE_INFO(ClauseNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(ClauseNode, Object); }; -RELAY_DEFINE_NODE_REF(Clause, ClauseNode, NodeRef); +class Clause : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Clause, ObjectRef, ClauseNode); +}; /*! \brief ADT pattern matching exression. */ class Match; @@ -280,10 +301,13 @@ class MatchNode : public ExprNode { TVM_DLL static Match make(Expr data, tvm::Array pattern, bool complete = true); static constexpr const char* _type_key = "relay.Match"; - TVM_DECLARE_NODE_TYPE_INFO(MatchNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(MatchNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(Match, MatchNode, Expr); +class Match : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Match, Expr, MatchNode); +}; } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index ccdc871e8a782..1c7fc1c454800 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -196,7 +196,7 @@ struct SqueezeAttrs : public tvm::AttrsNode { }; // struct SqueezeAttrs struct SplitAttrs : public tvm::AttrsNode { - NodeRef indices_or_sections; + ObjectRef indices_or_sections; int axis; TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs") { diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index 32f9c32f468a2..d64d05f119bb2 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -53,53 +53,11 @@ namespace relay { (*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \ } -/*! - * \brief We always used NodeRef for referencing nodes. - * - * By default, NodeRef is a std::shared_ptr of node - */ -using NodeRef = tvm::NodeRef; - -/*! - * \brief Content data type. - */ -using DataType = ::tvm::DataType; - /*! * \brief Symbolic expression for tensor shape. */ using IndexExpr = ::tvm::Expr; -/*! - * \brief Hash function for nodes. - * e.g. std::unordered_map - */ -using NodeHash = ::tvm::NodeHash; -/*! - * \brief Equality check function for nodes. - */ -using NodeEqual = ::tvm::NodeEqual; - -/*! - * \brief Macro to make it easy to define node ref type given node - * \param TypeName The name of the reference type. - * \param NodeName The internal container name. - * \param NodeRefBase The base type. - */ -#define RELAY_DEFINE_NODE_REF(TypeName, NodeName, NodeRefBase) \ - class TypeName : public NodeRefBase { \ - public: \ - TypeName() {} \ - explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n) \ - : NodeRefBase(n) { \ - } \ - const NodeName* operator->() const { \ - return static_cast(get()); \ - } \ - operator bool() { return this->defined(); } \ - using ContainerType = NodeName; \ - }; - /*! * \brief The source name in the Span * \sa SourceNameNode, Span @@ -108,7 +66,7 @@ class SourceName; /*! * \brief The name of a source fragment. */ -class SourceNameNode : public Node { +class SourceNameNode : public Object { public: /*! \brief The source name. */ std::string name; @@ -116,20 +74,20 @@ class SourceNameNode : public Node { void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } static constexpr const char* _type_key = "relay.SourceName"; - TVM_DECLARE_NODE_TYPE_INFO(SourceNameNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(SourceNameNode, Object); }; /*! * \brief The source name of a file span. * \sa SourceNameNode, Span */ -class SourceName : public NodeRef { +class SourceName : public ObjectRef { public: /*! \brief default constructor */ SourceName() {} /*! \brief constructor from node pointer */ - explicit SourceName(NodePtr n) : NodeRef(n) {} + explicit SourceName(ObjectPtr n) : ObjectRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -157,7 +115,7 @@ class Span; /*! * \brief Stores locations in frontend source that generated a node. */ -class SpanNode : public Node { +class SpanNode : public Object { public: /*! \brief The source name */ SourceName source; @@ -175,22 +133,25 @@ class SpanNode : public Node { TVM_DLL static Span make(SourceName source, int lineno, int col_offset); static constexpr const char* _type_key = "relay.Span"; - TVM_DECLARE_NODE_TYPE_INFO(SpanNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object); }; -RELAY_DEFINE_NODE_REF(Span, SpanNode, NodeRef); +class Span : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode); +}; /*! * \brief This is the base node container of all relay structures. */ -class RelayNode : public Node { +class RelayNode : public Object { public: /*! \brief The location of the program in a SourceFragment can be null, * check with span.defined() */ mutable Span span; static constexpr const char* _type_key = "relay.Node"; - TVM_DECLARE_BASE_NODE_INFO(RelayNode, Node); + TVM_DECLARE_BASE_OBJECT_INFO(RelayNode, Object); }; /*! @@ -201,7 +162,7 @@ class RelayNode : public Node { * * \note Do not create Id directly, they are created in Var. */ -class IdNode : public Node { +class IdNode : public Object { public: /*! * \brief The name of the variable, @@ -215,10 +176,13 @@ class IdNode : public Node { } static constexpr const char* _type_key = "relay.Id"; - TVM_DECLARE_NODE_TYPE_INFO(IdNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(IdNode, Object); }; -RELAY_DEFINE_NODE_REF(Id, IdNode, NodeRef); +class Id : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Id, ObjectRef, IdNode); +}; struct Module; diff --git a/include/tvm/relay/error.h b/include/tvm/relay/error.h index ef3387b1893b6..4cd999fb44800 100644 --- a/include/tvm/relay/error.h +++ b/include/tvm/relay/error.h @@ -118,7 +118,7 @@ class ErrorReporter { * \param node The expression or type to report the error at. * \param err The error message to report. */ - inline void ReportAt(const GlobalVar& global, const NodeRef& node, std::stringstream& err) { + inline void ReportAt(const GlobalVar& global, const ObjectRef& node, std::stringstream& err) { std::string err_msg = err.str(); this->ReportAt(global, node, Error(err_msg)); } @@ -134,7 +134,7 @@ class ErrorReporter { * \param node The expression or type to report the error at. * \param err The error to report. */ - void ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err); + void ReportAt(const GlobalVar& global, const ObjectRef& node, const Error& err); /*! \brief Render all reported errors and exit the program. * @@ -154,8 +154,8 @@ class ErrorReporter { private: std::vector errors_; - std::unordered_map, NodeHash, NodeEqual> node_to_error_; - std::unordered_map node_to_gv_; + std::unordered_map, ObjectHash, ObjectEqual> node_to_error_; + std::unordered_map node_to_gv_; }; } // namespace relay diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 01a73d5396cc8..47c83696c3e50 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -67,10 +67,13 @@ class ExprNode : public RelayNode { inline const TTypeNode* type_as() const; static constexpr const char* _type_key = "relay.Expr"; - TVM_DECLARE_BASE_NODE_INFO(ExprNode, RelayNode); + TVM_DECLARE_BASE_OBJECT_INFO(ExprNode, RelayNode); }; -RELAY_DEFINE_NODE_REF(Expr, ExprNode, NodeRef); +class Expr : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Expr, ObjectRef, ExprNode); +}; /*! * \brief Constant tensor, backed by an NDArray on the cpu(0) device. @@ -104,10 +107,13 @@ class ConstantNode : public ExprNode { TVM_DLL static Constant make(runtime::NDArray data); static constexpr const char* _type_key = "relay.Constant"; - TVM_DECLARE_NODE_TYPE_INFO(ConstantNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(Constant, ConstantNode, Expr); +class Constant : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Constant, Expr, ConstantNode); +}; /*! \brief Tuple of multiple Exprs */ class Tuple; @@ -126,10 +132,13 @@ class TupleNode : public ExprNode { TVM_DLL static Tuple make(tvm::Array fields); static constexpr const char* _type_key = "relay.Tuple"; - TVM_DECLARE_NODE_TYPE_INFO(TupleNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(TupleNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(Tuple, TupleNode, Expr); +class Tuple : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Tuple, Expr, TupleNode); +}; /*! * \brief Local variables used in the let expression. @@ -179,10 +188,13 @@ class VarNode : public ExprNode { Type type_annotation); static constexpr const char* _type_key = "relay.Var"; - TVM_DECLARE_NODE_TYPE_INFO(VarNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(Var, VarNode, Expr); +class Var : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Var, Expr, VarNode); +}; /*! * \brief Global variable that leaves in the top-level module. @@ -206,10 +218,13 @@ class GlobalVarNode : public ExprNode { TVM_DLL static GlobalVar make(std::string name_hint); static constexpr const char* _type_key = "relay.GlobalVar"; - TVM_DECLARE_NODE_TYPE_INFO(GlobalVarNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(GlobalVar, GlobalVarNode, Expr); +class GlobalVar : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, Expr, GlobalVarNode); +}; /*! * \brief Function (subgraph in computational graph) @@ -297,14 +312,19 @@ class FunctionNode : public ExprNode { tvm::Map GetParams() const; static constexpr const char* _type_key = "relay.Function"; - TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(Function, FunctionNode, Expr); +class Function : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Function, Expr, FunctionNode); +}; -TVM_DLL NodeRef FunctionGetAttr(const Function& func, const std::string& key); -TVM_DLL Function FunctionSetAttr(const Function& func, const std::string& key, const NodeRef& data); +TVM_DLL ObjectRef FunctionGetAttr(const Function& func, const std::string& key); +TVM_DLL Function FunctionSetAttr(const Function& func, + const std::string& key, + const ObjectRef& data); /*! * \brief Call corresponds to operator invocation. @@ -363,10 +383,13 @@ class CallNode : public ExprNode { Array type_args = Array()); static constexpr const char* _type_key = "relay.Call"; - TVM_DECLARE_NODE_TYPE_INFO(CallNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(Call, CallNode, Expr); +class Call : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Call, Expr, CallNode); +}; /*! * \brief Let binding that binds a local var and optionally a type annotation. @@ -401,10 +424,13 @@ class LetNode : public ExprNode { TVM_DLL static Let make(Var var, Expr value, Expr body); static constexpr const char* _type_key = "relay.Let"; - TVM_DECLARE_NODE_TYPE_INFO(LetNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(Let, LetNode, Expr); +class Let : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Let, Expr, LetNode); +}; /*! * \brief Condition expression @@ -439,10 +465,13 @@ class IfNode : public ExprNode { TVM_DLL static If make(Expr cond, Expr true_branch, Expr false_branch); static constexpr const char* _type_key = "relay.If"; - TVM_DECLARE_NODE_TYPE_INFO(IfNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(If, IfNode, Expr); +class If : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(If, Expr, IfNode); +}; /*! \brief Get index-th field out of a tuple. */ class TupleGetItem; @@ -463,10 +492,13 @@ class TupleGetItemNode : public ExprNode { TVM_DLL static TupleGetItem make(Expr tuple, int index); static constexpr const char* _type_key = "relay.TupleGetItem"; - TVM_DECLARE_NODE_TYPE_INFO(TupleGetItemNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr); +class TupleGetItem : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItem, Expr, TupleGetItemNode); +}; /*! \brief Create a new Reference out of initial value. */ class RefCreate; @@ -484,10 +516,13 @@ class RefCreateNode : public ExprNode { TVM_DLL static RefCreate make(Expr value); static constexpr const char* _type_key = "relay.RefCreate"; - TVM_DECLARE_NODE_TYPE_INFO(RefCreateNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(RefCreateNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(RefCreate, RefCreateNode, Expr); +class RefCreate : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(RefCreate, Expr, RefCreateNode); +}; /*! \brief Get value out of Reference. */ class RefRead; @@ -505,10 +540,13 @@ class RefReadNode : public ExprNode { TVM_DLL static RefRead make(Expr ref); static constexpr const char* _type_key = "relay.RefRead"; - TVM_DECLARE_NODE_TYPE_INFO(RefReadNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(RefReadNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(RefRead, RefReadNode, Expr); +class RefRead : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(RefRead, Expr, RefReadNode); +}; /*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */ class RefWrite; class RefWriteNode : public ExprNode { @@ -528,10 +566,13 @@ class RefWriteNode : public ExprNode { TVM_DLL static RefWrite make(Expr ref, Expr value); static constexpr const char* _type_key = "relay.RefWrite"; - TVM_DECLARE_NODE_TYPE_INFO(RefWriteNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(RefWriteNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(RefWrite, RefWriteNode, Expr); +class RefWrite : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(RefWrite, Expr, RefWriteNode); +}; /*! * \brief Base class of the temporary expression. @@ -554,10 +595,13 @@ class TempExprNode : public ExprNode { virtual Expr Realize() const = 0; static constexpr const char* _type_key = "relay.TempExpr"; - TVM_DECLARE_BASE_NODE_INFO(TempExprNode, ExprNode); + TVM_DECLARE_BASE_OBJECT_INFO(TempExprNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(TempExpr, TempExprNode, Expr); +class TempExpr : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TempExpr, Expr, TempExprNode); +}; // implementataions inline const Type& ExprNode::checked_type() const { @@ -583,7 +627,7 @@ inline const TTypeNode* ExprNode::type_as() const { } /*! \brief Pretty print a Relay node, producing a fragment of the Relay text format. */ -std::string PrettyPrint(const NodeRef& node); +std::string PrettyPrint(const ObjectRef& node); /*! * \brief Render the node as a string in the Relay text format. @@ -593,7 +637,7 @@ std::string PrettyPrint(const NodeRef& node); * additional comment block to an expr. * \return The text representation. */ -std::string AsText(const NodeRef& node, +std::string AsText(const ObjectRef& node, bool show_meta_data = true, runtime::TypedPackedFunc annotate = nullptr); diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 722f73f038269..f1d7152f48c01 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -116,7 +116,7 @@ class ExprFunctor { virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExprDefault_(const Node* op, Args...) { + virtual R VisitExprDefault_(const Object* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); throw; } @@ -177,7 +177,7 @@ class ExprVisitor protected: // Internal visiting counter - std::unordered_map visit_counter_; + std::unordered_map visit_counter_; }; /*! @@ -227,7 +227,7 @@ class ExprMutator protected: /*! \brief Internal map used for memoization. */ - std::unordered_map memo_; + std::unordered_map memo_; }; /*! diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index d5d783d4804a4..8ef7f6e4ed891 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -72,13 +72,13 @@ CreateInterpreter(Module mod, DLContext context, Target target); class ValueNode : public RelayNode { public: static constexpr const char* _type_key = "relay.Value"; - TVM_DECLARE_BASE_NODE_INFO(ValueNode, RelayNode); + TVM_DECLARE_BASE_OBJECT_INFO(ValueNode, RelayNode); }; -class Value : public NodeRef { +class Value : public ObjectRef { public: Value() {} - explicit Value(ObjectPtr n) : NodeRef(n) {} + explicit Value(ObjectPtr n) : ObjectRef(n) {} const ValueNode* operator->() const { return static_cast(get()); } @@ -114,10 +114,13 @@ class ClosureNode : public ValueNode { TVM_DLL static Closure make(tvm::Map env, Function func); static constexpr const char* _type_key = "relay.Closure"; - TVM_DECLARE_NODE_TYPE_INFO(ClosureNode, ValueNode); + TVM_DECLARE_FINAL_OBJECT_INFO(ClosureNode, ValueNode); }; -RELAY_DEFINE_NODE_REF(Closure, ClosureNode, Value); +class Closure : public Value { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Closure, Value, ClosureNode); +}; /*! \brief A Relay Recursive Closure. A closure that has a name. */ class RecClosure; @@ -140,10 +143,13 @@ class RecClosureNode : public ValueNode { TVM_DLL static RecClosure make(Closure clos, Var bind); static constexpr const char* _type_key = "relay.RecClosure"; - TVM_DECLARE_NODE_TYPE_INFO(RecClosureNode, ValueNode); + TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureNode, ValueNode); }; -RELAY_DEFINE_NODE_REF(RecClosure, RecClosureNode, Value); +class RecClosure : public Value { + public: + TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, Value, RecClosureNode); +}; /*! \brief A tuple value. */ class TupleValue; @@ -159,10 +165,13 @@ struct TupleValueNode : ValueNode { TVM_DLL static TupleValue make(tvm::Array value); static constexpr const char* _type_key = "relay.TupleValue"; - TVM_DECLARE_NODE_TYPE_INFO(TupleValueNode, ValueNode); + TVM_DECLARE_FINAL_OBJECT_INFO(TupleValueNode, ValueNode); }; -RELAY_DEFINE_NODE_REF(TupleValue, TupleValueNode, Value); +class TupleValue : public Value { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TupleValue, Value, TupleValueNode); +}; /*! \brief A tensor value. */ class TensorValue; @@ -179,10 +188,13 @@ struct TensorValueNode : ValueNode { TVM_DLL static TensorValue make(runtime::NDArray data); static constexpr const char* _type_key = "relay.TensorValue"; - TVM_DECLARE_NODE_TYPE_INFO(TensorValueNode, ValueNode); + TVM_DECLARE_FINAL_OBJECT_INFO(TensorValueNode, ValueNode); }; -RELAY_DEFINE_NODE_REF(TensorValue, TensorValueNode, Value); +class TensorValue : public Value { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TensorValue, Value, TensorValueNode); +}; /*! \brief A reference value. */ class RefValue; @@ -199,10 +211,13 @@ struct RefValueNode : ValueNode { TVM_DLL static RefValue make(Value val); static constexpr const char* _type_key = "relay.RefValue"; - TVM_DECLARE_NODE_TYPE_INFO(RefValueNode, ValueNode); + TVM_DECLARE_FINAL_OBJECT_INFO(RefValueNode, ValueNode); }; -RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value); +class RefValue : public Value { + public: + TVM_DEFINE_OBJECT_REF_METHODS(RefValue, Value, RefValueNode); +}; /*! \brief An ADT constructor value. */ class ConstructorValue; @@ -226,10 +241,13 @@ struct ConstructorValueNode : ValueNode { Constructor construtor = {}); static constexpr const char* _type_key = "relay.ConstructorValue"; - TVM_DECLARE_NODE_TYPE_INFO(ConstructorValueNode, ValueNode); + TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueNode, ValueNode); }; -RELAY_DEFINE_NODE_REF(ConstructorValue, ConstructorValueNode, Value); +class ConstructorValue : public Value { + public: + TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, Value, ConstructorValueNode); +}; } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 0d3f46cd3cc04..262c82df5c5d2 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -258,7 +258,7 @@ class ModuleNode : public RelayNode { const tvm::Map& type_definitions = {}); static constexpr const char* _type_key = "relay.Module"; - TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(ModuleNode, Object); private: /*! \brief Helper function for registering a typedef's constructors */ @@ -285,9 +285,9 @@ class ModuleNode : public RelayNode { std::unordered_set import_set_; }; -struct Module : public NodeRef { +struct Module : public ObjectRef { Module() {} - explicit Module(ObjectPtr<::tvm::Object> p) : NodeRef(p) {} + explicit Module(ObjectPtr<::tvm::Object> p) : ObjectRef(p) {} ModuleNode* operator->() const { return static_cast(get_mutable()); diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 90f2937c929b5..b4495191dd240 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -106,7 +106,7 @@ class OpNode : public relay::ExprNode { } static constexpr const char* _type_key = "relay.Op"; - TVM_DECLARE_NODE_TYPE_INFO(OpNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, ExprNode); private: // friend class @@ -431,7 +431,7 @@ inline OpRegistry& OpRegistry::describe( inline OpRegistry& OpRegistry::add_argument(const std::string& name, const std::string& type, const std::string& description) { - auto n = make_node(); + auto n = make_object(); n->name = name; n->type_info = type; n->description = description; diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h index 54ea707905e57..9cfa755ef8132 100644 --- a/include/tvm/relay/op_attr_types.h +++ b/include/tvm/relay/op_attr_types.h @@ -180,7 +180,7 @@ using FTVMLegalize = runtime::TypedPackedFunc< using FForwardRewrite = runtime::TypedPackedFunc< Expr(const Call& ref_call, const Array& new_args, - const NodeRef& ctx)>; + const ObjectRef& ctx)>; /*! * \brief Gradient for a specific op. diff --git a/include/tvm/relay/pattern_functor.h b/include/tvm/relay/pattern_functor.h index d84d43af82a70..71a024f37a197 100644 --- a/include/tvm/relay/pattern_functor.h +++ b/include/tvm/relay/pattern_functor.h @@ -102,7 +102,7 @@ class PatternFunctor { Args... args) PATTERN_FUNCTOR_DEFAULT; virtual R VisitPattern_(const PatternTupleNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT; - virtual R VisitPatternDefault_(const Node* op, Args...) { + virtual R VisitPatternDefault_(const Object* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); throw; } @@ -162,7 +162,7 @@ class PatternMutator /*! \brief Used to visit the vars inside of patterns. */ virtual Constructor VisitConstructor(const Constructor& c); private: - std::unordered_map var_map_; + std::unordered_map var_map_; }; } // namespace relay diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 52be6a0f37817..2d1e45f8ee0f5 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -109,7 +109,7 @@ class PassContextNode : public RelayNode { } static constexpr const char* _type_key = "relay.PassContext"; - TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode); + TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, RelayNode); }; /*! @@ -125,10 +125,10 @@ class PassContextNode : public RelayNode { * * \endcode */ -class PassContext : public NodeRef { +class PassContext : public ObjectRef { public: PassContext() {} - explicit PassContext(NodePtr<::tvm::Node> n) : NodeRef(n) {} + explicit PassContext(ObjectPtr<::tvm::Object> n) : ObjectRef(n) {} /*! * \brief const accessor. * \return const access pointer. @@ -207,10 +207,13 @@ class PassInfoNode : public RelayNode { tvm::Array required); static constexpr const char* _type_key = "relay.PassInfo"; - TVM_DECLARE_NODE_TYPE_INFO(PassInfoNode, RelayNode); + TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, RelayNode); }; -TVM_DEFINE_NODE_REF(PassInfo, PassInfoNode) +class PassInfo : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode); +}; class Pass; @@ -251,10 +254,10 @@ class PassNode : public RelayNode { void VisitAttrs(tvm::AttrVisitor* v) {} static constexpr const char* _type_key = "relay.Pass"; - TVM_DECLARE_BASE_NODE_INFO(PassNode, RelayNode); + TVM_DECLARE_BASE_OBJECT_INFO(PassNode, RelayNode); }; -class Pass : public NodeRef { +class Pass : public ObjectRef { public: /*! * \brief Transform mod using the default PassContext in the current scope. @@ -283,7 +286,7 @@ class Pass : public NodeRef { return node->operator()(mod, pass_ctx); } - TVM_DEFINE_NODE_REF_METHODS(Pass, NodeRef, PassNode); + TVM_DEFINE_OBJECT_REF_METHODS(Pass, ObjectRef, PassNode); }; class SequentialNode; @@ -309,7 +312,7 @@ class Sequential : public Pass { TVM_DLL Sequential(tvm::Array passes, std::string name = "sequential"); Sequential() = default; - explicit Sequential(tvm::NodePtr<::tvm::Node> n) : Pass(n) {} + explicit Sequential(tvm::ObjectPtr<::tvm::Object> n) : Pass(n) {} const SequentialNode* operator->() const; using ContainerType = Sequential; @@ -638,7 +641,7 @@ TVM_DLL Function InferType(const Function& f, */ TVM_DLL Expr ForwardRewrite(const Expr& expr, const std::string& rewrite_map_attr_name, - std::function fcontext = nullptr, + std::function fcontext = nullptr, std::function fmulti_ref_trigger = nullptr); /*! @@ -655,7 +658,7 @@ TVM_DLL Expr ForwardRewrite(const Expr& expr, */ TVM_DLL Expr ForwardRewrite(const Expr& expr, const FForwardRewrite& rewrite_func, - std::function fcontext = nullptr, + std::function fcontext = nullptr, std::function fmulti_ref_trigger = nullptr); /*! diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index e0c056c1216bd..08fe957d8a78f 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -41,7 +41,7 @@ using Any = tvm::ir::Any; class TypeNode : public RelayNode { public: static constexpr const char* _type_key = "relay.Type"; - TVM_DECLARE_BASE_NODE_INFO(TypeNode, Node); + TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object); }; /*! @@ -55,10 +55,10 @@ class TypeNode : public RelayNode { * There are also advanced types to support generic(polymorphic types), * which can be ignored when first reading the code base. */ -class Type : public NodeRef { +class Type : public ObjectRef { public: Type() {} - explicit Type(ObjectPtr p) : NodeRef(p) {} + explicit Type(ObjectPtr p) : ObjectRef(p) {} using ContainerType = TypeNode; }; @@ -70,10 +70,13 @@ class Type : public NodeRef { class BaseTensorTypeNode : public TypeNode { public: static constexpr const char* _type_key = "relay.BaseTensorType"; - TVM_DECLARE_BASE_NODE_INFO(BaseTensorTypeNode, TypeNode); + TVM_DECLARE_BASE_OBJECT_INFO(BaseTensorTypeNode, TypeNode); }; -RELAY_DEFINE_NODE_REF(BaseTensorType, BaseTensorTypeNode, Type); +class BaseTensorType : public Type { + public: + TVM_DEFINE_OBJECT_REF_METHODS(BaseTensorType, Type, BaseTensorTypeNode); +}; /*! * \brief This is the most commonly used type in relay. @@ -113,10 +116,13 @@ class TensorTypeNode : public BaseTensorTypeNode { TVM_DLL static TensorType Scalar(DataType dtype); static constexpr const char* _type_key = "relay.TensorType"; - TVM_DECLARE_NODE_TYPE_INFO(TensorTypeNode, BaseTensorTypeNode); + TVM_DECLARE_FINAL_OBJECT_INFO(TensorTypeNode, BaseTensorTypeNode); }; -RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type); +class TensorType : public Type { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TensorType, Type, TensorTypeNode); +}; /*! \brief Possible kinds of Type. */ enum Kind : int { @@ -168,10 +174,13 @@ class TypeVarNode : public TypeNode { TVM_DLL static TypeVar make(std::string name, Kind kind); static constexpr const char* _type_key = "relay.TypeVar"; - TVM_DECLARE_NODE_TYPE_INFO(TypeVarNode, TypeNode); + TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode); }; -RELAY_DEFINE_NODE_REF(TypeVar, TypeVarNode, Type); +class TypeVar : public Type { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TypeVar, Type, TypeVarNode); +}; /*! * \brief A global type variable that is used for defining new types or type aliases. @@ -197,10 +206,13 @@ class GlobalTypeVarNode : public TypeNode { TVM_DLL static GlobalTypeVar make(std::string name, Kind kind); static constexpr const char* _type_key = "relay.GlobalTypeVar"; - TVM_DECLARE_NODE_TYPE_INFO(GlobalTypeVarNode, TypeNode); + TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode); }; -RELAY_DEFINE_NODE_REF(GlobalTypeVar, GlobalTypeVarNode, Type); +class GlobalTypeVar : public Type { + public: + TVM_DEFINE_OBJECT_REF_METHODS(GlobalTypeVar, Type, GlobalTypeVarNode); +}; /*! * \brief Type application. @@ -225,10 +237,13 @@ class TypeCallNode : public TypeNode { TVM_DLL static TypeCall make(Type func, tvm::Array args); static constexpr const char* _type_key = "relay.TypeCall"; - TVM_DECLARE_NODE_TYPE_INFO(TypeCallNode, TypeNode); + TVM_DECLARE_FINAL_OBJECT_INFO(TypeCallNode, TypeNode); }; -RELAY_DEFINE_NODE_REF(TypeCall, TypeCallNode, Type); +class TypeCall : public Type { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TypeCall, Type, TypeCallNode); +}; /*! * \brief IncompleteType. @@ -253,10 +268,13 @@ class IncompleteTypeNode : public TypeNode { TVM_DLL static IncompleteType make(Kind kind); static constexpr const char* _type_key = "relay.IncompleteType"; - TVM_DECLARE_NODE_TYPE_INFO(IncompleteTypeNode, TypeNode); + TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode); }; -RELAY_DEFINE_NODE_REF(IncompleteType, IncompleteTypeNode, Type); +class IncompleteType : public Type { + public: + TVM_DEFINE_OBJECT_REF_METHODS(IncompleteType, Type, IncompleteTypeNode); +}; /*! * \brief Potential Constraints in the type. @@ -267,10 +285,13 @@ class TypeConstraint; class TypeConstraintNode : public TypeNode { public: static constexpr const char* _type_key = "relay.TypeConstraint"; - TVM_DECLARE_BASE_NODE_INFO(TypeConstraintNode, TypeNode); + TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode); }; -RELAY_DEFINE_NODE_REF(TypeConstraint, TypeConstraintNode, Type); +class TypeConstraint : public Type { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TypeConstraint, Type, TypeConstraintNode); +}; class FuncType; /*! @@ -311,10 +332,13 @@ class FuncTypeNode : public TypeNode { tvm::Array type_constraints); static constexpr const char* _type_key = "relay.FuncType"; - TVM_DECLARE_NODE_TYPE_INFO(FuncTypeNode, TypeNode); + TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode); }; -RELAY_DEFINE_NODE_REF(FuncType, FuncTypeNode, Type); +class FuncType : public Type { + public: + TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode); +}; /*! * \brief The type of tuple values. @@ -338,10 +362,13 @@ class TupleTypeNode : public TypeNode { TVM_DLL static TupleType make(tvm::Array fields); static constexpr const char* _type_key = "relay.TupleType"; - TVM_DECLARE_NODE_TYPE_INFO(TupleTypeNode, TypeNode); + TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode); }; -RELAY_DEFINE_NODE_REF(TupleType, TupleTypeNode, Type); +class TupleType : public Type { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TupleType, Type, TupleTypeNode); +}; /*! * \brief The type of reference values. @@ -365,10 +392,13 @@ class RefTypeNode : public TypeNode { TVM_DLL static RefType make(Type value); static constexpr const char* _type_key = "relay.RefType"; - TVM_DECLARE_NODE_TYPE_INFO(RefTypeNode, TypeNode); + TVM_DECLARE_FINAL_OBJECT_INFO(RefTypeNode, TypeNode); }; -RELAY_DEFINE_NODE_REF(RefType, RefTypeNode, Type); +class RefType : public Type { + public: + TVM_DEFINE_OBJECT_REF_METHODS(RefType, Type, RefTypeNode); +}; class TypeReporter; @@ -376,7 +406,7 @@ class TypeReporter; * \brief reporter that reports back to the * type resolution information. */ -class TypeReporterNode : public Node { +class TypeReporterNode : public Object { public: /*! * \brief Create a type equality constraint. @@ -408,7 +438,7 @@ class TypeReporterNode : public Node { * \brief Set the location at which to report unification errors. * \param ref The program node to report the error. */ - TVM_DLL virtual void SetLocation(const NodeRef& ref) = 0; + TVM_DLL virtual void SetLocation(const ObjectRef& ref) = 0; /*! * \brief Retrieve the current global module. @@ -420,17 +450,17 @@ class TypeReporterNode : public Node { void VisitAttrs(tvm::AttrVisitor* v) {} static constexpr const char* _type_key = "relay.TypeReporter"; - TVM_DECLARE_NODE_TYPE_INFO(TypeReporterNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(TypeReporterNode, Object); }; /*! * \brief Container class of TypeReporter. * \sa TypeReporterNode */ -class TypeReporter : public NodeRef { +class TypeReporter : public ObjectRef { public: TypeReporter() {} - explicit TypeReporter(::tvm::ObjectPtr<::tvm::Object> n) : NodeRef(n) { + explicit TypeReporter(::tvm::ObjectPtr<::tvm::Object> n) : ObjectRef(n) { } TypeReporterNode* operator->() const { return const_cast( @@ -502,10 +532,13 @@ class TypeRelationNode : public TypeConstraintNode { Attrs attrs); static constexpr const char* _type_key = "relay.TypeRelation"; - TVM_DECLARE_NODE_TYPE_INFO(TypeRelationNode, TypeConstraintNode); + TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode); }; -RELAY_DEFINE_NODE_REF(TypeRelation, TypeRelationNode, TypeConstraint); +class TypeRelation : public TypeConstraint { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TypeRelation, TypeConstraint, TypeRelationNode); +}; // The following fields contains advanced typing // Only keep the class name and reserved for future usage. diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 96215daf4a7ac..7d1494707af8b 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -700,7 +700,12 @@ struct ObjectEqual { TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = \ TypeName::_GetOrAllocRuntimeTypeIndex() - +/* + * \brief Define object reference methods. + * \param TypeName The object type name + * \param ParentType The parent type of the objectref + * \param ObjectName The type name of the object. + */ #define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ TypeName() {} \ explicit TypeName( \ @@ -712,17 +717,54 @@ struct ObjectEqual { operator bool() const { return data_ != nullptr; } \ using ContainerType = ObjectName; -#define TVM_DEFINE_OBJECT_REF_METHODS_MUT(TypeName, ParentType, ObjectName) \ +/* + * \brief Define object reference methods of whose content is mutable. + * \param TypeName The object type name + * \param ParentType The parent type of the objectref + * \param ObjectName The type name of the object. + * \note We recommend making objects immutable when possible. + * This macro is only reserved for objects that stores runtime states. + */ +#define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ TypeName() {} \ explicit TypeName( \ ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ : ParentType(n) {} \ - ObjectName* operator->() { \ + ObjectName* operator->() const { \ return static_cast(data_.get()); \ } \ operator bool() const { return data_ != nullptr; } \ using ContainerType = ObjectName; +/*! + * \brief Define CopyOnWrite function in an ObjectRef. + * \param ObjectName The Type of the Node. + * + * CopyOnWrite will generate a unique copy of the internal node. + * The node will be copied if it is referenced by multiple places. + * The function returns the raw pointer to the node to allow modification + * of the content. + * + * \code + * + * MyCOWObjectRef ref, ref2; + * ref2 = ref; + * ref.CopyOnWrite()->value = new_value; + * assert(ref2->value == old_value); + * assert(ref->value == new_value); + * + * \endcode + */ +#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName) \ + ObjectName* CopyOnWrite() { \ + CHECK(data_ != nullptr); \ + if (!data_.unique()) { \ + auto n = make_object(*(operator->())); \ + ObjectPtr(std::move(n)).swap(data_); \ + } \ + return static_cast(data_.get()); \ + } + // Implementations details below // Object reference counting. #if TVM_OBJECT_ATOMIC_REF_COUNTER @@ -832,10 +874,6 @@ inline SubRef Downcast(BaseRef ref) { } } // namespace runtime - -template -using NodePtr = runtime::ObjectPtr; - } // namespace tvm #endif // TVM_RUNTIME_OBJECT_H_ diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index 3f4ee38a76952..01caf5a02c91b 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -53,10 +53,10 @@ enum AttachType : int { }; /*! \brief Stage, contains scheduling for a stage of computation. */ -class Stage : public NodeRef { +class Stage : public ObjectRef { public: Stage() {} - explicit Stage(ObjectPtr n) : NodeRef(n) {} + explicit Stage(ObjectPtr n) : ObjectRef(n) {} /*! * \brief create a new schedule for op. * \param op The operator in the schedule @@ -277,10 +277,10 @@ class Stage : public NodeRef { * For operations and all the operations they depend on. * The schedule per Operation is named as stage. */ -class Schedule : public NodeRef { +class Schedule : public ObjectRef { public: Schedule() {} - explicit Schedule(ObjectPtr n) : NodeRef(n) {} + explicit Schedule(ObjectPtr n) : ObjectRef(n) {} /*! * \brief Get a copy of current schedule. * \return The copied schedule. @@ -400,10 +400,10 @@ class Schedule : public NodeRef { * \brief The schedule relation between IterVars * can be Split, Fuse. */ -class IterVarRelation : public NodeRef { +class IterVarRelation : public ObjectRef { public: IterVarRelation() {} - explicit IterVarRelation(ObjectPtr n) : NodeRef(n) {} + explicit IterVarRelation(ObjectPtr n) : ObjectRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -414,10 +414,10 @@ class IterVarRelation : public NodeRef { /*! * \brief Additional scheduable attributes about IterVar. */ -class IterVarAttr : public NodeRef { +class IterVarAttr : public ObjectRef { public: IterVarAttr() {} - explicit IterVarAttr(ObjectPtr n) : NodeRef(n) {} + explicit IterVarAttr(ObjectPtr n) : ObjectRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -440,7 +440,7 @@ class IterVarAttr : public NodeRef { * * The group stage node can be attached to IterVars as in normal stage. */ -class StageNode : public Node { +class StageNode : public Object { public: /*! * \brief The operation of stage, can be different from original op. @@ -515,11 +515,11 @@ class StageNode : public Node { } static constexpr const char* _type_key = "Stage"; - TVM_DECLARE_NODE_TYPE_INFO(StageNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object); }; /*! \brief node container for schedule */ -class ScheduleNode : public Node { +class ScheduleNode : public Object { public: /*! \brief The output operations in original data flow graph */ Array outputs; @@ -538,7 +538,7 @@ class ScheduleNode : public Node { * \brief Internal stage map to map internal ops to stages. * This is created on demand and can be invalidated. */ - std::unordered_map op2stage_cache_; + std::unordered_map op2stage_cache_; void VisitAttrs(AttrVisitor* v) { v->Visit("outputs", &outputs); @@ -576,7 +576,7 @@ class ScheduleNode : public Node { TVM_DLL static Schedule make(Array ops); static constexpr const char* _type_key = "Schedule"; - TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, Object); }; /*! @@ -589,7 +589,7 @@ inline Schedule create_schedule(Array ops) { } /*! \brief node container for IterVar attr */ -class IterVarAttrNode : public Node { +class IterVarAttrNode : public Object { public: /*! \brief The iteration type. */ IterVarType iter_type{kDataPar}; @@ -630,14 +630,14 @@ class IterVarAttrNode : public Node { } static constexpr const char* _type_key = "IterVarAttr"; - TVM_DECLARE_NODE_TYPE_INFO(IterVarAttrNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(IterVarAttrNode, Object); }; /*! \brief base node of iteration var */ -class IterVarRelationNode : public Node { +class IterVarRelationNode : public Object { public: static constexpr const char* _type_key = "IterVarRelation"; - TVM_DECLARE_BASE_NODE_INFO(IterVarRelationNode, Node); + TVM_DECLARE_BASE_OBJECT_INFO(IterVarRelationNode, Object); }; /*! @@ -672,7 +672,7 @@ class SplitNode : public IterVarRelationNode { Expr nparts); static constexpr const char* _type_key = "Split"; - TVM_DECLARE_NODE_TYPE_INFO(SplitNode, IterVarRelationNode); + TVM_DECLARE_FINAL_OBJECT_INFO(SplitNode, IterVarRelationNode); }; /*! @@ -697,7 +697,7 @@ class FuseNode : public IterVarRelationNode { IterVar outer, IterVar inner, IterVar fused); static constexpr const char* _type_key = "Fuse"; - TVM_DECLARE_NODE_TYPE_INFO(FuseNode, IterVarRelationNode); + TVM_DECLARE_FINAL_OBJECT_INFO(FuseNode, IterVarRelationNode); }; /*! @@ -720,7 +720,7 @@ class RebaseNode : public IterVarRelationNode { static IterVarRelation make(IterVar parent, IterVar rebased); static constexpr const char* _type_key = "Rebase"; - TVM_DECLARE_NODE_TYPE_INFO(RebaseNode, IterVarRelationNode); + TVM_DECLARE_FINAL_OBJECT_INFO(RebaseNode, IterVarRelationNode); }; @@ -739,7 +739,7 @@ class SingletonNode : public IterVarRelationNode { static IterVarRelation make(IterVar iter); static constexpr const char* _type_key = "Singleton"; - TVM_DECLARE_NODE_TYPE_INFO(SingletonNode, IterVarRelationNode); + TVM_DECLARE_FINAL_OBJECT_INFO(SingletonNode, IterVarRelationNode); }; diff --git a/include/tvm/target_info.h b/include/tvm/target_info.h index 86cb0e2756094..25fb7243eaf2d 100644 --- a/include/tvm/target_info.h +++ b/include/tvm/target_info.h @@ -34,7 +34,7 @@ namespace tvm { * \brief Memory information of special memory region. * Use MemoryInfo as its container type */ -struct MemoryInfoNode : public Node { +struct MemoryInfoNode : public Object { /*! \brief The addressable unit */ int unit_bits; /*! \brief Maximum number of bits supported in the memory */ @@ -55,11 +55,14 @@ struct MemoryInfoNode : public Node { } static constexpr const char* _type_key = "MemoryInfo"; - TVM_DECLARE_NODE_TYPE_INFO(MemoryInfoNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(MemoryInfoNode, Object); }; /*! \brief Defines memory info */ -TVM_DEFINE_NODE_REF(MemoryInfo, MemoryInfoNode); +class MemoryInfo : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(MemoryInfo, ObjectRef, MemoryInfoNode); +}; /*! * \brief get memory info given scope diff --git a/include/tvm/tensor.h b/include/tvm/tensor.h index f44498a0aa7a5..d6e93f567e50a 100644 --- a/include/tvm/tensor.h +++ b/include/tvm/tensor.h @@ -46,11 +46,11 @@ class OperationNode; * \brief Tensor structure representing a possible input, * or intermediate computation result. */ -class Tensor : public NodeRef { +class Tensor : public ObjectRef { public: /*! \brief default constructor, used internally */ Tensor() {} - explicit Tensor(ObjectPtr n) : NodeRef(n) {} + explicit Tensor(ObjectPtr n) : ObjectRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -158,7 +158,7 @@ class Operation : public ir::FunctionRef { }; /*! \brief Node to represent a tensor */ -class TensorNode : public Node { +class TensorNode : public Object { public: /*! \brief The shape of the tensor */ Array shape; @@ -183,7 +183,7 @@ class TensorNode : public Node { int value_index); static constexpr const char* _type_key = "Tensor"; - TVM_DECLARE_NODE_TYPE_INFO(TensorNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, Object); }; @@ -250,13 +250,13 @@ DEFINE_OVERLOAD_SLICE_BINARY_OP(<); // NOLINT(*) namespace std { template <> -struct hash<::tvm::Operation> : public ::tvm::NodeHash { +struct hash<::tvm::Operation> : public ::tvm::ObjectHash { }; template <> struct hash<::tvm::Tensor> { std::size_t operator()(const ::tvm::Tensor& k) const { - ::tvm::NodeHash hasher; + ::tvm::ObjectHash hasher; if (k.defined() && k->op.defined()) { return hasher(k->op); } else{ diff --git a/include/tvm/tensor_intrin.h b/include/tvm/tensor_intrin.h index 0d4795ad54409..f973909ae398e 100644 --- a/include/tvm/tensor_intrin.h +++ b/include/tvm/tensor_intrin.h @@ -34,10 +34,10 @@ namespace tvm { class TensorIntrinNode; /*! \brief Tensor intrinsic node. */ -class TensorIntrin : public NodeRef { +class TensorIntrin : public ObjectRef { public: TensorIntrin() {} - explicit TensorIntrin(NodePtr n) : NodeRef(n) {} + explicit TensorIntrin(ObjectPtr n) : ObjectRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -49,7 +49,7 @@ class TensorIntrin : public NodeRef { }; /*! \brief Node to represent a Tensor intrinsic operator */ -class TensorIntrinNode : public Node { +class TensorIntrinNode : public Object { public: /*! \brief The name of the intrinsic */ std::string name; @@ -108,7 +108,7 @@ class TensorIntrinNode : public Node { Stmt reduce_update); static constexpr const char* _type_key = "TensorIntrin"; - TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object); }; inline const TensorIntrinNode* TensorIntrin::operator->() const { @@ -119,10 +119,10 @@ inline const TensorIntrinNode* TensorIntrin::operator->() const { class TensorIntrinCallNode; /*! \brief Tensor intrinsic calling node. */ -class TensorIntrinCall : public NodeRef { +class TensorIntrinCall : public ObjectRef { public: TensorIntrinCall() {} - explicit TensorIntrinCall(NodePtr n) : NodeRef(n) {} + explicit TensorIntrinCall(ObjectPtr n) : ObjectRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -133,7 +133,7 @@ class TensorIntrinCall : public NodeRef { using ContainerType = TensorIntrinCallNode; }; -class TensorIntrinCallNode : public Node { +class TensorIntrinCallNode : public Object { public: /*! \brief the tensor intrinsic */ TensorIntrin intrin; @@ -166,7 +166,7 @@ class TensorIntrinCallNode : public Node { Array scalar_inputs); static constexpr const char* _type_key = "TensorIntrinCall"; - TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinCallNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinCallNode, Object); }; inline const TensorIntrinCallNode* TensorIntrinCall::operator->() const { diff --git a/nnvm/include/nnvm/graph.h b/nnvm/include/nnvm/graph.h index 6bda2f57c4bff..1911a0337ac2f 100644 --- a/nnvm/include/nnvm/graph.h +++ b/nnvm/include/nnvm/graph.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -306,7 +306,7 @@ void PostOrderDFSVisit(const std::vector& heads, template inline void DFSVisit(const std::vector& heads, FVisit fvisit) { - typedef const NodePtr* GNode; + typedef const ObjectPtr* GNode; std::vector head_nodes(heads.size()); std::transform(heads.begin(), heads.end(), head_nodes.begin(), [](const NodeEntry& e)->GNode { diff --git a/nnvm/include/nnvm/node.h b/nnvm/include/nnvm/node.h index 220115bd0c9d1..95a7ce23e4da0 100644 --- a/nnvm/include/nnvm/node.h +++ b/nnvm/include/nnvm/node.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -40,22 +40,22 @@ class Node; class Symbol; /*! - * \brief we always used NodePtr for a reference pointer + * \brief we always used ObjectPtr for a reference pointer * to the node, so this alias can be changed in case. * - * By default, NodePtr is a std::shared_ptr of node + * By default, ObjectPtr is a std::shared_ptr of node */ -using NodePtr = std::shared_ptr; +using ObjectPtr = std::shared_ptr; /*! \brief an entry that represents output data from a node */ struct NodeEntry { - NodeEntry(NodePtr node, uint32_t index, uint32_t version): + NodeEntry(ObjectPtr node, uint32_t index, uint32_t version): node(std::move(node)), index(index), version(version) {} - explicit NodeEntry(NodePtr node): + explicit NodeEntry(ObjectPtr node): node(std::move(node)), index(), version() @@ -72,7 +72,7 @@ struct NodeEntry { {} /*! \brief the source node of this data */ - NodePtr node; + ObjectPtr node; /*! \brief index of output from the source. */ uint32_t index; /*! @@ -169,7 +169,7 @@ class NNVM_DLL Node { * \brief Optional control flow dependencies * Gives operation must be performed before this operation. */ - std::vector control_deps; + std::vector control_deps; /*! \brief additional fields for this node */ any info; /*! \brief destructor of node */ @@ -191,7 +191,7 @@ class NNVM_DLL Node { * \return a created empty node. */ template - static NodePtr Create(Args&&... args) { + static ObjectPtr Create(Args&&... args) { return std::make_shared(std::forward(args)...); } }; @@ -210,7 +210,7 @@ inline NodeEntry MakeNode( std::vector inputs, std::unordered_map attrs = std::unordered_map()) { - NodePtr p = Node::Create(); + ObjectPtr p = Node::Create(); p->attrs.op = nnvm::Op::Get(op_name); p->attrs.name = std::move(node_name); p->attrs.dict = attrs; diff --git a/nnvm/include/nnvm/op_attr_types.h b/nnvm/include/nnvm/op_attr_types.h index 04b71b75e1ac2..bf001e0f1be79 100644 --- a/nnvm/include/nnvm/op_attr_types.h +++ b/nnvm/include/nnvm/op_attr_types.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -192,7 +192,7 @@ using FIgnoreInputs = std::function< * \note Register under "FGradient" */ using FGradient = std::function( - const NodePtr& nodeptr, + const ObjectPtr& nodeptr, const std::vector& out_grads)>; /*! @@ -204,7 +204,7 @@ using FGradient = std::function( */ using FSetInputVarAttrOnCompose = std::function; /*! diff --git a/nnvm/include/nnvm/symbolic.h b/nnvm/include/nnvm/symbolic.h index dda79d468173c..d3555ec726b27 100644 --- a/nnvm/include/nnvm/symbolic.h +++ b/nnvm/include/nnvm/symbolic.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -97,7 +97,7 @@ class NNVM_DLL Symbol { * \return The arguments list of this symbol, they can be either named or unnamed (empty string). * \sa ListInputOption */ - std::vector ListInputs(ListInputOption option) const; + std::vector ListInputs(ListInputOption option) const; /*! * \brief List the input names. * diff --git a/nnvm/src/c_api/c_api_symbolic.cc b/nnvm/src/c_api/c_api_symbolic.cc index ae819480eff87..7ca56035acaeb 100644 --- a/nnvm/src/c_api/c_api_symbolic.cc +++ b/nnvm/src/c_api/c_api_symbolic.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -259,7 +259,7 @@ int NNSymbolListInputVariables(SymbolHandle symbol, Symbol *s = static_cast(symbol); NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); - std::vector vs = s->ListInputs(Symbol::ListInputOption(option)); + std::vector vs = s->ListInputs(Symbol::ListInputOption(option)); ret->ret_handles.resize(0); ret->ret_handles.reserve(vs.size()); for (size_t i = 0; i < vs.size(); ++i) { diff --git a/nnvm/src/core/graph.cc b/nnvm/src/core/graph.cc index 829924ea7d5c4..8930e49ecc58f 100644 --- a/nnvm/src/core/graph.cc +++ b/nnvm/src/core/graph.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -50,7 +50,7 @@ static void SubgraphSanityCheck(const std::vector> &subg next_level.clear(); for (const std::vector *graph_ptr : curr_level) { const std::vector &graph = *graph_ptr; - DFSVisit(graph, [&next_level, &node2level, level](const NodePtr& n) { + DFSVisit(graph, [&next_level, &node2level, level](const ObjectPtr& n) { nnvm::Node *node = n.get(); // if the node is visited, but on a different level, then check failed // if check failed here or before, we stop doing anything, but raise an error @@ -74,7 +74,7 @@ IndexedGraph::IndexedGraph(const Graph &g) { std::vector> subgraphs; DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr, &subgraphs] - (const NodePtr& n) { + (const ObjectPtr& n) { const auto& is_ghost = Op::GetAttr("TIsGhost"); if (!n->is_variable() && is_ghost.get(n->op(), false)) return; CHECK_LT(nodes_.size(), std::numeric_limits::max()); diff --git a/nnvm/src/core/node.cc b/nnvm/src/core/node.cc index 59e35243d8f81..32d5e7f913b34 100644 --- a/nnvm/src/core/node.cc +++ b/nnvm/src/core/node.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -30,7 +30,7 @@ Node::~Node() { // explicit deletion via DFS // this is used to avoid stackoverflow caused by chain of deletions std::vector stack{this}; - std::vector to_delete; + std::vector to_delete; while (!stack.empty()) { Node* n = stack.back(); stack.pop_back(); @@ -42,7 +42,7 @@ Node::~Node() { e.node.reset(); } } - for (NodePtr& sp : n->control_deps) { + for (ObjectPtr& sp : n->control_deps) { if (sp.unique()) { stack.push_back(sp.get()); to_delete.emplace_back(std::move(sp)); diff --git a/nnvm/src/core/symbolic.cc b/nnvm/src/core/symbolic.cc index 884dae7372f88..86dc7e63c4034 100644 --- a/nnvm/src/core/symbolic.cc +++ b/nnvm/src/core/symbolic.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -36,8 +36,8 @@ struct VariableParam { uint32_t version{0}; }; -NodePtr CreateVariableNode(const std::string& name) { - NodePtr n = Node::Create(); +ObjectPtr CreateVariableNode(const std::string& name) { + ObjectPtr n = Node::Create(); n->attrs.op = nullptr; n->attrs.name = name; n->attrs.parsed = VariableParam(); @@ -114,10 +114,10 @@ inline bool IsAtomic(const std::vector& outputs) { // public functions Symbol Symbol::Copy() const { - std::unordered_map old_new; + std::unordered_map old_new; // use DFSVisit to copy all the nodes - DFSVisit(this->outputs, [&old_new](const NodePtr& node) { - NodePtr np = Node::Create(); + DFSVisit(this->outputs, [&old_new](const ObjectPtr& node) { + ObjectPtr np = Node::Create(); np->attrs = node->attrs; old_new[node.get()] = std::move(np); }); @@ -127,7 +127,7 @@ Symbol Symbol::Copy() const { Node *ptr = e.node.get(); kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index, e.version}); } - for (const NodePtr& p : kv.first->control_deps) { + for (const ObjectPtr& p : kv.first->control_deps) { kv.second->control_deps.emplace_back(old_new[p.get()]); } } @@ -155,7 +155,7 @@ void Symbol::Print(std::ostream &os) const { os << "\toutput[" << i << "]=" << outputs[i].node->attrs.name << '(' << outputs[i].index << ")\n"; } - DFSVisit(this->outputs, [&os](const NodePtr& node) { + DFSVisit(this->outputs, [&os](const ObjectPtr& node) { if (node->is_variable()) { os << "Variable:" << node->attrs.name << '\n'; } else { @@ -204,21 +204,21 @@ Symbol Symbol::operator[] (size_t index) const { } } -std::vector Symbol::ListInputs(ListInputOption option) const { - std::vector ret; +std::vector Symbol::ListInputs(ListInputOption option) const { + std::vector ret; if (option == kAll) { ret.reserve(this->outputs.size()); - DFSVisit(this->outputs, [&ret](const NodePtr &node) { + DFSVisit(this->outputs, [&ret](const ObjectPtr &node) { if (node->is_variable()) { ret.push_back(node); } }); } else { std::unordered_set mutable_set; - std::vector vlist; + std::vector vlist; vlist.reserve(this->outputs.size()); static auto& fmutate_inputs = Op::GetAttr("FMutateInputs"); - DFSVisit(this->outputs, [&mutable_set, &vlist](const NodePtr &node) { + DFSVisit(this->outputs, [&mutable_set, &vlist](const ObjectPtr &node) { if (node->is_variable()) { vlist.push_back(node); } else if (fmutate_inputs.count(node->op())) { @@ -228,7 +228,7 @@ std::vector Symbol::ListInputs(ListInputOption option) const { } }); ret.reserve(vlist.size()); - for (const NodePtr& node : vlist) { + for (const ObjectPtr& node : vlist) { if ((option == kReadOnlyArgs && mutable_set.count(node.get()) == 0) || (option == kAuxiliaryStates && mutable_set.count(node.get()) != 0)) { ret.emplace_back(node); @@ -239,7 +239,7 @@ std::vector Symbol::ListInputs(ListInputOption option) const { } std::vector Symbol::ListInputNames(ListInputOption option) const { - std::vector inputs = ListInputs(option); + std::vector inputs = ListInputs(option); std::vector ret(inputs.size()); for (size_t i = 0; i < inputs.size(); ++i) { ret[i] = inputs[i]->attrs.name; @@ -416,7 +416,7 @@ void Symbol::Compose(const array_view& args, std::unordered_map replace_map; // replace map stores the existing replacement plan for arguments node auto find_replace_map = [&nmatched, &arg_counter, &args, &kwargs, &replace_map] - (const NodePtr &node) { + (const ObjectPtr &node) { if (node->is_variable()) { if (arg_counter < args.size()) { replace_map[node.get()] = &(args[arg_counter]->outputs[0]); @@ -437,7 +437,7 @@ void Symbol::Compose(const array_view& args, std::vector update_nodes; std::vector > replace_plan; auto find_replace_plan = [&replace_map, &replace_plan, &update_nodes] - (const NodePtr &node) { + (const ObjectPtr &node) { // visit all the childs, find possible replacement bool repl = false; for (size_t i = 0; i < node->inputs.size(); ++i) { @@ -499,7 +499,7 @@ void Symbol::AddControlDeps(const Symbol& src) { Symbol Symbol::GetInternals() const { static auto& fnum_vis_output = Op::GetAttr("FNumVisibleOutputs"); Symbol ret; - DFSVisit(this->outputs, [&ret](const NodePtr& node) { + DFSVisit(this->outputs, [&ret](const ObjectPtr& node) { Node* n = node.get(); if (n->is_variable()) { // grab version from variable. @@ -582,7 +582,7 @@ bool Symbol::GetAttr(const std::string& key, std::string* out) const { std::unordered_map Symbol::ListAttrs(ListAttrOption option) const { if (option == kRecursive) { std::unordered_map ret; - DFSVisit(this->outputs, [&ret](const NodePtr& n) { + DFSVisit(this->outputs, [&ret](const ObjectPtr& n) { for (const auto& it : n->attrs.dict) { ret[n->attrs.name + symbol_constants::kNamespaceSeparator + it.first] = it.second; } @@ -596,7 +596,7 @@ std::unordered_map Symbol::ListAttrs(ListAttrOption op std::vector > Symbol::ListAttrsRecursive() const { std::vector > ret; - DFSVisit(this->outputs, [&ret](const NodePtr& n) { + DFSVisit(this->outputs, [&ret](const ObjectPtr& n) { for (const auto& it : n->attrs.dict) { ret.emplace_back(std::make_tuple(n->attrs.name, it.first, it.second)); } @@ -608,7 +608,7 @@ Symbol Symbol::CreateFunctor(const Op* op, std::unordered_map attrs) { static auto& fnum_vis_output = Op::GetAttr("FNumVisibleOutputs"); Symbol s; - NodePtr n = Node::Create(); + ObjectPtr n = Node::Create(); n->attrs.op = op; n->attrs.dict = std::move(attrs); if (n->op()->attr_parser != nullptr) { @@ -628,7 +628,7 @@ Symbol Symbol::CreateFunctor(const Op* op, Symbol Symbol::CreateFunctor(const NodeAttrs& attrs) { static auto& fnum_vis_output = Op::GetAttr("FNumVisibleOutputs"); Symbol s; - NodePtr n = Node::Create(); + ObjectPtr n = Node::Create(); n->attrs = attrs; uint32_t nout = n->num_outputs(); diff --git a/nnvm/src/pass/correct_layout.cc b/nnvm/src/pass/correct_layout.cc index ae1a6971cd047..e988ebd879150 100644 --- a/nnvm/src/pass/correct_layout.cc +++ b/nnvm/src/pass/correct_layout.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -30,11 +30,11 @@ namespace nnvm { namespace pass { -nnvm::NodePtr CreateLayoutTransformNode(const Layout& src, +nnvm::ObjectPtr CreateLayoutTransformNode(const Layout& src, const Layout& dst) { static const nnvm::Op* trans_op = nnvm::Op::Get("__layout_transform__"); static int count = 0; - nnvm::NodePtr n = nnvm::Node::Create(); + nnvm::ObjectPtr n = nnvm::Node::Create(); n->attrs.op = trans_op; n->attrs.name = src.name() + "_to_" + dst.name() + std::to_string(count++); n->attrs.dict["src_layout"] = src.name(); @@ -56,14 +56,14 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) { nnvm::Op::GetAttr("FCorrectLayoutEx"); const IndexedGraph& idx = src.indexed_graph(); - std::vector mirror_vec(idx.num_nodes(), nullptr); + std::vector mirror_vec(idx.num_nodes(), nullptr); - // (new) NodePtr -> output_layouts + // (new) ObjectPtr -> output_layouts LayoutAttrDict new_layouts; for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { const auto& inode = idx[nid]; - nnvm::NodePtr new_node = nnvm::Node::Create(); + nnvm::ObjectPtr new_node = nnvm::Node::Create(); *new_node = *(inode.source); if (new_node->is_variable()) { // Variable node. No operator. Only one output entry. @@ -87,7 +87,7 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) { std::vector request_ilayouts(num_inputs, Layout::Undef()); for (size_t i = 0; i < num_inputs; ++i) { const IndexedGraph::NodeEntry& input_entry = inode.inputs[i]; - const NodePtr& new_input_node = mirror_vec[input_entry.node_id]; + const ObjectPtr& new_input_node = mirror_vec[input_entry.node_id]; CHECK(new_input_node != nullptr); // fill inputs by previous node (DFS order) inferred layouts. @@ -138,14 +138,14 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) { for (uint32_t i = 0; i < inode.inputs.size(); ++i) { const auto& e = inode.inputs[i]; - const nnvm::NodePtr& in = mirror_vec[e.node_id]; + const nnvm::ObjectPtr& in = mirror_vec[e.node_id]; new_node->inputs[i] = nnvm::NodeEntry{in, e.index, e.version}; // insert layout_transform if necessary const Layout& produce = produce_ilayouts[i]; const Layout& request = request_ilayouts[i]; if (produce != request && produce.defined()) { - nnvm::NodePtr tnode = CreateLayoutTransformNode(produce, request); + nnvm::ObjectPtr tnode = CreateLayoutTransformNode(produce, request); tnode->attrs.name = idx[e.node_id].source->attrs.name + "_" + request.name(); tnode->inputs.emplace_back(new_node->inputs[i]); nnvm::NodeEntry tnode_output(std::move(tnode), 0, 0); diff --git a/nnvm/src/pass/gradient.cc b/nnvm/src/pass/gradient.cc index 3e925222504c4..9c30a785cac22 100644 --- a/nnvm/src/pass/gradient.cc +++ b/nnvm/src/pass/gradient.cc @@ -37,13 +37,13 @@ NodeEntry DefaultAggregateGradient(std::vector&& v) { if (v.size() == 1) { return std::move(v[0]); } else if (v.size() == 0) { - NodePtr zero_node = Node::Create(); + ObjectPtr zero_node = Node::Create(); zero_node->attrs.op = Op::Get("zeros"); zero_node->attrs.name = "zero_grad"; zero_node->attrs.op->attr_parser(&(zero_node->attrs)); return NodeEntry{zero_node, 0, 0}; } else { - NodePtr sum_node = Node::Create(); + ObjectPtr sum_node = Node::Create(); sum_node->attrs.op = Op::Get("elemwise_sum"); sum_node->inputs = std::move(v); sum_node->attrs.name = "grad_sum"; @@ -119,10 +119,10 @@ Graph Gradient(Graph src) { nullptr; // topo sort - std::vector topo_order; + std::vector topo_order; std::unordered_map > output_grads; - DFSVisit(ys, [&](const NodePtr& node) { + DFSVisit(ys, [&](const ObjectPtr& node) { if (output_grads.count(node.get()) == 0) { output_grads[node.get()].resize(node->num_outputs()); } @@ -143,11 +143,11 @@ Graph Gradient(Graph src) { } // construct mirror as memory reduction strategy if needed - std::unordered_map mirror_map; + std::unordered_map mirror_map; if (mirror_fun != nullptr) { - for (const NodePtr& node_ptr : topo_order) { + for (const ObjectPtr& node_ptr : topo_order) { if (mirror_fun(*node_ptr)) { - NodePtr new_node = Node::Create(); + ObjectPtr new_node = Node::Create(); *new_node = *node_ptr; new_node->attrs.name += "_mirror"; for (auto& e : new_node->inputs) { @@ -169,7 +169,7 @@ Graph Gradient(Graph src) { std::vector out_agg_grads; for (auto rit = topo_order.rbegin(); rit != topo_order.rend(); ++rit) { - const NodePtr& ptr = *rit; + const ObjectPtr& ptr = *rit; if (ptr->is_variable()) continue; out_agg_grads.clear(); auto& out_grad_vec = output_grads.at(ptr.get()); @@ -182,7 +182,7 @@ Graph Gradient(Graph src) { out_agg_grads.push_back(e.sum); } if ((*rit)->inputs.size() != 0) { - NodePtr fwd_node = (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get())); + ObjectPtr fwd_node = (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get())); std::vector input_grads; // Check for FGradient if (grad_fun_map.contains(ptr->op())) { @@ -244,7 +244,7 @@ Graph Gradient(Graph src) { if (kv == unique_grads.end()) { unique_grads.emplace(std::move(entry.sum), std::make_pair(1, counter)); } else { - NodePtr copy_node = Node::Create(); + ObjectPtr copy_node = Node::Create(); std::ostringstream os; os << entry.sum.node->attrs.name << "_" << kv->second.first << "_copy"; kv->second.first++; diff --git a/nnvm/src/pass/infer_shape_type.cc b/nnvm/src/pass/infer_shape_type.cc index a5797736209f0..876dce1c113d0 100644 --- a/nnvm/src/pass/infer_shape_type.cc +++ b/nnvm/src/pass/infer_shape_type.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -112,7 +112,7 @@ Graph InferAttr(Graph &&ret, CHECK_GE(inode.control_deps.size(), 1U) << "BackwardOp need to have control_deps to its forward op"; const IndexedGraph::Node& fnode = idx[inode.control_deps[0]]; - NodePtr fwd_ptr = inode.source->control_deps[0]; + ObjectPtr fwd_ptr = inode.source->control_deps[0]; CHECK(fwd_ptr->op() != nullptr) << "Forward op cannot be a variable"; // use gradient function to find out the correspondence. std::vector ograd(fwd_ptr->num_outputs()); diff --git a/nnvm/src/pass/order_mutation.cc b/nnvm/src/pass/order_mutation.cc index 6f43da282ee4d..b2fa2ca33e07a 100644 --- a/nnvm/src/pass/order_mutation.cc +++ b/nnvm/src/pass/order_mutation.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -45,7 +45,7 @@ inline bool IsMutate(const std::vector& mutate_inputs, uint32_t i) { Graph OrderMutation(const Graph& src) { std::unordered_map > version_hist; - DFSVisit(src.outputs, [&version_hist](const NodePtr& n) { + DFSVisit(src.outputs, [&version_hist](const ObjectPtr& n) { for (const NodeEntry& e : n->inputs) { if (e.node->is_variable()) { if (e.version != 0 && version_hist.count(e.node.get()) == 0) { @@ -57,8 +57,8 @@ Graph OrderMutation(const Graph& src) { // no mutation happens, everything if fine. if (version_hist.size() == 0) return src; // start preparing for remapping the nodes. - std::unordered_map old_new; - auto prepare = [&version_hist, &old_new] (const NodePtr& n) { + std::unordered_map old_new; + auto prepare = [&version_hist, &old_new] (const ObjectPtr& n) { static auto& fmutate_inputs = Op::GetAttr("FMutateInputs"); std::vector mutate_inputs; if (!n->is_variable() && fmutate_inputs.count(n->op())) { @@ -80,11 +80,11 @@ Graph OrderMutation(const Graph& src) { if (old_new.count(e.node.get()) != 0) need_repl = true; } } - for (const NodePtr& p : n->control_deps) { + for (const ObjectPtr& p : n->control_deps) { if (old_new.count(p.get()) != 0) need_repl = true; } if (need_repl) { - NodePtr np = Node::Create(); + ObjectPtr np = Node::Create(); np->attrs = n->attrs; old_new[n.get()] = std::move(np); } @@ -111,7 +111,7 @@ Graph OrderMutation(const Graph& src) { kv.second->inputs.push_back(e); } } - for (const NodePtr& p : kv.first->control_deps) { + for (const ObjectPtr& p : kv.first->control_deps) { kv.second->control_deps.emplace_back( get_with_default(old_new, p.get(), p)); } diff --git a/nnvm/src/pass/place_device.cc b/nnvm/src/pass/place_device.cc index a0c0fb2f534af..6d6866e472d6a 100644 --- a/nnvm/src/pass/place_device.cc +++ b/nnvm/src/pass/place_device.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -105,8 +105,8 @@ Graph PlaceDevice(Graph src) { src.attrs["device"] = std::make_shared(std::move(device)); return src; } - std::map, NodePtr> copy_map; - std::vector new_node_map(idx.num_nodes(), nullptr); + std::map, ObjectPtr> copy_map; + std::vector new_node_map(idx.num_nodes(), nullptr); std::unordered_map new_device_map; static auto& fmutate_inputs = Op::GetAttr("FMutateInputs"); @@ -142,7 +142,7 @@ Graph PlaceDevice(Graph src) { CHECK(!need_mutate) << "consistency check"; } if (need_mutate) { - NodePtr new_node = Node::Create(); + ObjectPtr new_node = Node::Create(); new_node->attrs = inode.source->attrs; new_node->inputs.reserve(inode.inputs.size()); for (size_t i = 0; i < inode.inputs.size(); ++i) { @@ -154,7 +154,7 @@ Graph PlaceDevice(Graph src) { new_node->inputs.emplace_back( NodeEntry{it->second, 0, 0}); } else { - NodePtr copy_node = Node::Create(); + ObjectPtr copy_node = Node::Create(); std::ostringstream os; os << inode.source->inputs[i].node->attrs.name << "_" << e.index <<"_copy"; copy_node->attrs.op = copy_op; diff --git a/nnvm/src/pass/saveload_json.cc b/nnvm/src/pass/saveload_json.cc index 69d4a05f66e89..9389995d05211 100644 --- a/nnvm/src/pass/saveload_json.cc +++ b/nnvm/src/pass/saveload_json.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -86,7 +86,7 @@ struct JSONNode { }; // pointer to the graph node - NodePtr node; + ObjectPtr node; // inputs std::vector inputs; // control flow dependencies @@ -190,7 +190,7 @@ struct JSONGraph { void Symbol2JSONGraph(std::shared_ptr src, JSONGraph *jgraph) { std::unordered_map node2index; jgraph->node_row_ptr.push_back(0); - DFSVisit(src->outputs, [&node2index, jgraph](const NodePtr& n) { + DFSVisit(src->outputs, [&node2index, jgraph](const ObjectPtr& n) { uint32_t nid = static_cast(jgraph->nodes.size()); node2index[n.get()] = nid; if (n->is_variable()) { @@ -202,7 +202,7 @@ void Symbol2JSONGraph(std::shared_ptr src, JSONGraph *jgraph) { for (const NodeEntry& e : n->inputs) { jnode.inputs.emplace_back(node2index.at(e.node.get()), e.index, e.version); } - for (const NodePtr& c : n->control_deps) { + for (const ObjectPtr& c : n->control_deps) { jnode.control_deps.push_back(node2index.at(c.get())); } jgraph->node_row_ptr.push_back(jgraph->node_row_ptr.back() + n->num_outputs()); diff --git a/src/api/api_base.cc b/src/api/api_base.cc index cbefaa464ded6..bcfd82bee7fee 100644 --- a/src/api/api_base.cc +++ b/src/api/api_base.cc @@ -32,7 +32,7 @@ TVM_REGISTER_API("_format_str") .set_body([](TVMArgs args, TVMRetValue *ret) { CHECK(args[0].type_code() == kObjectHandle); std::ostringstream os; - os << args[0].operator NodeRef(); + os << args[0].operator ObjectRef(); *ret = os.str(); }); diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 8a74fe5cdb7df..00ceaf72118c4 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -65,7 +65,7 @@ TVM_REGISTER_API("_Array") data.push_back(ObjectRef(nullptr)); } } - auto node = make_node(); + auto node = make_object(); node->data = std::move(data); *ret = Array(node); }); @@ -105,7 +105,7 @@ TVM_REGISTER_API("_Map") data.emplace(std::make_pair(args[i].operator std::string(), args[i + 1].operator ObjectRef())); } - auto node = make_node(); + auto node = make_object(); node->data = std::move(data); *ret = Map(node); } else { @@ -119,7 +119,7 @@ TVM_REGISTER_API("_Map") data.emplace(std::make_pair(args[i].operator ObjectRef(), args[i + 1].operator ObjectRef())); } - auto node = make_node(); + auto node = make_object(); node->data = std::move(data); *ret = Map(node); } @@ -186,7 +186,7 @@ TVM_REGISTER_API("_MapItems") if (ptr->IsInstance()) { auto* n = static_cast(ptr); - auto rkvs = make_node(); + auto rkvs = make_object(); for (const auto& kv : n->data) { rkvs->data.push_back(kv.first); rkvs->data.push_back(kv.second); @@ -194,7 +194,7 @@ TVM_REGISTER_API("_MapItems") *ret = Array(rkvs); } else { auto* n = static_cast(ptr); - auto rkvs = make_node(); + auto rkvs = make_object(); for (const auto& kv : n->data) { rkvs->data.push_back(ir::StringImm::make(kv.first)); rkvs->data.push_back(kv.second); diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index c62cc8ad16a05..339b25a518946 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -100,12 +100,13 @@ TVM_REGISTER_API("ir_pass.RewriteForTensorCore") }); TVM_REGISTER_API("ir_pass.AttrsEqual") -.set_body_typed([](const NodeRef& lhs, const NodeRef& rhs) { +.set_body_typed( + [](const ObjectRef& lhs, const ObjectRef& rhs) { return AttrsEqual()(lhs, rhs); }); TVM_REGISTER_API("ir_pass.AttrsHash") -.set_body_typed([](const NodeRef &node) { +.set_body_typed([](const ObjectRef &node) { return AttrsHash()(node); }); @@ -118,7 +119,7 @@ TVM_REGISTER_API("ir_pass.ExprUseVar") TVM_REGISTER_API("ir_pass.PostOrderVisit") .set_body([](TVMArgs args, TVMRetValue *ret) { PackedFunc f = args[1]; - ir::PostOrderVisit(args[0], [f](const NodeRef& n) { + ir::PostOrderVisit(args[0], [f](const ObjectRef& n) { f(n); }); }); @@ -126,7 +127,7 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit") TVM_REGISTER_API("ir_pass.LowerStorageAccess") .set_body([](TVMArgs args, TVMRetValue *ret) { LoweredFunc f = args[0]; - auto n = make_node(*f.operator->()); + auto n = make_object(*f.operator->()); n->body = LowerStorageAccessInfo(f->body); *ret = LoweredFunc(n); }); diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 19f045241915e..0b84be291f71e 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -42,7 +42,7 @@ class VariablePathFinder: public IRVisitor { public: explicit VariablePathFinder(Expr target) : target_(target) {} - void Visit(const NodeRef& node) final { + void Visit(const ObjectRef& node) final { if (visited_.count(node.get()) != 0) return; visited_.insert(node.get()); @@ -82,7 +82,7 @@ class BoundDeducer: public IRVisitor { void Deduce(); - void Visit(const NodeRef& e) final { + void Visit(const ObjectRef& e) final { if (!success_) return; if (e.get() == path_[iter_++]) { IRVisitor::Visit(e); @@ -202,7 +202,7 @@ class BoundDeduceInputChecker: public IRVisitor { return target_count == 1; } - void Visit(const NodeRef& e) final { + void Visit(const ObjectRef& e) final { if (e.same_as(deducer_->target_)) ++target_count; IRVisitor::Visit(e); } diff --git a/src/arithmetic/canonical_simplify.cc b/src/arithmetic/canonical_simplify.cc index 022dd8e94dbb6..6a19a7aeb3f2a 100644 --- a/src/arithmetic/canonical_simplify.cc +++ b/src/arithmetic/canonical_simplify.cc @@ -56,7 +56,7 @@ class CanonicalExprNode : public BaseExprNode { } static constexpr const char* _type_key = "arith.CanonicalExpr"; - TVM_DECLARE_BASE_NODE_INFO(CanonicalExprNode, BaseExprNode); + TVM_DECLARE_BASE_OBJECT_INFO(CanonicalExprNode, BaseExprNode); }; enum DivMode { @@ -147,10 +147,14 @@ class SplitExprNode : public CanonicalExprNode { /*! \brief positive infty */ static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf; static constexpr const char* _type_key = "arith.SplitExpr"; - TVM_DECLARE_NODE_TYPE_INFO(SplitExprNode, CanonicalExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(SplitExprNode, CanonicalExprNode); }; -TVM_DEFINE_COW_NODE_REF(SplitExpr, Expr, SplitExprNode); +class SplitExpr : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(SplitExpr, Expr, SplitExprNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(SplitExprNode); +}; inline bool SplitExprNode::IndexEqual(const SplitExpr& other) const { if (index.same_as(other->index)) return true; @@ -272,7 +276,7 @@ class SumExprNode : public CanonicalExprNode { void AddToSelf(const SumExpr& other, int64_t scale); static constexpr const char* _type_key = "arith.SumExpr"; - TVM_DECLARE_NODE_TYPE_INFO(SumExprNode, CanonicalExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(SumExprNode, CanonicalExprNode); private: /*! @@ -405,7 +409,11 @@ class SumExprNode : public CanonicalExprNode { } }; -TVM_DEFINE_COW_NODE_REF(SumExpr, Expr, SumExprNode); +class SumExpr : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(SumExpr, Expr, SumExprNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(SumExprNode); +}; void SumExprNode::AddToSelf(const SumExpr& other, int64_t scale) { // NOTE: it is rare to have a balanced long expression, @@ -507,7 +515,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { if (const auto* op = expr.as()) { expr = op->Normalize(); } - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->dtype = expr.dtype(); n->index = std::move(expr); n->div_mode = kTruncDiv; @@ -544,7 +552,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { if (const auto* op = expr.as()) { return GetRef(op); } - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->dtype = expr.dtype(); if (const auto* op = expr.as()) { n->base = op->value; @@ -655,8 +663,8 @@ SeparateDivisibleParts(const SumExprNode* psum, int64_t coeff, SumExpr* out_divisible, SumExpr* out_non_divisible) { - auto divisible = make_node(); - auto non_divisible = make_node(); + auto divisible = make_object(); + auto non_divisible = make_object(); divisible->dtype = psum->dtype; non_divisible->dtype = psum->dtype; diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index c0519107d5b8c..16e489a9c8187 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -35,7 +35,7 @@ TVM_REGISTER_NODE_TYPE(ConstIntBoundNode); ConstIntBound::ConstIntBound( int64_t min_value, int64_t max_value) { - auto node = make_node(); + auto node = make_object(); node->min_value = min_value; node->max_value = max_value; data_ = std::move(node); @@ -123,7 +123,7 @@ class ConstIntBoundAnalyzer::Impl : } // Override visitor behaviors - Entry VisitExprDefault_(const Node* op) final { + Entry VisitExprDefault_(const Object* op) final { return Everything( static_cast(op)->dtype); } diff --git a/src/arithmetic/detect_linear_equation.cc b/src/arithmetic/detect_linear_equation.cc index cf37545502ba7..c4ee40f12da80 100644 --- a/src/arithmetic/detect_linear_equation.cc +++ b/src/arithmetic/detect_linear_equation.cc @@ -106,7 +106,7 @@ class LinearEqDetector } return ret; } - LinearEqEntry VisitExprDefault_(const Node* op, const Expr& e) final { + LinearEqEntry VisitExprDefault_(const Object* op, const Expr& e) final { if (fail_) return LinearEqEntry(); if (ExprUseVar(e, var_)) { fail_ = true; @@ -171,7 +171,7 @@ bool DetectClipBound( std::unordered_map* bmap) { int flag = 0; Var var; - auto fvisit = [&bmap, &flag, &var](const NodeRef& n) { + auto fvisit = [&bmap, &flag, &var](const ObjectRef& n) { if (const Variable* v = n.as()) { if (bmap->count(v)) { if (flag == 0) { diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index e4f2042a19d70..79b39748426d3 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -37,7 +37,7 @@ Expr SymbolicLimits::pos_inf_ = Var("pos_inf", DataType::Handle()); Expr SymbolicLimits::neg_inf_ = Var("neg_inf", DataType::Handle()); IntervalSet::IntervalSet(Expr min_value, Expr max_value) { - auto node = make_node(); + auto node = make_object(); node->min_value = std::move(min_value); node->max_value = std::move(max_value); data_ = std::move(node); @@ -505,7 +505,7 @@ class IntervalSetEvaluator : return Union(analyzer_, false_set, true_set); } - IntervalSet VisitExprDefault_(const Node* op) final { + IntervalSet VisitExprDefault_(const Object* op) final { DLOG(WARNING) << "cannot evaluate set type " << op->GetTypeKey(); return IntervalSet::Everything(); } diff --git a/src/arithmetic/int_set.h b/src/arithmetic/int_set.h index 831b444090306..2e072127b4498 100644 --- a/src/arithmetic/int_set.h +++ b/src/arithmetic/int_set.h @@ -75,7 +75,7 @@ class IntervalSetNode : public IntSetNode { } static constexpr const char* _type_key = "arith.IntervalSet"; - TVM_DECLARE_NODE_TYPE_INFO(IntervalSetNode, IntSetNode); + TVM_DECLARE_FINAL_OBJECT_INFO(IntervalSetNode, IntSetNode); }; /*! @@ -116,8 +116,8 @@ class IntervalSet : public IntSet { return IntervalSet(pos_inf(), neg_inf()); } - TVM_DEFINE_NODE_REF_COW(IntervalSetNode); - TVM_DEFINE_NODE_REF_METHODS(IntervalSet, IntSet, IntervalSetNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(IntervalSetNode); + TVM_DEFINE_OBJECT_REF_METHODS(IntervalSet, IntSet, IntervalSetNode); }; /*! diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc index 25c7391fd9c4f..5ab1bd3867488 100644 --- a/src/arithmetic/modular_set.cc +++ b/src/arithmetic/modular_set.cc @@ -37,7 +37,7 @@ using namespace ir; TVM_REGISTER_NODE_TYPE(ModularSetNode); ModularSet::ModularSet(int64_t coeff, int64_t base) { - auto node = make_node(); + auto node = make_object(); node->coeff = coeff; node->base = base; // finish construction. @@ -120,7 +120,7 @@ class ModularSetAnalyzer::Impl : } // Override visitor behaviors - Entry VisitExprDefault_(const Node* op) final { + Entry VisitExprDefault_(const Object* op) final { return Everything(); } diff --git a/src/arithmetic/pattern_match.h b/src/arithmetic/pattern_match.h index fd07a377e955c..bff956473c87d 100644 --- a/src/arithmetic/pattern_match.h +++ b/src/arithmetic/pattern_match.h @@ -250,7 +250,7 @@ class PBinaryExpr : b_.InitMatch_(); } - bool Match_(const NodeRef& node) const { + bool Match_(const ObjectRef& node) const { if (const NodeType* ptr = node.as()) { if (!a_.Match_(ptr->a)) return false; if (!b_.Match_(ptr->b)) return false; @@ -282,7 +282,7 @@ class PConstWithTypeLike : void InitMatch_() const {} - bool Match_(const NodeRef& node) const { + bool Match_(const ObjectRef& node) const { if (const ir::IntImm* ptr = node.as()) { return ptr->value == value_; } else { @@ -364,7 +364,7 @@ class PNotExpr : public Pattern > { value_.InitMatch_(); } - bool Match_(const NodeRef& node) const { + bool Match_(const ObjectRef& node) const { if (const ir::Not* ptr = node.as()) { if (!value_.Match_(ptr->a)) return false; return true; @@ -410,7 +410,7 @@ class PSelectExpr : false_value_.InitMatch_(); } - bool Match_(const NodeRef& node) const { + bool Match_(const ObjectRef& node) const { if (const ir::Select* ptr = node.as()) { if (!condition_.Match_(ptr->condition)) return false; if (!true_value_.Match_(ptr->true_value)) return false; @@ -472,7 +472,7 @@ class PCastExpr : value_.InitMatch_(); } - bool Match_(const NodeRef& node) const { + bool Match_(const ObjectRef& node) const { if (const ir::Cast* ptr = node.as()) { if (!dtype_.Match_(ptr->dtype)) return false; if (!value_.Match_(ptr->value)) return false; @@ -530,7 +530,7 @@ class PRampExpr : lanes_.InitMatch_(); } - bool Match_(const NodeRef& node) const { + bool Match_(const ObjectRef& node) const { if (const ir::Ramp* ptr = node.as()) { if (!base_.Match_(ptr->base)) return false; if (!stride_.Match_(ptr->stride)) return false; @@ -592,7 +592,7 @@ class PBroadcastExpr : lanes_.InitMatch_(); } - bool Match_(const NodeRef& node) const { + bool Match_(const ObjectRef& node) const { if (const ir::Broadcast* ptr = node.as()) { if (!value_.Match_(ptr->value)) return false; if (!lanes_.Match_(ptr->lanes)) return false; @@ -704,7 +704,7 @@ class PCallExpr : detail::tuple_for_each(finit, args_); } - bool Match_(const NodeRef& node) const { + bool Match_(const ObjectRef& node) const { if (const ir::Call* ptr = node.as()) { if (ptr->args.size() != sizeof...(TArgs)) return false; if (ptr->name != Op::kName) return false; diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index ca25731cafef2..3ea2cb77d316f 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -53,7 +53,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) */ Target CreateTarget(const std::string& target_name, const std::vector& options) { - auto t = make_node(); + auto t = make_object(); t->target_name = target_name; std::string libs_flag = "-libs="; @@ -366,7 +366,7 @@ void GetBinds(const Array& args, bool compact, const std::unordered_map& binds, Map* out_binds, - Array* out_arg_list, + Array* out_arg_list, const BuildConfig& config) { *out_binds = binds; @@ -396,7 +396,7 @@ Stmt BuildStmt(Schedule sch, const Array& args, const std::unordered_map& binds, bool loop_partition, - Array *out_arg_list, + Array *out_arg_list, const BuildConfig& config) { sch = sch.normalize(); @@ -445,7 +445,7 @@ Array lower(Schedule sch, const std::string& name, const std::unordered_map& binds, const BuildConfig& config) { - Array out_arg_list; + Array out_arg_list; auto stmt = BuildStmt(sch, args, binds, true, &out_arg_list, config); return Array({ ir::MakeAPI(stmt, name, out_arg_list, 0, config->restricted_func) }); } @@ -618,7 +618,7 @@ runtime::Module build(const Array& funcs, } BuildConfig BuildConfig::Create() { - return BuildConfig(make_node()); + return BuildConfig(make_object()); } /*! \brief Entry to hold the BuildConfig context stack. */ @@ -701,7 +701,7 @@ GenericFunc GenericFunc::Get(const std::string& name) { std::lock_guard(m->mutex); auto it = m->fmap.find(name); if (it == m->fmap.end()) { - auto f = make_node(); + auto f = make_object(); f->name_ = name; auto gf = GenericFunc(f); m->fmap[name] = gf; @@ -825,7 +825,7 @@ TVM_REGISTER_API("_BuildConfigGetAddLowerPassInfo") TVM_REGISTER_API("_GenericFuncCreate") .set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = GenericFunc(make_node()); + *ret = GenericFunc(make_object()); }); TVM_REGISTER_API("_GenericFuncGetGlobal") diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 2bb86093e2f8e..c723a2284ebf6 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -408,7 +408,7 @@ void CodeGenHybrid::PrintIndent() { std::string CodeGenHybrid::GetVarID(const Variable *v) { if (binds_.count(v)) return binds_[v]; - auto key = std::make_pair(static_cast(v), 0); + auto key = std::make_pair(static_cast(v), 0); if (id_map_.count(key)) { return id_map_[key]; } @@ -472,7 +472,7 @@ void CodeGenHybrid::ReserveKeywords() { } void CodeGenHybrid::DumpStmt(const Stmt &stmt, - const Array &inputs, + const Array &inputs, const Array &outputs, const std::string &name) { ReserveKeywords(); diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index 2c719b0b3ecfe..647ef77fc5341 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -56,7 +56,7 @@ class CodeGenHybrid : * \param outputs Output tensors of this schedule. * \param name The name of the function. */ - void DumpStmt(const Stmt &stmt, const Array &inputs, const Array &outputs, + void DumpStmt(const Stmt &stmt, const Array &inputs, const Array &outputs, const std::string &name = "hybrid_func"); /*! * \brief Finalize the compilation and return the code. @@ -152,7 +152,7 @@ class CodeGenHybrid : /*! * \brief Keys are either (tensors, value_index) or (variables, 0). * Values are the corresponding IDs.*/ - std::map, std::string> id_map_; + std::map, std::string> id_map_; /*! \brief Variables (keys) binded to the threads (values). */ std::map binds_; /*! diff --git a/src/lang/api_registry.cc b/src/lang/api_registry.cc index d6a413e987cfd..68d42a2c14338 100644 --- a/src/lang/api_registry.cc +++ b/src/lang/api_registry.cc @@ -33,7 +33,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ObjectPtr CreateEnvNode(const std::string& name) { auto* f = runtime::Registry::Get(name); CHECK(f != nullptr) << "Cannot find global function \'" << name << '\''; - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->func = *f; n->name = name; return n; diff --git a/src/lang/attrs.cc b/src/lang/attrs.cc index b83734beacb39..1c341d53168ef 100644 --- a/src/lang/attrs.cc +++ b/src/lang/attrs.cc @@ -39,8 +39,8 @@ void DictAttrsNode::InitByPackedArgs( for (int i = 0; i < args.size(); i += 2) { std::string key = args[i]; runtime::TVMArgValue val = args[i + 1]; - if (val.type_code() == kObjectHandle) { - dict.Set(key, val.operator NodeRef()); + if (val.IsObjectRef()) { + dict.Set(key, val.operator ObjectRef()); } else if (val.type_code() == kStr) { dict.Set(key, Expr(val.operator std::string())); } else { @@ -53,8 +53,8 @@ Array DictAttrsNode::ListFieldInfo() const { return {}; } -Attrs DictAttrsNode::make(Map dict) { - NodePtr n = make_node(); +Attrs DictAttrsNode::make(Map dict) { + ObjectPtr n = make_object(); n->dict = std::move(dict); return Attrs(n); } diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc index eb5d87efbbfa0..9bbd8d62105fc 100644 --- a/src/lang/buffer.cc +++ b/src/lang/buffer.cc @@ -334,7 +334,7 @@ Buffer Buffer::MakeStrideView() const { if ((*this)->strides.size() != 0) return *this; if ((*this)->shape.size() == 0) return *this; std::vector temp; - auto n = make_node(*operator->()); + auto n = make_object(*operator->()); Expr acc = make_const(n->DefaultIndexType(), 1); for (size_t i = n->shape.size(); i != 0 ; --i) { temp.push_back(acc); @@ -419,7 +419,7 @@ Buffer BufferNode::make(Var data, int data_alignment, int offset_factor, BufferType buffer_type) { - auto n = make_node(); + auto n = make_object(); n->data = std::move(data); n->dtype = dtype; n->shape = std::move(shape); diff --git a/src/lang/data_layout.cc b/src/lang/data_layout.cc index 5393bbffb1483..58f033b69e510 100644 --- a/src/lang/data_layout.cc +++ b/src/lang/data_layout.cc @@ -68,7 +68,7 @@ const LayoutAxis& LayoutAxis::make(const std::string& name) { } Layout::Layout(const Array& axes) { - auto node = make_node(); + auto node = make_object(); node->axes = axes; std::ostringstream repr; for (const IterVar& axis : axes) { @@ -89,7 +89,7 @@ Layout::Layout(const Array& axes) { Layout::Layout(const std::string& name) { // NOLINT(*) if (name == "__undef__") return; - auto node = make_node(); + auto node = make_object(); node->name = name; if (name.empty()) return; // scalar @@ -347,7 +347,7 @@ Array BijectiveLayout::BackwardShape(const Array& shape) const { BijectiveLayout BijectiveLayoutNode::make(const Layout& src_layout, const Layout& dst_layout) { - auto n = make_node(); + auto n = make_object(); n->src_layout = src_layout; n->dst_layout = dst_layout; diff --git a/src/lang/expr.cc b/src/lang/expr.cc index 997c15177546e..5a54f0407c8de 100644 --- a/src/lang/expr.cc +++ b/src/lang/expr.cc @@ -42,14 +42,14 @@ Var::Var(std::string name_hint, DataType t) : Var(Variable::make(t, name_hint)) {} Var Variable::make(DataType t, std::string name_hint) { - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = t; node->name_hint = std::move(name_hint); return Var(node); } Range::Range(Expr begin, Expr end) - : Range(make_node( + : Range(make_object( begin, is_zero(begin) ? end : (end - begin))) { } @@ -57,21 +57,21 @@ Range::Range(Expr begin, Expr end) Integer IntImm::make(DataType t, int64_t value) { CHECK(t.is_int() && t.is_scalar()) << "ValueError: IntImm can only take scalar."; - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = t; node->value = value; return Integer(node); } Range Range::make_by_min_extent(Expr min, Expr extent) { - return Range(make_node(min, extent)); + return Range(make_object(min, extent)); } IterVar IterVarNode::make(Range dom, Var var, IterVarType t, std::string thread_tag) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->dom = dom; n->var = var; n->iter_type = t; @@ -89,7 +89,7 @@ IterVar reduce_axis(Range dom, std::string name) { dom, Var(name), kCommReduce); } -void Dump(const NodeRef& n) { +void Dump(const ObjectRef& n) { std::cerr << n << "\n"; } diff --git a/src/lang/ir.cc b/src/lang/ir.cc index 427e026bc7284..d5cc285ac861a 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -34,7 +34,7 @@ namespace ir { Expr UIntImm::make(DataType t, uint64_t value) { CHECK(t.is_uint() && t.lanes() == 1) << "ValueError: UIntImm can only take scalar"; - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = t; node->value = value; return Expr(node); @@ -43,14 +43,14 @@ Expr UIntImm::make(DataType t, uint64_t value) { Expr FloatImm::make(DataType t, double value) { CHECK_EQ(t.lanes(), 1) << "ValueError: FloatImm can only take scalar"; - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = t; node->value = value; return Expr(node); } Expr StringImm::make(std::string value) { - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = DataType::Handle(); node->value = std::move(value); return Expr(node); @@ -59,7 +59,7 @@ Expr StringImm::make(std::string value) { Expr Cast::make(DataType t, Expr value) { CHECK(value.defined()); CHECK_EQ(t.lanes(), value.dtype().lanes()); - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = t; node->value = std::move(value); return Expr(node); @@ -72,7 +72,7 @@ Expr And::make(Expr a, Expr b) { CHECK(b.dtype().is_bool()); CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types"; - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = DataType::Bool(a.dtype().lanes()); node->a = std::move(a); node->b = std::move(b); @@ -86,7 +86,7 @@ Expr Or::make(Expr a, Expr b) { CHECK(b.dtype().is_bool()); CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types"; - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = DataType::Bool(a.dtype().lanes()); node->a = std::move(a); node->b = std::move(b); @@ -97,7 +97,7 @@ Expr Not::make(Expr a) { CHECK(a.defined()) << "ValueError: a is undefined"; CHECK(a.dtype().is_bool()); - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = DataType::Bool(a.dtype().lanes()); node->a = std::move(a); return Expr(node); @@ -111,7 +111,7 @@ Expr Select::make(Expr condition, Expr true_value, Expr false_value) { CHECK_EQ(condition.dtype().lanes(), true_value.dtype().lanes()); CHECK(false_value.dtype() == true_value.dtype()) << "TypeError: mismatched types"; - NodePtr(); + ObjectPtr(); node->dtype = true_value.dtype(); node->condition = std::move(condition); node->true_value = std::move(true_value); @@ -126,7 +126,7 @@ Expr Load::make(DataType dtype, Var buffer_var, Expr index, Expr predicate) { CHECK_EQ(dtype.lanes(), index.dtype().lanes()); CHECK_EQ(dtype.lanes(), predicate.dtype().lanes()); - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = dtype; node->buffer_var = std::move(buffer_var); node->index = std::move(index); @@ -143,7 +143,7 @@ Expr Ramp::make(Expr base, Expr stride, int lanes) { CHECK_GT(lanes, 1); CHECK_EQ(stride.dtype(), base.dtype()); - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = base.dtype().with_lanes(lanes); node->base = base; node->stride = stride; @@ -156,7 +156,7 @@ Expr Broadcast::make(Expr value, int lanes) { CHECK(value.dtype().is_scalar()); CHECK_GT(lanes, 1); - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = value.dtype().with_lanes(lanes); node->value = std::move(value); node->lanes = lanes; @@ -168,7 +168,7 @@ Expr Let::make(Var var, Expr value, Expr body) { CHECK(body.defined()); CHECK_EQ(value.dtype(), var.dtype()); - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = body.dtype(); node->var = std::move(var); node->value = std::move(value); @@ -208,7 +208,7 @@ Expr Call::make(DataType dtype, } } - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = dtype; node->name = std::move(name); node->args = std::move(args); @@ -232,7 +232,7 @@ Expr Shuffle::make(Array vectors, } CHECK_LE(indices.size(), static_cast(total_lanes)); - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = base_type.with_lanes(static_cast(indices.size())); node->vectors = std::move(vectors); node->indices = std::move(indices); @@ -262,7 +262,7 @@ CommReducer CommReducerNode::make(Array lhs, Array rhs, Array result, Array identity_element) { - auto node = make_node(); + auto node = make_object(); node->lhs = lhs; node->rhs = rhs; node->result = result; @@ -293,7 +293,7 @@ Expr Reduce::make(CommReducer combiner, Array source, if (!condition.defined()) { condition = const_true(); } - auto n = make_node(); + auto n = make_object(); CHECK(source.defined()); for (size_t i = 0; i < axis.size(); ++i) { CHECK(axis[i].defined()); @@ -308,7 +308,7 @@ Expr Reduce::make(CommReducer combiner, Array source, } Expr Any::make() { - auto n = make_node(); + auto n = make_object(); return Expr(n); } @@ -317,18 +317,18 @@ Stmt LetStmt::make(Var var, Expr value, Stmt body) { CHECK(body.defined()); CHECK_EQ(value.dtype(), var.dtype()); - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->var = std::move(var); node->value = std::move(value); node->body = std::move(body); return Stmt(node); } -Stmt AttrStmt::make(NodeRef node, +Stmt AttrStmt::make(ObjectRef node, std::string attr_key, Expr value, Stmt body) { - auto n = make_node(); + auto n = make_object(); n->node = node; n->attr_key = std::move(attr_key); n->value = std::move(value); @@ -343,7 +343,7 @@ Stmt AssertStmt::make(Expr condition, Expr message, Stmt body) { << "TypeError: AssertStmt message must be an int or string:" << message << "\n"; - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->condition = std::move(condition); node->message = std::move(message); node->body = std::move(body); @@ -353,7 +353,7 @@ Stmt AssertStmt::make(Expr condition, Expr message, Stmt body) { Stmt ProducerConsumer::make(FunctionRef func, bool is_producer, Stmt body) { CHECK(body.defined()); - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->func = std::move(func); node->is_producer = is_producer; node->body = std::move(body); @@ -373,7 +373,7 @@ Stmt For::make(Var loop_var, CHECK(loop_var.dtype().is_scalar()); CHECK(body.defined()); - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->loop_var = std::move(loop_var); node->min = std::move(min); node->extent = std::move(extent); @@ -390,7 +390,7 @@ Stmt Store::make(Var buffer_var, Expr value, Expr index, Expr predicate) { CHECK_EQ(value.dtype().lanes(), index.dtype().lanes()); CHECK_EQ(value.dtype().lanes(), predicate.dtype().lanes()); - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->buffer_var = std::move(buffer_var); node->value = std::move(value); node->index = std::move(index); @@ -407,7 +407,7 @@ Stmt Provide::make(FunctionRef func, int value_index, Expr value, Array ar CHECK(args[i].defined()) << "Provide to undefined location\n"; } - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->func = std::move(func); node->value_index = value_index; node->value = std::move(value); @@ -430,7 +430,7 @@ Stmt Allocate::make(Var buffer_var, CHECK(condition.defined()); CHECK(condition.dtype().is_bool()); - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->buffer_var = std::move(buffer_var); node->dtype = dtype; node->extents = std::move(extents); @@ -457,7 +457,7 @@ int32_t Allocate::constant_allocation_size(const Array& extents) { } Stmt Free::make(Var buffer_var) { - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->buffer_var = buffer_var; return Stmt(node); } @@ -478,7 +478,7 @@ Stmt Realize::make(FunctionRef func, CHECK(condition.defined()); CHECK(condition.dtype().is_bool()); - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->func = std::move(func); node->value_index = value_index; node->dtype = dtype; @@ -496,7 +496,7 @@ Stmt Prefetch::make(FunctionRef func, int value_index, DataType dtype, Region bo CHECK(bounds[i]->extent.dtype().is_scalar()); } - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->func = std::move(func); node->value_index = value_index; node->dtype = dtype; @@ -507,7 +507,7 @@ Stmt Prefetch::make(FunctionRef func, int value_index, DataType dtype, Region bo Stmt Block::make(Stmt first, Stmt rest) { CHECK(first.defined()); CHECK(rest.defined()); - NodePtr node = make_node(); + ObjectPtr node = make_object(); // canonicalize. if (const Block* b = first.as()) { @@ -536,7 +536,7 @@ Stmt IfThenElse::make(Expr condition, Stmt then_case, Stmt else_case) { CHECK(then_case.defined()); // else_case may be null. - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->condition = std::move(condition); node->then_case = std::move(then_case); node->else_case = std::move(else_case); @@ -546,7 +546,7 @@ Stmt IfThenElse::make(Expr condition, Stmt then_case, Stmt else_case) { Stmt Evaluate::make(Expr value) { CHECK(value.defined()); - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->value = std::move(value); return Stmt(node); } diff --git a/src/lang/tensor.cc b/src/lang/tensor.cc index 1c110936b3ef1..e9ca89a4b31ed 100644 --- a/src/lang/tensor.cc +++ b/src/lang/tensor.cc @@ -47,7 +47,7 @@ Expr Tensor::operator()(Array indices) const { } Tensor Operation::output(size_t i) const { - auto node = make_node(); + auto node = make_object(); node->op = *this; node->value_index = i; node->dtype = (*this)->output_dtype(i); @@ -59,7 +59,7 @@ Tensor TensorNode::make(Array shape, DataType dtype, Operation op, int value_index) { - auto n = make_node(); + auto n = make_object(); n->shape = std::move(shape); n->dtype = dtype; n->op = op; @@ -87,7 +87,7 @@ TensorIntrin TensorIntrinNode::make(std::string name, Stmt body, Stmt reduce_init, Stmt reduce_update) { - auto n = make_node(); + auto n = make_object(); n->name = std::move(name); n->op = std::move(op); n->inputs = std::move(inputs); @@ -115,7 +115,7 @@ TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin, Array regions, Array reduce_axis, Array scalar_inputs) { - auto n = make_node(); + auto n = make_object(); n->intrin = std::move(intrin); n->tensors = std::move(tensors); n->regions = std::move(regions); diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 5a991aa3ad1bc..5e8a0f709f81f 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -79,7 +79,7 @@ class NodeIndexer : public AttrVisitor { // make index of all the children of node void MakeIndex(Object* node) { if (node == nullptr) return; - CHECK(node->IsInstance()); + CHECK(node->IsInstance()); if (node_index_.count(node)) return; CHECK_EQ(node_index_.size(), node_list_.size()); diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index bd129ac330582..c0cae269ffc3f 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -90,8 +90,8 @@ Tensor compute(Array shape, FCompute fcompute, std::string name, std::string tag, - Map attrs) { - auto op_node = make_node(); + Map attrs) { + auto op_node = make_object(); // compute dimension. size_t ndim = shape.size(); std::vector axis; @@ -112,8 +112,8 @@ Array compute(Array shape, FBatchCompute fcompute, std::string name, std::string tag, - Map attrs) { - auto op_node = make_node(); + Map attrs) { + auto op_node = make_object(); // compute dimension. size_t ndim = shape.size(); std::vector axis; @@ -136,13 +136,13 @@ Array compute(Array shape, Operation ComputeOpNode::make(std::string name, std::string tag, - Map attrs, + Map attrs, Array axis, Array body) { if (!attrs.defined()) { - attrs = Map(); + attrs = Map(); } - auto n = make_node(); + auto n = make_object(); n->name = std::move(name); n->tag = std::move(tag); n->attrs = std::move(attrs); @@ -161,7 +161,7 @@ Array ComputeOpNode::InputTensors() const { Array ret; std::unordered_set visited; for (auto& e : body) { - ir::PostOrderVisit(e, [&ret, &visited](const NodeRef& n) { + ir::PostOrderVisit(e, [&ret, &visited](const ObjectRef& n) { const ir::Call *call = n.as(); if (call != nullptr && call->func.defined()) { Tensor t = Downcast(call->func).output(call->value_index); @@ -188,7 +188,7 @@ Operation ComputeOpNode::ReplaceInputs( if (!new_reduce.same_as(this->body[0])) { const ir::Reduce* r = new_reduce.as(); for (size_t k = 0; k < this->body.size(); ++k) { - auto n = make_node(*r); + auto n = make_object(*r); n->value_index = static_cast(k); n->dtype = r->source[k].dtype(); arr.push_back(Expr(n)); @@ -215,7 +215,7 @@ void ComputeOpNode::PropBoundToInputs( const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const { CHECK_EQ(self.operator->(), this); - auto fvisit = [&dom_map, out_dom_map, analyzer](const NodeRef& n) { + auto fvisit = [&dom_map, out_dom_map, analyzer](const ObjectRef& n) { auto *call = n.as(); if (call != nullptr && call->func.defined()) { Tensor t = Downcast(call->func).output(call->value_index); @@ -574,7 +574,7 @@ class ComputeVerifier final : protected ir::IRVisitor { protected: /// Visitor implementation //@{ - void Visit(const NodeRef& n) final { + void Visit(const ObjectRef& n) final { ++level_; ir::IRVisitor::Visit(n); --level_; diff --git a/src/op/extern_op.cc b/src/op/extern_op.cc index 883ebdc4a0f76..b921c86f35561 100644 --- a/src/op/extern_op.cc +++ b/src/op/extern_op.cc @@ -57,15 +57,15 @@ Array ExternOpNode::output_shape(size_t i) const { Operation ExternOpNode::make(std::string name, std::string tag, - Map attrs, + Map attrs, Array inputs, Array input_placeholders, Array output_placeholders, Stmt body) { if (!attrs.defined()) { - attrs = Map(); + attrs = Map(); } - auto n = make_node(); + auto n = make_object(); n->name = std::move(name); n->tag = std::move(tag); n->attrs = std::move(attrs); @@ -93,7 +93,7 @@ Operation ExternOpNode::ReplaceInputs( const Operation& self, const std::unordered_map& rmap) const { CHECK_EQ(self.operator->(), this); - auto n = make_node(*this); + auto n = make_object(*this); n->body = op::ReplaceTensor(this->body, rmap); for (size_t i = 0; i < n->inputs.size(); ++i) { Tensor t = n->inputs[i]; @@ -161,7 +161,7 @@ Stmt ExternOpNode::BuildProvide( CHECK_EQ(stage->op.operator->(), this); Stmt ret = AttrStmt::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body); auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) { - Array bind_spec; + Array bind_spec; Array tuple; bind_spec.push_back(buffer); bind_spec.push_back(tensor); diff --git a/src/op/hybrid_op.cc b/src/op/hybrid_op.cc index 1e1a81423b69c..061929a31ef1e 100644 --- a/src/op/hybrid_op.cc +++ b/src/op/hybrid_op.cc @@ -63,14 +63,14 @@ Array HybridOpNode::output_shape(size_t i) const { Operation HybridOpNode::make(std::string name, std::string tag, - Map attrs, + Map attrs, Array inputs, Array outputs, Stmt body) { if (!attrs.defined()) { - attrs = Map(); + attrs = Map(); } - auto n = make_node(); + auto n = make_object(); n->name = std::move(name); n->tag = std::move(tag); n->attrs = std::move(attrs); @@ -91,7 +91,7 @@ Array HybridOpNode::InputTensors() const { } std::unordered_set visited; Array curr_inputs; - ir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const NodeRef& n) { + ir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const ObjectRef& n) { const ir::Call *call = n.as(); if (call != nullptr && call->func.defined()) { Tensor t = Downcast(call->func).output(call->value_index); @@ -108,7 +108,7 @@ Operation HybridOpNode::ReplaceInputs( const Operation &self, const std::unordered_map &rmap) const { CHECK_EQ(self.operator->(), this); - auto n = make_node(*this); + auto n = make_object(*this); n->body = op::ReplaceTensor(this->body, rmap); for (size_t i = 0; i < n->inputs.size(); ++i) { Tensor t = n->inputs[i]; @@ -185,7 +185,7 @@ Stmt HybridOpNode::BuildProvide( for (int i = 0; i < this->num_outputs(); ++i) { rmap[outputs[i]] = stage->op.output(i); } - auto n = make_node(*this); + auto n = make_object(*this); /* This is a story little bit complicated. * The following two lines of codes replace output tensors' usage. * This is the simplest way I (@were) can come up with to glue @@ -369,7 +369,8 @@ Stmt ApplyLoopAnnotations(const Stage &stage, expected = IterVarTypeToForType(attr->iter_type); } - PostOrderVisit(stmt, [&found, &var, &attr, &expected, &need_change](const NodeRef &node) { + PostOrderVisit(stmt, + [&found, &var, &attr, &expected, &need_change](const ObjectRef& node) { if (const For *op = node.as()) { if (op->loop_var.get() == var) { ++found; @@ -390,7 +391,7 @@ Stmt ApplyLoopOrder(const Stage &stage, const std::unordered_map &dom_map, const std::unordered_map &rebased, Stmt stmt) { std::vector current_order; - PostOrderVisit(stmt, [¤t_order](const NodeRef &node) { + PostOrderVisit(stmt, [¤t_order](const ObjectRef& node) { if (const For *op = node.as()) current_order.push_back(op->loop_var.get()); }); @@ -466,7 +467,7 @@ Stmt ApplySchedule(const Stage &stage, std::vector GatherLoopVars(Stmt stmt) { // TODO(@were): Write a comprehensive pass to analyze iter var types std::vector res_; - PostOrderVisit(stmt, [&res_](const NodeRef &node) { + PostOrderVisit(stmt, [&res_](const ObjectRef& node) { if (const For *op = node.as()) { Var loop_var(op->loop_var); Range dom = Range::make_by_min_extent(op->min, op->extent); diff --git a/src/op/placeholder_op.cc b/src/op/placeholder_op.cc index 6910f63b44d31..7863c8a522659 100644 --- a/src/op/placeholder_op.cc +++ b/src/op/placeholder_op.cc @@ -55,7 +55,7 @@ Array PlaceholderOpNode::output_shape(size_t i) const { Operation PlaceholderOpNode::make(std::string name, Array shape, DataType dtype) { - auto n = make_node(); + auto n = make_object(); n->name = name; n->shape = shape; n->dtype = dtype; diff --git a/src/op/scan_op.cc b/src/op/scan_op.cc index e83a23194cf83..57f16f82c54bf 100644 --- a/src/op/scan_op.cc +++ b/src/op/scan_op.cc @@ -64,16 +64,16 @@ Array ScanOpNode::output_shape(size_t i) const { Operation ScanOpNode::make(std::string name, std::string tag, - Map attrs, + Map attrs, IterVar axis, Array init, Array update, Array state_placeholder, Array inputs) { if (!attrs.defined()) { - attrs = Map(); + attrs = Map(); } - auto n = make_node(); + auto n = make_object(); CHECK_EQ(init.size(), update.size()); CHECK_EQ(init.size(), state_placeholder.size()); @@ -126,7 +126,7 @@ Array scan(Array init, Array inputs, std::string name, std::string tag, - Map attrs) { + Map attrs) { IterVar scan_axis = IterVarNode::make( Range::make_by_min_extent( @@ -157,7 +157,7 @@ Operation ScanOpNode::ReplaceInputs( const Operation& self, const std::unordered_map& rmap) const { CHECK_EQ(self.operator->(), this); - auto n = make_node(*this); + auto n = make_object(*this); for (size_t i = 0; i < n->init.size(); ++i) { if (rmap.count(n->init[i])) { n->init.Set(i, rmap.at(n->init[i])); diff --git a/src/op/tensor_compute_op.cc b/src/op/tensor_compute_op.cc index e59f90f4948e8..cfd6e23a0db4c 100644 --- a/src/op/tensor_compute_op.cc +++ b/src/op/tensor_compute_op.cc @@ -59,7 +59,7 @@ Operation TensorComputeOpNode::make(std::string name, Array tensors, Array regions, Array scalar_inputs) { - auto n = make_node(); + auto n = make_object(); n->name = std::move(name); n->tag = std::move(tag); n->axis = std::move(axis); @@ -80,8 +80,8 @@ Operation TensorComputeOpNode::ReplaceInputs( const Operation& self, const std::unordered_map& rmap) const { CHECK_EQ(self.operator->(), this); - auto n = make_node(*this); - auto intrin = make_node(*(this->intrin.operator->())); + auto n = make_object(*this); + auto intrin = make_object(*(this->intrin.operator->())); intrin->body = op::ReplaceTensor(this->intrin->body, rmap); if (intrin->reduce_init.defined()) { intrin->reduce_init = op::ReplaceTensor(this->intrin->reduce_init, rmap); @@ -146,7 +146,7 @@ Stmt TensorComputeOpNode::BuildProvide( Tensor tensor = inputs[i]; Region region = this->input_regions[i]; Buffer buffer = this->intrin->buffers[i]; - Array bind_spec{buffer, tensor}; + Array bind_spec{buffer, tensor}; Array tuple; for (size_t i = 0; i < region.size(); ++i) { @@ -162,7 +162,7 @@ Stmt TensorComputeOpNode::BuildProvide( for (int i = 0; i < this->num_outputs(); ++i) { Tensor tensor = stage->op.output(i); Buffer buffer = this->intrin->buffers[num_inputs + i]; - Array bind_spec{buffer, tensor}; + Array bind_spec{buffer, tensor}; Array tuple; for (size_t i = 0; i < this->axis.size(); ++i) { diff --git a/src/op/tensorize.cc b/src/op/tensorize.cc index b7f32de8b5add..7ab54e9830289 100644 --- a/src/op/tensorize.cc +++ b/src/op/tensorize.cc @@ -379,7 +379,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, for (size_t i = 0; i < intrin->inputs.size(); ++i) { Tensor tensor = inputs[i]; Buffer buffer = intrin->buffers[i]; - Array bind_spec{buffer, tensor}; + Array bind_spec{buffer, tensor}; auto it = in_region.find(tensor); CHECK(it != in_region.end()); const Array& region = it->second; @@ -407,7 +407,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, for (size_t i = intrin->inputs.size(); i < intrin->buffers.size(); ++i) { Tensor tensor = stage->op.output(i - intrin->inputs.size()); Buffer buffer = intrin->buffers[i]; - Array bind_spec{buffer, tensor}; + Array bind_spec{buffer, tensor}; output_bind_nest.emplace_back(AttrStmt::make( bind_spec, ir::attr::buffer_bind_scope, Call::make(DataType::Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop)); @@ -507,7 +507,7 @@ TVM_REGISTER_API("test.op.InferTensorizeRegion") stage, as_unordered_map(dmap), &out_dom, &in_region); - *ret = Array{Map(out_dom), + *ret = Array{Map(out_dom), Map >(in_region)}; }); diff --git a/src/pass/combine_context_call.cc b/src/pass/combine_context_call.cc index f1cb8fe10a4b6..e050fee98e678 100644 --- a/src/pass/combine_context_call.cc +++ b/src/pass/combine_context_call.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -108,7 +108,7 @@ class ContextCallCombiner final : public IRMutator { }; LoweredFunc CombineContextCall(LoweredFunc f) { - auto n = make_node(*f.operator->()); + auto n = make_object(*f.operator->()); n->body = ContextCallCombiner().Combine(n->body); return LoweredFunc(n); } diff --git a/src/pass/coproc_sync.cc b/src/pass/coproc_sync.cc index 4aa8879f679bd..a5b3285f7fa92 100644 --- a/src/pass/coproc_sync.cc +++ b/src/pass/coproc_sync.cc @@ -104,7 +104,7 @@ class CoProcSyncPlanner : public StorageAccessVisitor { } // Write synchronization to be inserted before or after stmt. - std::unordered_map > sync_; + std::unordered_map > sync_; protected: bool Enabled(const Variable* buf, @@ -229,8 +229,8 @@ class CoProcBarrierDetector : public StorageAccessVisitor { PlanWriteBarrier(scope_.back(), nullptr); } - std::unordered_map > barrier_before_; - std::unordered_map > barrier_after_; + std::unordered_map > barrier_before_; + std::unordered_map > barrier_after_; protected: bool Enabled(const Variable* buf, @@ -458,14 +458,14 @@ class CoProcInstDepDetector : public IRVisitor { // insert before is stored in reverse order // the first element is closest to the node. - std::unordered_map > insert_before_; - std::unordered_map > insert_after_; + std::unordered_map > insert_before_; + std::unordered_map > insert_after_; private: // state in the sync entry struct SyncState { // The statement of the state. - const Node* node{nullptr}; + const Object* node{nullptr}; // Set of all possible contexts in the entering moment. std::unordered_set enter_ctx; // Set of all possible contexts in the exit moment. @@ -679,8 +679,8 @@ class CoProcSyncInserter : public IRMutator { private: // insert before is stored in reverse order // the first element is closest to the node. - std::unordered_map > insert_before_; - std::unordered_map > insert_after_; + std::unordered_map > insert_before_; + std::unordered_map > insert_after_; }; diff --git a/src/pass/hoist_if_then_else.cc b/src/pass/hoist_if_then_else.cc index a1c635e2692b1..e3ffcc4f15f39 100644 --- a/src/pass/hoist_if_then_else.cc +++ b/src/pass/hoist_if_then_else.cc @@ -35,8 +35,8 @@ namespace tvm { namespace ir { -using HoistMap = std::unordered_map>; -using VarMap = std::unordered_map>; +using HoistMap = std::unordered_map>; +using VarMap = std::unordered_map>; /* * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant. @@ -124,12 +124,12 @@ class IfThenElseHoist { // Check whether a given IfThenElse stmt is the first one appearing // in a For stmt. bool is_first_if(const Stmt& for_stmt, const Stmt& if_stmt) { - std::vector if_node_list; + std::vector if_node_list; const For* for_node = for_stmt.as(); CHECK(for_node); CHECK(if_stmt.as()); - PostOrderVisit(for_node->body, [&](const NodeRef& node) { + PostOrderVisit(for_node->body, [&](const ObjectRef& node) { if (node.as()) { if_node_list.push_back(node.get()); } @@ -141,12 +141,12 @@ bool is_first_if(const Stmt& for_stmt, const Stmt& if_stmt) { // With this function we only need to visit and mutate top level For node // in the main VisitAndMutate function. Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) { - const Node* top_for_node; + const Object* top_for_node; const For* parent_for_node = parent_for_stmt.as(); CHECK(parent_for_node); CHECK(new_if_stmt.as()); - PostOrderVisit(parent_for_node->body, [&](const NodeRef& node) { + PostOrderVisit(parent_for_node->body, [&](const ObjectRef& node) { if (node.as()) { top_for_node = node.get(); } @@ -154,7 +154,7 @@ Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) { PackedFunc replace_target_for = PackedFunc( [&](TVMArgs args, TVMRetValue *ret){ - const NodeRef& current_for = args[0]; + const ObjectRef& current_for = args[0]; if (current_for.get() == top_for_node) { *ret = new_if_stmt; } @@ -173,7 +173,7 @@ std::pair RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { PackedFunc replace_then_case = PackedFunc( [&](TVMArgs args, TVMRetValue *ret){ - const NodeRef& node = args[0]; + const ObjectRef& node = args[0]; if (node == if_stmt) { *ret = node.as()->then_case; } @@ -181,7 +181,7 @@ std::pair RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { PackedFunc replace_else_case = PackedFunc( [&](TVMArgs args, TVMRetValue *ret){ - const NodeRef& node = args[0]; + const ObjectRef& node = args[0]; if (node == if_stmt) { *ret = node.as()->else_case; } @@ -199,13 +199,13 @@ std::pair RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { // Locate all For nodes and capture child IfThenElse nodes. void IfThenElseHoist::SelectCandidates(const Stmt& stmt) { - PostOrderVisit(stmt, [&](const NodeRef& node){ + PostOrderVisit(stmt, [&](const ObjectRef& node){ const For* for_node = node.as(); if (!for_node) return; std::queue tracker; tracker.push(for_node->body); - Stmt for_stmt = Downcast(node); + Stmt for_stmt = Downcast(node); for2if_map_.insert({for_stmt.get(), std::vector()}); while (!tracker.empty()) { Stmt head = tracker.front(); @@ -227,9 +227,9 @@ void IfThenElseHoist::SelectCandidates(const Stmt& stmt) { // Record condition variables. if (!cond_var_map_.count(head.get())) { - std::unordered_set new_var_set; + std::unordered_set new_var_set; cond_var_map_.insert({head.get(), new_var_set}); - PostOrderVisit(if_node->condition, [&](const NodeRef& cond_node) { + PostOrderVisit(if_node->condition, [&](const ObjectRef& cond_node) { if (cond_node.as()) { cond_var_map_[head.get()].insert(cond_node.get()); } @@ -239,15 +239,15 @@ void IfThenElseHoist::SelectCandidates(const Stmt& stmt) { continue; } } - ordered_for_list_.emplace_back(Downcast(node)); + ordered_for_list_.emplace_back(Downcast(node)); }); } // For each IfThenElse node, find the highest For node which // meets loop invariant condition. void IfThenElseHoist::LocateTopFor() { - std::unordered_map if_position_map; - std::unordered_set top_for_var_set; + std::unordered_map if_position_map; + std::unordered_set top_for_var_set; // Create IfThenElse -> For map. for (const Stmt& for_stmt : ordered_for_list_) { @@ -256,7 +256,7 @@ void IfThenElseHoist::LocateTopFor() { CHECK(for_node); top_for_var_map_.insert({for_node->loop_var.get(), if_list}); for (const Stmt& if_stmt : if_list) { - const Node* if_node = if_stmt.get(); + const Object* if_node = if_stmt.get(); if2for_map_[if_node].push_back(for_stmt); } } @@ -264,7 +264,7 @@ void IfThenElseHoist::LocateTopFor() { // Locate the highest For node which is loop invariant. for (const auto& item : if2for_map_) { Stmt top_for; - const Node* if_stmt = item.first; + const Object* if_stmt = item.first; std::vector for_list = item.second; for (size_t i = 0; i < for_list.size(); ++i) { const Stmt& for_stmt = for_list.at(i); @@ -291,9 +291,9 @@ void IfThenElseHoist::LocateTopFor() { top_for_var_set.insert(item.second.as()->loop_var.get()); } - std::vector removed_for_var_list; + std::vector removed_for_var_list; for (const auto& item : top_for_var_map_) { - const Node* top_for_var = item.first; + const Object* top_for_var = item.first; std::vector if_list = item.second; if (!top_for_var_set.count(top_for_var)) { removed_for_var_list.push_back(top_for_var); @@ -307,7 +307,7 @@ void IfThenElseHoist::LocateTopFor() { top_for_var_map_[top_for_var] = actual_if_list; } } - for (const Node* top_for_var : removed_for_var_list) { + for (const Object* top_for_var : removed_for_var_list) { top_for_var_map_.erase(top_for_var); } } @@ -374,7 +374,7 @@ Stmt IfThenElseHoist::HoistIf(const Stmt& if_stmt) { Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) { PackedFunc replace_top_for = PackedFunc( [&](TVMArgs args, TVMRetValue *ret){ - const NodeRef& current_for = args[0]; + const ObjectRef& current_for = args[0]; const For* for_node = current_for.as(); if (!for_node) return; diff --git a/src/pass/infer_fragment.cc b/src/pass/infer_fragment.cc index 71da645474b06..13f9ebade9b10 100644 --- a/src/pass/infer_fragment.cc +++ b/src/pass/infer_fragment.cc @@ -214,7 +214,7 @@ Stmt InferFragment(Stmt stmt) { LoweredFunc InferFragment(LoweredFunc f) { CHECK_NE(f->func_type, kHostFunc); - auto n = make_node(*f.operator->()); + auto n = make_object(*f.operator->()); n->body = InferFragment(f->body); return LoweredFunc(n); } diff --git a/src/pass/inject_virtual_thread.cc b/src/pass/inject_virtual_thread.cc index c80c7fcdaa8c5..7e7af187dce12 100644 --- a/src/pass/inject_virtual_thread.cc +++ b/src/pass/inject_virtual_thread.cc @@ -37,7 +37,7 @@ class ExprTouched final : public IRVisitor { bool check_write) : touched_var_(touched), check_write_(check_write) {} - void Visit(const NodeRef& n) final { + void Visit(const ObjectRef& n) final { // early stopping if (expr_touched_ && !check_write_) return; IRVisitor::Visit(n); diff --git a/src/pass/ir_deep_compare.cc b/src/pass/ir_deep_compare.cc index e399e7f2c54f6..6a61d5e402f94 100644 --- a/src/pass/ir_deep_compare.cc +++ b/src/pass/ir_deep_compare.cc @@ -358,7 +358,7 @@ class IRDeepCompare : return order_; } - int CompareNodeRef(const NodeRef& lhs, const NodeRef& rhs) { + int CompareNodeRef(const ObjectRef& lhs, const ObjectRef& rhs) { if (order_ != 0) return order_; if (lhs.get() < rhs.get()) { order_ = -1; return order_; diff --git a/src/pass/ir_util.cc b/src/pass/ir_util.cc index 8b6e66135235b..cdc708ce5faf2 100644 --- a/src/pass/ir_util.cc +++ b/src/pass/ir_util.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -31,38 +31,38 @@ Stmt MergeNest(const std::vector& nest, Stmt body) { for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) { Stmt s = *ri; if (const auto* for_ = s.as()) { - auto n = make_node(*for_); + auto n = make_object(*for_); CHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); } else if (const auto* let = s.as()) { - auto n = make_node(*let); + auto n = make_object(*let); CHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); } else if (const auto* attr = s.as()) { - auto n = make_node(*attr); + auto n = make_object(*attr); CHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); } else if (const auto* ite = s.as()) { - auto n = make_node(*ite); + auto n = make_object(*ite); CHECK(is_no_op(n->then_case)); CHECK(!n->else_case.defined()); n->then_case = body; body = Stmt(n); } else if (const auto* block = s.as()) { - auto n = make_node(*block); + auto n = make_object(*block); CHECK(is_no_op(n->rest)); n->rest = body; body = Stmt(n); } else if (const auto* assert_ = s.as()) { - auto n = make_node(*assert_); + auto n = make_object(*assert_); CHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); } else if (const auto* alloc = s.as()) { - auto n = make_node(*alloc); + auto n = make_object(*alloc); CHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); diff --git a/src/pass/ir_visitor.cc b/src/pass/ir_visitor.cc index d6f163ccedc61..467cd5de2ef71 100644 --- a/src/pass/ir_visitor.cc +++ b/src/pass/ir_visitor.cc @@ -29,9 +29,9 @@ namespace ir { // visitor to implement apply class IRApplyVisit : public IRVisitor { public: - explicit IRApplyVisit(std::function f) : f_(f) {} + explicit IRApplyVisit(std::function f) : f_(f) {} - void Visit(const NodeRef& node) final { + void Visit(const ObjectRef& node) final { if (visited_.count(node.get()) != 0) return; visited_.insert(node.get()); IRVisitor::Visit(node); @@ -39,11 +39,11 @@ class IRApplyVisit : public IRVisitor { } private: - std::function f_; - std::unordered_set visited_; + std::function f_; + std::unordered_set visited_; }; -void PostOrderVisit(const NodeRef& node, std::function fvisit) { +void PostOrderVisit(const ObjectRef& node, std::function fvisit) { IRApplyVisit(fvisit).Visit(node); } diff --git a/src/pass/lift_attr_scope.cc b/src/pass/lift_attr_scope.cc index cfc6e5a7fc687..7f5b4cca0bb41 100644 --- a/src/pass/lift_attr_scope.cc +++ b/src/pass/lift_attr_scope.cc @@ -54,7 +54,7 @@ class AttrScopeLifter : public IRMutator { Stmt body = AttrStmt::make( attr_node_, attr_key_, attr_value_, op->body); // undefine them - attr_node_ = NodeRef(); + attr_node_ = ObjectRef(); attr_value_ = Expr(); return Allocate::make( op->buffer_var, op->dtype, @@ -93,7 +93,7 @@ class AttrScopeLifter : public IRMutator { return IRMutator::Mutate_(op, s); } Stmt then_case = this->Mutate(op->then_case); - NodeRef first_node; + ObjectRef first_node; Expr first_value; std::swap(first_node, attr_node_); std::swap(first_value, attr_value_); @@ -119,7 +119,7 @@ class AttrScopeLifter : public IRMutator { else_case = AttrStmt::make( attr_node_, attr_key_, attr_value_, else_case); // undefine them - attr_node_ = NodeRef(); + attr_node_ = ObjectRef(); attr_value_ = Expr(); } if (then_case.same_as(op->then_case) && @@ -149,11 +149,11 @@ class AttrScopeLifter : public IRMutator { std::vector MutateSeq(const std::vector& seq) { std::vector res_seq; - NodeRef curr_node; + ObjectRef curr_node; Expr curr_value; Stmt curr_stmt; for (const Stmt & stmt : seq) { - attr_node_ = NodeRef(); + attr_node_ = ObjectRef(); attr_value_ = Expr(); Stmt rest = this->Mutate(stmt); if (attr_node_.defined() && @@ -188,7 +188,7 @@ class AttrScopeLifter : public IRMutator { } res_seq.push_back(curr_stmt); // reset - attr_node_ = NodeRef(); + attr_node_ = ObjectRef(); attr_value_ = Expr(); } return res_seq; @@ -209,7 +209,7 @@ class AttrScopeLifter : public IRMutator { } std::string attr_key_; - NodeRef attr_node_; + ObjectRef attr_node_; Expr attr_value_; }; diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index 1ac386767ae3a..e68387f1baad3 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -37,10 +37,10 @@ using arith::IntSet; using arith::DeduceBound; using arith::Intersect; -using PartitionKey = std::pair; +using PartitionKey = std::pair; struct PartitionKeyHash { std::size_t operator()(PartitionKey const& k) const noexcept { - std::size_t h1 = std::hash{}(k.first); + std::size_t h1 = std::hash{}(k.first); std::size_t h2 = std::hash{}(k.second); return h1 ^ h2; } @@ -53,7 +53,7 @@ using Partition = std::unordered_map; bool ExprUseVars(Expr expr, const std::unordered_set& vars) { bool success = false; - PostOrderVisit(expr, [&vars, &success](const NodeRef& node) { + PostOrderVisit(expr, [&vars, &success](const ObjectRef& node) { if (const Variable* v = node.as()) { if (vars.count(v)) { success = true; @@ -138,7 +138,7 @@ class CandidateSelector final : public IRVisitor { } } - std::unordered_set candidates; + std::unordered_set candidates; private: bool in_likely_{false}; @@ -257,7 +257,7 @@ class PartitionFinder : public IRVisitor { // Replace the set of conditions given by ps with cond_value (true or false) class ConditionEliminator : public IRMutator { public: - explicit ConditionEliminator(const std::unordered_set& ps, bool cond_value = true) + explicit ConditionEliminator(const std::unordered_set& ps, bool cond_value = true) : ps_(ps), cond_value_(cond_value) {} using IRMutator::Mutate; @@ -269,7 +269,7 @@ class ConditionEliminator : public IRMutator { } private: - std::unordered_set ps_; + std::unordered_set ps_; bool cond_value_; }; @@ -277,7 +277,7 @@ class ConditionEliminator : public IRMutator { // Insert the partition branch at the innermost thread scope class ThreadPartitionInserter : public IRMutator { public: - explicit ThreadPartitionInserter(const std::unordered_set& ps, + explicit ThreadPartitionInserter(const std::unordered_set& ps, Expr cond) : ps_(ps), cond_(cond), innermost_thread_scope_(false) {} Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { @@ -299,7 +299,7 @@ class ThreadPartitionInserter : public IRMutator { } private: - const std::unordered_set& ps_; + const std::unordered_set& ps_; Expr cond_; bool innermost_thread_scope_; }; @@ -364,15 +364,15 @@ class LoopPartitioner : public IRMutator { } private: - Stmt TryPartition(const Node* op, const Stmt& stmt, VarExpr var, + Stmt TryPartition(const Object* op, const Stmt& stmt, VarExpr var, Expr min, Expr max, Stmt body, bool partition_thread_scope); - std::pair> + std::pair> GetIntervalAndCondset(const Partition &partitions, const arith::IntervalSet &for_interval, bool cond_value); - inline Stmt MakeFor(const Node* op, Expr extent, Stmt body); + inline Stmt MakeFor(const Object* op, Expr extent, Stmt body); /* Candidate IRs that may be partitioned potentially */ std::unordered_map hint_map_; @@ -383,12 +383,12 @@ class LoopPartitioner : public IRMutator { // Returns an interval (in the first component) in which all the conditions // given in the second component provably have value given by cond_value -std::pair> +std::pair> LoopPartitioner::GetIntervalAndCondset(const Partition &partitions, const arith::IntervalSet &for_interval, bool cond_value) { Array sets; - std::unordered_set cond_set; + std::unordered_set cond_set; for (const auto &kv : partitions) { if (kv.first.second == cond_value) { @@ -461,7 +461,7 @@ Stmt AppendStmts(const Stmt& a, const Stmt& b) { * which will eventually be simplified to empty code. And because only one loop was generated * from loop 2 we stop recursing. */ -Stmt LoopPartitioner::TryPartition(const Node* node, +Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, VarExpr var, Expr min, @@ -481,7 +481,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, arith::IntervalSet for_interval(min, max); bool cond_value; IntSet middle_interval; - std::unordered_set cond_set; + std::unordered_set cond_set; // find an interval in which all conditions on var are true std::tie(middle_interval, cond_set) = GetIntervalAndCondset(finder.partitions, for_interval, true); @@ -592,7 +592,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, return s; } -inline Stmt LoopPartitioner::MakeFor(const Node *node, Expr extent, Stmt body) { +inline Stmt LoopPartitioner::MakeFor(const Object *node, Expr extent, Stmt body) { const For *for_node = static_cast(node); CHECK(for_node); if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1))) { diff --git a/src/pass/lower_custom_datatypes.cc b/src/pass/lower_custom_datatypes.cc index e24cddd97f254..c45019ab38b82 100644 --- a/src/pass/lower_custom_datatypes.cc +++ b/src/pass/lower_custom_datatypes.cc @@ -130,7 +130,7 @@ class CustomDatatypesLowerer : public IRMutator { }; LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target) { - auto n = make_node(*f.operator->()); + auto n = make_object(*f.operator->()); n->body = CustomDatatypesLowerer(target).Mutate(n->body); return LoweredFunc(n); } diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc index f0b0b3c36d42b..dd81826a59889 100644 --- a/src/pass/lower_intrin.cc +++ b/src/pass/lower_intrin.cc @@ -282,7 +282,7 @@ Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) { LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target) { - auto n = make_node(*f.operator->()); + auto n = make_object(*f.operator->()); n->body = LowerIntrinStmt(n->body, target); return LoweredFunc(n); } diff --git a/src/pass/lower_thread_allreduce.cc b/src/pass/lower_thread_allreduce.cc index 2a121180d6958..03470271b0295 100644 --- a/src/pass/lower_thread_allreduce.cc +++ b/src/pass/lower_thread_allreduce.cc @@ -338,7 +338,7 @@ class ThreadAllreduceBuilder final : public IRMutator { LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size) { CHECK_NE(f->func_type, kHostFunc); - auto n = make_node(*f.operator->()); + auto n = make_object(*f.operator->()); n->body = ThreadAllreduceBuilder(warp_size).Mutate(n->body); return LoweredFunc(n); } diff --git a/src/pass/lower_tvm_builtin.cc b/src/pass/lower_tvm_builtin.cc index c8c8fa9c62d02..9a33d647b683d 100644 --- a/src/pass/lower_tvm_builtin.cc +++ b/src/pass/lower_tvm_builtin.cc @@ -360,7 +360,7 @@ class BuiltinLower : public IRMutator { }; LoweredFunc LowerTVMBuiltin(LoweredFunc f) { - auto n = make_node(*f.operator->()); + auto n = make_object(*f.operator->()); n->body = BuiltinLower().Build(n->body); return LoweredFunc(n); } diff --git a/src/pass/lower_warp_memory.cc b/src/pass/lower_warp_memory.cc index 0ed2b6232fc12..0749127b905b8 100644 --- a/src/pass/lower_warp_memory.cc +++ b/src/pass/lower_warp_memory.cc @@ -380,7 +380,7 @@ class WarpMemoryRewriter : private IRMutator { LoweredFunc LowerWarpMemory(LoweredFunc f, int warp_size) { CHECK_EQ(f->func_type, kDeviceFunc); - auto n = make_node(*f.operator->()); + auto n = make_object(*f.operator->()); n->body = WarpMemoryRewriter(warp_size).Rewrite(n->body); return LoweredFunc(n); } diff --git a/src/pass/make_api.cc b/src/pass/make_api.cc index 74b8f891299a2..b0f9482545d35 100644 --- a/src/pass/make_api.cc +++ b/src/pass/make_api.cc @@ -42,7 +42,7 @@ inline Stmt MakeAssertEQ(Expr lhs, Expr rhs, std::string msg) { LoweredFunc MakeAPI(Stmt body, std::string name, - Array api_args, + Array api_args, int num_unpacked_args, bool is_restricted) { const Stmt nop = Evaluate::make(0); @@ -168,7 +168,7 @@ LoweredFunc MakeAPI(Stmt body, buf_arg.second, buf_arg.second->name_hint); } - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->name = name; n->args = args; n->handle_data_type = binder.def_handle_dtype(); @@ -266,7 +266,7 @@ class DeviceTypeBinder: public IRMutator { LoweredFunc BindDeviceType(LoweredFunc f, int device_type) { - auto n = make_node(*f.operator->()); + auto n = make_object(*f.operator->()); n->body = DeviceTypeBinder(device_type).Mutate(n->body); return LoweredFunc(n); } diff --git a/src/pass/remap_thread_axis.cc b/src/pass/remap_thread_axis.cc index f3f0d009573d7..49d92d027193e 100644 --- a/src/pass/remap_thread_axis.cc +++ b/src/pass/remap_thread_axis.cc @@ -85,7 +85,7 @@ RemapThreadAxis(LoweredFunc f, Map thread_map) { } CHECK_EQ(f->func_type, kDeviceFunc); - auto n = make_node(*f.operator->()); + auto n = make_object(*f.operator->()); // replace the thread axis for (size_t i = 0; i < n->thread_axis.size(); ++i) { auto it = tmap.find(n->thread_axis[i]->thread_tag); diff --git a/src/pass/simple_passes.cc b/src/pass/simple_passes.cc index 06579f31e17a3..1159e568f519a 100644 --- a/src/pass/simple_passes.cc +++ b/src/pass/simple_passes.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -31,7 +31,7 @@ namespace ir { class IRSideEffect : public IRVisitor { public: - void Visit(const NodeRef& e) final { + void Visit(const ObjectRef& e) final { if (has_side_effect_) return; IRVisitor::Visit(e); } @@ -103,7 +103,7 @@ Expr Substitute(Expr expr, const Map& value_map) { class VarTouchVisitor : public IRVisitor { public: - void Visit(const NodeRef& e) final { + void Visit(const ObjectRef& e) final { if (use_var_) return; IRVisitor::Visit(e); } diff --git a/src/pass/skip_assert.cc b/src/pass/skip_assert.cc index 5f310a61dfe3e..817416d9fd2cb 100644 --- a/src/pass/skip_assert.cc +++ b/src/pass/skip_assert.cc @@ -38,7 +38,7 @@ Stmt SkipAssert(Stmt stmt) { } LoweredFunc SkipAssert(LoweredFunc f) { - auto n = make_node(*f.operator->()); + auto n = make_object(*f.operator->()); n->body = SkipAssert(f->body); return LoweredFunc(n); } diff --git a/src/pass/split_host_device.cc b/src/pass/split_host_device.cc index 5076300d968a1..f045c271456cb 100644 --- a/src/pass/split_host_device.cc +++ b/src/pass/split_host_device.cc @@ -176,8 +176,8 @@ class HostDeviceSplitter : public IRMutator { handle_data_type_[kv.first.get()] = kv.second; } name_ = f->name; - NodePtr n = - make_node(*f.operator->()); + ObjectPtr n = + make_object(*f.operator->()); n->body = this->Mutate(f->body); n->func_type = kHostFunc; Array ret{LoweredFunc(n)}; @@ -191,7 +191,7 @@ class HostDeviceSplitter : public IRMutator { Stmt SplitDeviceFunc(Stmt body) { std::ostringstream os; os << name_ << "_kernel" << device_funcs_.size(); - NodePtr n = make_node(); + ObjectPtr n = make_object(); // isolate the device function. IRUseDefAnalysis m; m.visit_thread_extent_ = false; diff --git a/src/pass/ssa.cc b/src/pass/ssa.cc index 0fff1e6e67744..37db29c580796 100644 --- a/src/pass/ssa.cc +++ b/src/pass/ssa.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -38,7 +38,7 @@ class IRVerifySSA final : public IRVisitor { public: bool is_ssa{true}; - void Visit(const NodeRef& n) final { + void Visit(const ObjectRef& n) final { if (!is_ssa) return; IRVisitor::Visit(n); } diff --git a/src/pass/storage_access.cc b/src/pass/storage_access.cc index c146a8709b1ec..bf8d4e0205217 100644 --- a/src/pass/storage_access.cc +++ b/src/pass/storage_access.cc @@ -341,7 +341,7 @@ Stmt LowerStorageAccessInfo(Stmt stmt) { } LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc f) { - auto n = make_node(*f.operator->()); + auto n = make_object(*f.operator->()); n->body = LowerStorageAccessInfo(f->body); return LoweredFunc(n); } diff --git a/src/pass/storage_access.h b/src/pass/storage_access.h index 028645b78640e..302ca929581d6 100644 --- a/src/pass/storage_access.h +++ b/src/pass/storage_access.h @@ -71,7 +71,7 @@ class StorageAccessVisitor : public IRVisitor { /*! \brief Access pattern about a single statement */ struct StmtEntry { /*! \brief The statement */ - const Node* stmt; + const Object* stmt; /*! \brief access patterns in the statement */ std::vector access; }; diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc index d6dde29a519d8..2df2672adcb1f 100644 --- a/src/pass/storage_flatten.cc +++ b/src/pass/storage_flatten.cc @@ -402,7 +402,7 @@ class StorageFlattener : public IRMutator { // We do support a few relaxed case, such as bindingx // region with shape [1, 1, n, m] to buffer with shape [n, m] Stmt HandleBufferBindScope(const AttrStmt* op) { - Array arr = Downcast > (op->node); + Array arr = Downcast > (op->node); CHECK_EQ(arr.size(), 2U); const BufferNode* buffer = arr[0].as(); const TensorNode* tensor = arr[1].as(); @@ -512,7 +512,7 @@ class StorageFlattener : public IRMutator { // Dimension alignment std::unordered_map > dim_align_; // Storage scope - std::unordered_map storage_scope_; + std::unordered_map storage_scope_; // The current thread scope. std::vector curr_thread_scope_; // Collects shapes. diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 12a06da8007f6..01c6f983d692b 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -59,7 +59,7 @@ class LinearAccessPatternFinder final : public IRVisitor { /*! \brief record the touch hist of statment. */ struct StmtEntry { // The statment - const Node* stmt; + const Object* stmt; // The index in the linear_seq_ to point to end of the nested scope. // This is only set to non-zero if stmt is a nested scope. // if offset > 0, means this is the begin, the end entry is current_index + offset @@ -236,7 +236,7 @@ class LinearAccessPatternFinder final : public IRVisitor { // class InplaceOpVerifier : public IRVisitor { public: - bool Check(const Node* stmt, + bool Check(const Object* stmt, const Variable* dst, const Variable* src) { dst_ = dst; @@ -258,7 +258,7 @@ class InplaceOpVerifier : public IRVisitor { using IRVisitor::Visit_; - void Visit(const NodeRef& e) final { + void Visit(const ObjectRef& e) final { if (!result_) return; IRVisitor::Visit(e); } @@ -471,7 +471,7 @@ class StoragePlanRewriter : public IRMutator { // The scope that this alloc attaches after // For shared/local memory it is beginning of the thread extent. // for global memory it is nullptr, means beginning of everything. - const Node* attach_scope_{nullptr}; + const Object* attach_scope_{nullptr}; // The constant size of the buffer in bits, only used if it is constant uint64_t const_nbits{0}; // The storage scope. @@ -695,7 +695,7 @@ class StoragePlanRewriter : public IRMutator { } } } - void PlanNewScope(const Node* op) { + void PlanNewScope(const Object* op) { if (thread_scope_ != nullptr) { CHECK(thread_scope_ == op); // erase all memory atatched to this scope. @@ -808,7 +808,7 @@ class StoragePlanRewriter : public IRMutator { } // Allocate new storage entry. StorageEntry* NewAlloc(const Allocate* op, - const Node* attach_scope, + const Object* attach_scope, const StorageScope& scope, size_t const_nbits) { CHECK(op != nullptr); @@ -824,7 +824,7 @@ class StoragePlanRewriter : public IRMutator { } StorageEntry* FindAlloc(const Allocate* op, - const Node* attach_scope, + const Object* attach_scope, const StorageScope& scope) { CHECK(op != nullptr); // skip plan for local variable, @@ -908,17 +908,17 @@ class StoragePlanRewriter : public IRMutator { } } // thread scope. - const Node* thread_scope_{nullptr}; + const Object* thread_scope_{nullptr}; // whether enable inplace detection. bool detect_inplace_{false}; // Locations of free ops. - std::unordered_map event_map_; + std::unordered_map event_map_; // constant size free map. std::multimap const_free_map_; // symbolic free list, for non constant items. std::list sym_free_list_; // The allocation attach map - std::unordered_map > attach_map_; + std::unordered_map > attach_map_; // The allocation assign map std::unordered_map alloc_map_; // The allocations @@ -987,7 +987,7 @@ class VectorAllocRewriter : public IRMutator { LoweredFunc PointerValueTypeRewrite(LoweredFunc f) { - auto n = make_node(*f.operator->()); + auto n = make_object(*f.operator->()); VectorAllocRewriter rewriter; n->body = rewriter.Mutate(n->body); for (Var arg : f->args) { diff --git a/src/pass/storage_sync.cc b/src/pass/storage_sync.cc index 018a6bb2e79ef..0f8bef8383f2e 100644 --- a/src/pass/storage_sync.cc +++ b/src/pass/storage_sync.cc @@ -39,7 +39,7 @@ class ThreadSyncPlanner : public StorageAccessVisitor { : sync_scope_(sync_scope) {} // The syncs inserted before each statement - std::unordered_set syncs_inserted_; + std::unordered_set syncs_inserted_; protected: bool Enabled(const Variable* buf, @@ -200,7 +200,7 @@ class ThreadSyncPlanner : public StorageAccessVisitor { class ThreadSyncInserter : public IRMutator { public: ThreadSyncInserter(StorageScope sync_scope, - const std::unordered_set& syncs) + const std::unordered_set& syncs) : sync_scope_(sync_scope), syncs_(syncs) {} Stmt Mutate(Stmt stmt) final { @@ -346,11 +346,11 @@ class ThreadSyncInserter : public IRMutator { } // data structure. StorageScope sync_scope_; - const std::unordered_set& syncs_; + const std::unordered_set& syncs_; // The storage scope of each buffer std::unordered_map storage_scope_; // The read write statistics of storage - std::unordered_map rw_stats_; + std::unordered_map rw_stats_; // The statistics for global barrier bool in_thread_env_{false}; // memorized results @@ -369,7 +369,7 @@ Stmt ThreadSync(Stmt stmt, std::string storage_scope) { LoweredFunc ThreadSync(LoweredFunc f, std::string storage_scope) { CHECK_NE(f->func_type, kHostFunc); - auto n = make_node(*f.operator->()); + auto n = make_object(*f.operator->()); n->body = ThreadSync(f->body, storage_scope); return LoweredFunc(n); } diff --git a/src/pass/tensor_core.cc b/src/pass/tensor_core.cc index 2ead2b934d7e7..8dcc0e49e1195 100644 --- a/src/pass/tensor_core.cc +++ b/src/pass/tensor_core.cc @@ -225,9 +225,9 @@ class MMAMatcher: public IRVisitor { } std::unordered_map buf_map_; - std::unordered_map storage_scope_; + std::unordered_map storage_scope_; std::unordered_map> mma_sync_; - std::unordered_map buf_name_; + std::unordered_map buf_name_; std::unordered_set frag_reg_; bool matched_{false}; bool tensor_core_on_{false}; @@ -365,7 +365,7 @@ class ScheduleAnalyser { std::unordered_map matrix_abc_; std::unordered_map matrix_major_; std::unordered_map> mma_sync_; - std::unordered_map buf_name_; + std::unordered_map buf_name_; }; // IndexVisitor visits access index of fragment @@ -745,7 +745,7 @@ class BufferAnalyser : public IRVisitor { std::unordered_map buf_map_; std::unordered_map > dim_align_; - std::unordered_map storage_scope_; + std::unordered_map storage_scope_; std::unordered_map matrix_abc_; std::unordered_map matrix_major_; std::unordered_set frag_reg_; @@ -868,9 +868,9 @@ class TensorCoreIRMutator : public IRMutator { Expr c = operands[2]; auto cc = c.as(); - NodePtr buffer_node_a = make_node(); - NodePtr buffer_node_b = make_node(); - NodePtr buffer_node_c = make_node(); + ObjectPtr buffer_node_a = make_object(); + ObjectPtr buffer_node_b = make_object(); + ObjectPtr buffer_node_c = make_object(); auto mma_sync_call = [&buffer_node_a, &buffer_node_b] @@ -921,7 +921,7 @@ class TensorCoreIRMutator : public IRMutator { Call::Intrinsic)); }; - NodePtr buffer_node = make_node(); + ObjectPtr buffer_node = make_object(); return add_buffer_bind_scope_(call, buffer_node, TensorKey{call->func, call->value_index}, fill_fragment_call, call->dtype); @@ -971,7 +971,7 @@ class TensorCoreIRMutator : public IRMutator { Call::Intrinsic)); }; - NodePtr buffer_node = make_node(); + ObjectPtr buffer_node = make_object(); return add_buffer_bind_scope_(call, buffer_node, TensorKey{op->func, op->value_index}, load_matrix_call, call->dtype); @@ -1011,7 +1011,7 @@ class TensorCoreIRMutator : public IRMutator { Call::Intrinsic)); }; - NodePtr buffer_node = make_node(); + ObjectPtr buffer_node = make_object(); return add_buffer_bind_scope_(call, buffer_node, TensorKey{call->func, call->value_index}, store_matrix_call, call->dtype); @@ -1073,7 +1073,7 @@ class TensorCoreIRMutator : public IRMutator { } Stmt add_buffer_bind_scope_(const Call* call, - const NodePtr &buffer_node, const TensorKey &key, + const ObjectPtr &buffer_node, const TensorKey &key, const std::function &call_back, DataType datatype) { auto it = bounds_.find(key); @@ -1124,7 +1124,7 @@ class TensorCoreIRMutator : public IRMutator { buffer_node->offset_factor = 1; Buffer buffer(buffer_node); - NodePtr tensor_node = make_node(); + ObjectPtr tensor_node = make_object(); tensor_node->value_index = key.value_index; tensor_node->op = Downcast(key.f); tensor_node->shape = shape; @@ -1140,7 +1140,7 @@ class TensorCoreIRMutator : public IRMutator { intrinsic::tvm_tuple, args, Call::Intrinsic); - Array node = {buffer, tensor}; + Array node = {buffer, tensor}; return AttrStmt::make(node, "buffer_bind_scope", tuple, diff --git a/src/pass/verify_memory.cc b/src/pass/verify_memory.cc index 1d7bb3d8425ba..4a5c8adeb8e7b 100644 --- a/src/pass/verify_memory.cc +++ b/src/pass/verify_memory.cc @@ -64,7 +64,7 @@ class MemoryAccessVerifier final : protected IRVisitor { protected: /// Visitor implementation //@{ - void Visit(const NodeRef &n) final { + void Visit(const ObjectRef &n) final { if (Failed()) return; IRVisitor::Visit(n); } diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 780e19bd017f9..102e4c2997742 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -256,7 +256,7 @@ class RelayBuildModule : public runtime::ModuleNode { relay::Function func, const std::unordered_map& params) { std::unordered_map name_dict; - std::unordered_set repeat_var; + std::unordered_set repeat_var; for (auto arg : func->params) { const auto &name = arg->name_hint(); if (name_dict.count(name)) { @@ -266,7 +266,7 @@ class RelayBuildModule : public runtime::ModuleNode { } } - std::unordered_map bind_dict; + std::unordered_map bind_dict; for (auto &kv : params) { if (name_dict.count(kv.first) == 0) { continue; diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 7c33ac9ed61a3..68a3bed3bc4b6 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -51,7 +51,7 @@ TVM_REGISTER_NODE_TYPE(CCacheValueNode); TVM_REGISTER_OBJECT_TYPE(CompileEngineNode); CCacheKey CCacheKeyNode::make(Function source_func, Target target) { - auto n = make_node(); + auto n = make_object(); n->source_func = std::move(source_func); n->target = std::move(target); return CCacheKey(n); @@ -109,7 +109,7 @@ class ScheduleGetter : std::pair Create(const Function& prim_func) { static auto fschedule = Op::GetAttr("FTVMSchedule"); - auto cache_node = make_node(); + auto cache_node = make_object(); cache_node->target = target_; for (Var param : prim_func->params) { Array inputs; @@ -330,7 +330,7 @@ class ScheduleGetter : Attrs master_attrs_; int master_op_pattern_{0}; std::ostringstream readable_name_stream_; - std::unordered_map, NodeHash, NodeEqual> memo_; + std::unordered_map, ObjectHash, ObjectEqual> memo_; Array scalars_; // Cache device copy op for equivalence checking to reduce registry lookup // overhead for each invocation of call node when retrieving schedules. @@ -380,7 +380,7 @@ class MakeShapeFunc : public ExprFunctor(const Expr&)> { param_shapes_[param] = shape_inputs; } readable_name_stream_ << "shape_func"; - auto cache_node = make_node(); + auto cache_node = make_object(); cache_node->outputs = VisitExpr(prim_func->body); auto candidate_name = readable_name_stream_.str(); constexpr static size_t kMaxFuncNameLength = 80; @@ -574,13 +574,13 @@ class MakeShapeFunc : public ExprFunctor(const Expr&)> { /*! \brief String stream for function name */ std::ostringstream readable_name_stream_; /*! \brief Map from parameter to its shape function usage state */ - std::unordered_map param_states_; + std::unordered_map param_states_; /*! \brief Map from parameter to list of data placeholder */ - std::unordered_map, NodeHash, NodeEqual> param_data_; + std::unordered_map, ObjectHash, ObjectEqual> param_data_; /*! \brief Map from parameter to list of shape placeholder */ - std::unordered_map, NodeHash, NodeEqual> param_shapes_; + std::unordered_map, ObjectHash, ObjectEqual> param_shapes_; /*! \brief Memoized visit result */ - std::unordered_map, NodeHash, NodeEqual> memo_; + std::unordered_map, ObjectHash, ObjectEqual> memo_; /*! \brief Stack of data dependencies for shape function */ std::vector data_dependants_; /*! \brief Scalars used in the shape function */ @@ -656,9 +656,9 @@ class CompileEngineImpl : public CompileEngineNode { cache_.clear(); } // List all items in the cache. - Array ListItems() { + Array ListItems() { std::lock_guard lock(mutex_); - Array items; + Array items; for (auto& kv : cache_) { items.push_back(kv.first); items.push_back(kv.second); @@ -688,14 +688,14 @@ class CompileEngineImpl : public CompileEngineNode { if (it->second->cached_func.defined()) return it->second; value = it->second; } else { - value = CCacheValue(make_node()); + value = CCacheValue(make_object()); value->use_count = 0; cache_[key] = value; } // No need to lower external functions for now. We will invoke the external // codegen tool once and lower all functions together. if (!key->source_func->UseDefaultCompiler()) { - auto cache_node = make_node(); + auto cache_node = make_object(); const auto name_node = FunctionGetAttr(key->source_func, attr::kExternalSymbol).as(); CHECK(name_node != nullptr) << "External function has not been attached a name yet."; @@ -709,7 +709,7 @@ class CompileEngineImpl : public CompileEngineNode { CHECK(!value->cached_func.defined()); auto spair = CreateSchedule(key->source_func, key->target); - auto cache_node = make_node( + auto cache_node = make_object( *(spair.second.operator->())); // Skip lowering for device copy node. @@ -749,7 +749,7 @@ class CompileEngineImpl : public CompileEngineNode { if (it->second->cached_func.defined()) return it->second; value = it->second; } else { - value = CCacheValue(make_node()); + value = CCacheValue(make_object()); value->use_count = 0; shape_func_cache_[key] = value; } @@ -758,7 +758,7 @@ class CompileEngineImpl : public CompileEngineNode { CHECK(!value->cached_func.defined()); auto spair = MakeShapeFunc().Create(key->source_func); - auto cache_node = make_node( + auto cache_node = make_object( *(spair.second.operator->())); cache_node->func_name = GetUniqueName(cache_node->func_name); cache_node->target = key->target; @@ -811,7 +811,7 @@ const CompileEngine& CompileEngine::Global() { // intentionally allocate raw pointer to avoid // free during destructuion. static CompileEngine* inst = new CompileEngine( - make_node()); + make_object()); return *inst; } @@ -852,7 +852,7 @@ TVM_REGISTER_GLOBAL("relay.backend._CompileEngineJIT") }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListItems") -.set_body_typed(CompileEngine)>( +.set_body_typed(CompileEngine)>( [](CompileEngine self){ return static_cast(self.operator->())->ListItems(); }); diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index 596dfa7154f7f..f6c38ba6b9a97 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -45,7 +45,7 @@ enum ShapeFuncParamState { }; /*! \brief Node container to represent a cached function. */ -struct CachedFuncNode : public Node { +struct CachedFuncNode : public Object { /* \brief compiled target */ tvm::Target target; /*! \brief Function name */ @@ -69,15 +69,17 @@ struct CachedFuncNode : public Node { } static constexpr const char* _type_key = "relay.CachedFunc"; - TVM_DECLARE_NODE_TYPE_INFO(CachedFuncNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(CachedFuncNode, Object); }; -TVM_DEFINE_NODE_REF(CachedFunc, CachedFuncNode); - +class CachedFunc : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(CachedFunc, ObjectRef, CachedFuncNode); +}; class CCacheKey; /*! \brief Compile cache key */ -class CCacheKeyNode : public Node { +class CCacheKeyNode : public Object { public: /*! \brief The source function to be lowered. */ Function source_func; @@ -106,7 +108,7 @@ class CCacheKeyNode : public Node { Target target); static constexpr const char* _type_key = "relay.CCacheKey"; - TVM_DECLARE_NODE_TYPE_INFO(CCacheKeyNode, tvm::Node); + TVM_DECLARE_FINAL_OBJECT_INFO(CCacheKeyNode, tvm::Object); private: /*! @@ -116,10 +118,10 @@ class CCacheKeyNode : public Node { }; /*! \brief cache entry used in compile engine */ -class CCacheKey : public NodeRef { +class CCacheKey : public ObjectRef { public: CCacheKey() {} - explicit CCacheKey(ObjectPtr n) : NodeRef(n) {} + explicit CCacheKey(ObjectPtr n) : ObjectRef(n) {} const CCacheKeyNode* operator->() const { return static_cast(get()); } @@ -132,7 +134,7 @@ class CCacheKey : public NodeRef { }; /*! \brief Node container for compile cache. */ -class CCacheValueNode : public Node { +class CCacheValueNode : public Object { public: /*! \brief The corresponding function */ CachedFunc cached_func; @@ -146,14 +148,14 @@ class CCacheValueNode : public Node { v->Visit("use_count", &use_count); } static constexpr const char* _type_key = "relay.CCacheValue"; - TVM_DECLARE_NODE_TYPE_INFO(CCacheValueNode, tvm::Node); + TVM_DECLARE_FINAL_OBJECT_INFO(CCacheValueNode, tvm::Object); }; /*! \brief cache entry used in compile engine */ -class CCacheValue : public NodeRef { +class CCacheValue : public ObjectRef { public: CCacheValue() {} - explicit CCacheValue(ObjectPtr n) : NodeRef(n) {} + explicit CCacheValue(ObjectPtr n) : ObjectRef(n) {} CCacheValueNode* operator->() { return static_cast(get_mutable()); } @@ -167,7 +169,7 @@ class CCacheValue : public NodeRef { * \brief Backend compilation engine for * low level code generation. */ -class CompileEngineNode : public Node { +class CompileEngineNode : public Object { public: /*! * \brief Get lowered result. @@ -200,14 +202,14 @@ class CompileEngineNode : public Node { void VisitAttrs(AttrVisitor*) {} static constexpr const char* _type_key = "relay.CompileEngine"; - TVM_DECLARE_NODE_TYPE_INFO(CompileEngineNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(CompileEngineNode, Object); }; /*! \brief cache entry used in compile engine */ -class CompileEngine : public NodeRef { +class CompileEngine : public ObjectRef { public: CompileEngine() {} - explicit CompileEngine(ObjectPtr n) : NodeRef(n) {} + explicit CompileEngine(ObjectPtr n) : ObjectRef(n) {} CompileEngineNode* operator->() { return static_cast(get_mutable()); } diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc index cdaf813c44e40..84fada060744f 100644 --- a/src/relay/backend/contrib/codegen_c/codegen.cc +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -152,7 +152,7 @@ class CSourceCodegen : public CSourceModuleCodegenBase { code_stream_ << builder.JIT(); } - runtime::Module CreateCSourceModule(const NodeRef& ref) override { + runtime::Module CreateCSourceModule(const ObjectRef& ref) override { // Create headers code_stream_ << "#include \n"; code_stream_ << "#include \n"; @@ -170,7 +170,7 @@ class CSourceCodegen : public CSourceModuleCodegenBase { out[i] = a[i] p_OP_ b[i]; \ } \ } - + #define CSOURCE_BINARY_OP_2D(p_ID_, p_OP_, p_DIM1_, p_DIM2_) \ extern "C" void p_ID_(float* a, float* b, float* out) { \ for (int64_t i = 0; i < p_DIM1_; ++i) { \ @@ -214,7 +214,7 @@ class CSourceCodegen : public CSourceModuleCodegenBase { * CUDA, etc, under TVM, so the generated code could be packed in a runtime * module. This module simplifies code serialization and invocation. */ -runtime::Module CCompiler(const NodeRef& ref) { +runtime::Module CCompiler(const ObjectRef& ref) { CSourceCodegen csource; return csource.CreateCSourceModule(ref); } diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index 1319ca2ff787f..d97f5dcd9103f 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -49,7 +49,7 @@ class CSourceModuleCodegenBase { * * \return A runtime module. */ - virtual runtime::Module CreateCSourceModule(const NodeRef& ref) = 0; + virtual runtime::Module CreateCSourceModule(const ObjectRef& ref) = 0; /*! * \brief Get the external symbol of the Relay function name. diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index e7f7bd6ff5597..675198fcc9b38 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -254,7 +254,7 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { * * \return The runtime module that contains C source code. */ - runtime::Module CreateCSourceModule(const NodeRef& ref) override { + runtime::Module CreateCSourceModule(const ObjectRef& ref) override { // Create headers code_stream_ << "#include \n"; code_stream_ << "#include \n"; @@ -298,7 +298,7 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { * \brief The external compiler/codegen tool. It takes a Relay expression/module and * compile it into a runtime module. */ -runtime::Module DNNLCompiler(const NodeRef& ref) { +runtime::Module DNNLCompiler(const ObjectRef& ref) { DNNLModuleCodegen dnnl; return dnnl.CreateCSourceModule(ref); } diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index fc12cf66900fa..5f210436f9b9f 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -47,9 +47,9 @@ class GraphOpNode; using IntegerArray = Array; using ShapeVector = std::vector >; using GraphAttrs = std::unordered_map; -using GraphNodePtr = std::shared_ptr; -using GraphInputNodePtr = std::shared_ptr; -using GraphOpNodePtr = std::shared_ptr; +using GraphObjectPtr = std::shared_ptr; +using GraphInputObjectPtr = std::shared_ptr; +using GraphOpObjectPtr = std::shared_ptr; using TargetsMap = std::unordered_map; /*! \brief Lowered outputs */ @@ -255,7 +255,7 @@ class GraphRuntimeCodegen * \param expr * \return std::vector<_NodeRef> */ - std::vector AddNode(GraphNodePtr node, Expr expr) { + std::vector AddNode(GraphObjectPtr node, Expr expr) { auto checked_type = expr->checked_type(); size_t count = storage_device_map_.count(expr); CHECK_GT(count, 0) << "Expr is not existing in storage plan"; @@ -319,7 +319,7 @@ class GraphRuntimeCodegen } /*! \brief Visitors */ - std::unordered_map, NodeHash, NodeEqual> visitor_cache_; + std::unordered_map, ObjectHash, ObjectEqual> visitor_cache_; std::vector VisitExpr(const Expr& expr) override { if (visitor_cache_.count(expr)) return visitor_cache_.at(expr); @@ -587,13 +587,13 @@ class GraphRuntimeCodegen protected: /*! \brief nodes */ - std::vector nodes_; + std::vector nodes_; /*! \brief output of graph */ std::vector heads_; /*! \brief mod */ runtime::Module* mod_; /*! \brief variable map */ - std::unordered_map> var_map_; + std::unordered_map> var_map_; /*! \brief target device */ TargetsMap targets_; /*! \brief params */ @@ -601,7 +601,7 @@ class GraphRuntimeCodegen /*! \brief plan memory of device result */ Map> storage_device_map_; /*! \brief lowered funcs */ - std::unordered_map> + std::unordered_map> lowered_funcs_; /*! \brief name map */ std::unordered_map name_map_; diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index b5fd0c914b62a..b4777845670ad 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -45,7 +45,7 @@ inline const PackedFunc& GetPackedFunc(const std::string& name) { /* Value Implementation */ Closure ClosureNode::make(tvm::Map env, Function func) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->env = std::move(env); n->func = std::move(func); return Closure(n); @@ -64,7 +64,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) // TODO(@jroesch): this doesn't support mutual letrec /* Value Implementation */ RecClosure RecClosureNode::make(Closure clos, Var bind) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->clos = std::move(clos); n->bind = std::move(bind); return RecClosure(n); @@ -80,7 +80,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); TupleValue TupleValueNode::make(tvm::Array value) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->fields = value; return TupleValue(n); } @@ -95,7 +95,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); TensorValue TensorValueNode::make(runtime::NDArray data) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->data = std::move(data); return TensorValue(n); } @@ -112,7 +112,7 @@ TVM_REGISTER_API("relay._make.TensorValue") .set_body_typed(TensorValueNode::make); RefValue RefValueNode::make(Value value) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->value = value; return RefValue(n); } @@ -131,7 +131,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ConstructorValue ConstructorValueNode::make(int32_t tag, tvm::Array fields, Constructor constructor) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->tag = tag; n->fields = fields; n->constructor = constructor; @@ -204,7 +204,7 @@ struct Stack { class InterpreterState; /*! \brief A container capturing the state of the interpreter. */ -class InterpreterStateNode : public Node { +class InterpreterStateNode : public Object { public: using Frame = tvm::Map; using Stack = tvm::Array; @@ -223,13 +223,16 @@ class InterpreterStateNode : public Node { static InterpreterState make(Expr current_expr, Stack stack); static constexpr const char* _type_key = "relay.InterpreterState"; - TVM_DECLARE_NODE_TYPE_INFO(InterpreterStateNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(InterpreterStateNode, Object); }; -RELAY_DEFINE_NODE_REF(InterpreterState, InterpreterStateNode, NodeRef); +class InterpreterState : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(InterpreterState, ObjectRef, InterpreterStateNode); +}; InterpreterState InterpreterStateNode::make(Expr current_expr, Stack stack) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->current_expr = std::move(current_expr); n->stack = std::move(stack); return InterpreterState(n); diff --git a/src/relay/backend/param_dict.cc b/src/relay/backend/param_dict.cc index 9bde3a0b4edd6..e517fee3a4af9 100644 --- a/src/relay/backend/param_dict.cc +++ b/src/relay/backend/param_dict.cc @@ -95,7 +95,7 @@ TVM_REGISTER_GLOBAL("tvm.relay._load_param_dict") for (size_t i = 0; i < size; ++i) { tvm::runtime::NDArray temp; temp.Load(strm); - auto n = tvm::make_node(); + auto n = tvm::make_object(); n->name = std::move(names[i]); n->array = temp; ret.push_back(NamedNDArray(n)); diff --git a/src/relay/backend/param_dict.h b/src/relay/backend/param_dict.h index aa3c0244118fe..e2d225aadd19b 100644 --- a/src/relay/backend/param_dict.h +++ b/src/relay/backend/param_dict.h @@ -40,7 +40,7 @@ constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7; /*! * \brief Wrapper node for naming `NDArray`s. */ -struct NamedNDArrayNode : public ::tvm::Node { +struct NamedNDArrayNode : public ::tvm::Object { std::string name; tvm::runtime::NDArray array; @@ -50,11 +50,13 @@ struct NamedNDArrayNode : public ::tvm::Node { } static constexpr const char* _type_key = "NamedNDArray"; - TVM_DECLARE_NODE_TYPE_INFO(NamedNDArrayNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(NamedNDArrayNode, Object); }; -TVM_DEFINE_NODE_REF(NamedNDArray, NamedNDArrayNode); - +class NamedNDArray : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(NamedNDArray, ObjectRef, NamedNDArrayNode); +}; } // namespace relay } // namespace tvm diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 0de47bda0bbc2..af425a4966d02 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -112,7 +112,7 @@ struct ConditionNode { virtual ~ConditionNode() {} }; -using ConditionNodePtr = std::shared_ptr; +using ConditionObjectPtr = std::shared_ptr; /*! * \brief A var binding condition @@ -144,15 +144,15 @@ struct TagCompare : ConditionNode { ~TagCompare() {} }; -using TreeNodePtr = typename relay::TreeNode::pointer; -using TreeLeafNode = relay::TreeLeafNode; -using TreeLeafFatalNode = relay::TreeLeafFatalNode; -using TreeBranchNode = relay::TreeBranchNode; +using TreeObjectPtr = typename relay::TreeNode::pointer; +using TreeLeafNode = relay::TreeLeafNode; +using TreeLeafFatalNode = relay::TreeLeafFatalNode; +using TreeBranchNode = relay::TreeBranchNode; -TreeNodePtr BuildDecisionTreeFromPattern(MatchValuePtr data, +TreeObjectPtr BuildDecisionTreeFromPattern(MatchValuePtr data, Pattern pattern, - TreeNodePtr then_branch, - TreeNodePtr else_branch) { + TreeObjectPtr then_branch, + TreeObjectPtr else_branch) { if (pattern.as()) { // We ignore wildcard binding since it's not producing new vars return then_branch; @@ -185,16 +185,16 @@ TreeNodePtr BuildDecisionTreeFromPattern(MatchValuePtr data, } } -TreeNodePtr BuildDecisionTreeFromClause(MatchValuePtr data, +TreeObjectPtr BuildDecisionTreeFromClause(MatchValuePtr data, Clause clause, - TreeNodePtr else_branch) { + TreeObjectPtr else_branch) { return BuildDecisionTreeFromPattern(data, clause->lhs, TreeLeafNode::Make(clause->rhs), else_branch); } -TreeNodePtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array clauses) { +TreeObjectPtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array clauses) { // When nothing matches, the VM throws fatal error - TreeNodePtr else_branch = TreeLeafFatalNode::Make(); + TreeObjectPtr else_branch = TreeLeafFatalNode::Make(); // Start from the last clause for (auto it = clauses.rbegin(); it != clauses.rend(); ++it) { else_branch = BuildDecisionTreeFromClause(data, *it, else_branch); @@ -674,7 +674,7 @@ class VMFunctionCompiler : ExprFunctor { } } - void CompileTreeNode(TreeNodePtr tree) { + void CompileTreeNode(TreeObjectPtr tree) { if (std::dynamic_pointer_cast(tree)) { auto node = std::dynamic_pointer_cast(tree); VisitExpr(node->body); @@ -731,13 +731,13 @@ class VMFunctionCompiler : ExprFunctor { protected: /*! \brief Store the expression a variable points to. */ - std::unordered_map expr_map_; + std::unordered_map expr_map_; /*! \brief Instructions in the VMFunction. */ std::vector instructions_; /*! \brief Parameter names of the function. */ std::vector params_; /*! \brief Map from var to register number. */ - std::unordered_map var_register_map_; + std::unordered_map var_register_map_; /*! \brief Last used register number. */ size_t last_register_; /*! \brief Total number of virtual registers allocated. */ @@ -786,7 +786,7 @@ relay::Function VMCompiler::BindParamsByName( relay::Function func, const std::unordered_map& params) { std::unordered_map name_dict; - std::unordered_set repeat_var; + std::unordered_set repeat_var; for (auto arg : func->params) { const auto &name = arg->name_hint(); if (name_dict.count(name)) { @@ -795,7 +795,7 @@ relay::Function VMCompiler::BindParamsByName( name_dict[name] = arg; } } - std::unordered_map bind_dict; + std::unordered_map bind_dict; for (auto &kv : params) { if (name_dict.count(kv.first) == 0) { continue; diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index 8cdb12e4dafa0..2beab1536a187 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -52,7 +52,7 @@ using namespace tvm::runtime::vm; using namespace relay::transform; template -using NodeMap = std::unordered_map; +using NodeMap = std::unordered_map; using TagMap = NodeMap; using TagNameMap = std::unordered_map; using GlobalMap = NodeMap; @@ -76,7 +76,7 @@ struct VMCompilerContext { // List of cached functions std::vector cached_funcs; // The functions that have been lowered. - std::unordered_map seen_funcs; + std::unordered_map seen_funcs; }; diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc index 3bb1458b0758e..f94f837ef550a 100644 --- a/src/relay/backend/vm/inline_primitives.cc +++ b/src/relay/backend/vm/inline_primitives.cc @@ -53,7 +53,7 @@ namespace vm { */ struct PrimitiveInliner : ExprMutator { Module module_; - std::unordered_map var_map; + std::unordered_map var_map; explicit PrimitiveInliner(const Module& module) : module_(module) {} diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index ab9dc8cbec639..7298c50e6f1f9 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -43,7 +43,7 @@ inline std::string GenerateName(const Function& func) { } bool IsClosure(const Function& func) { - NodeRef res = FunctionGetAttr(func, attr::kClosure); + ObjectRef res = FunctionGetAttr(func, attr::kClosure); const ir::IntImm* pval = res.as(); return pval && pval->value != 0; } @@ -200,7 +200,7 @@ class LambdaLifter : public ExprMutator { } private: - std::unordered_map lambda_map_; + std::unordered_map lambda_map_; std::vector letrec_; Module module_; }; diff --git a/src/relay/backend/vm/removed_unused_funcs.cc b/src/relay/backend/vm/removed_unused_funcs.cc index ee44e26fdfa10..546f1d30cb41a 100644 --- a/src/relay/backend/vm/removed_unused_funcs.cc +++ b/src/relay/backend/vm/removed_unused_funcs.cc @@ -46,7 +46,7 @@ struct CallTracer : ExprVisitor { std::unordered_set called_funcs_; // Record the expressions that are being visited - std::unordered_set visiting_; + std::unordered_set visiting_; explicit CallTracer(const Module& module) : module_{module}, @@ -96,7 +96,7 @@ struct CallTracer : ExprVisitor { * * \param module The Relay module. * \param entry_funcs The set of functions that can be entry function. - * + * * \return The module with dead functions removed. */ Module RemoveUnusedFunctions(const Module& module, diff --git a/src/relay/ir/adt.cc b/src/relay/ir/adt.cc index 1f51ecc84fdca..73172879d393c 100644 --- a/src/relay/ir/adt.cc +++ b/src/relay/ir/adt.cc @@ -28,7 +28,7 @@ namespace tvm { namespace relay { PatternWildcard PatternWildcardNode::make() { - NodePtr n = make_node(); + ObjectPtr n = make_object(); return PatternWildcard(n); } @@ -43,7 +43,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); PatternVar PatternVarNode::make(tvm::relay::Var var) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->var = std::move(var); return PatternVar(n); } @@ -61,7 +61,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) PatternConstructor PatternConstructorNode::make(Constructor constructor, tvm::Array patterns) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->constructor = std::move(constructor); n->patterns = std::move(patterns); return PatternConstructor(n); @@ -80,7 +80,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); PatternTuple PatternTupleNode::make(tvm::Array patterns) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->patterns = std::move(patterns); return PatternTuple(n); } @@ -99,7 +99,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) Constructor ConstructorNode::make(std::string name_hint, tvm::Array inputs, GlobalTypeVar belong_to) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->name_hint = std::move(name_hint); n->inputs = std::move(inputs); n->belong_to = std::move(belong_to); @@ -121,7 +121,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TypeData TypeDataNode::make(GlobalTypeVar header, tvm::Array type_vars, tvm::Array constructors) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->header = std::move(header); n->type_vars = std::move(type_vars); n->constructors = std::move(constructors); @@ -141,7 +141,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); Clause ClauseNode::make(Pattern lhs, Expr rhs) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->lhs = std::move(lhs); n->rhs = std::move(rhs); return Clause(n); @@ -160,7 +160,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); Match MatchNode::make(Expr data, tvm::Array clauses, bool complete) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->data = std::move(data); n->clauses = std::move(clauses); n->complete = complete; diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index df91f794f6d16..589de09b0b81d 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -49,7 +49,7 @@ class AlphaEqualHandler: * \param rhs The right hand operand. * \return The comparison result. */ - bool Equal(const NodeRef& lhs, const NodeRef& rhs) { + bool Equal(const ObjectRef& lhs, const ObjectRef& rhs) { if (lhs.same_as(rhs)) return true; if (!lhs.defined() || !rhs.defined()) return false; if (lhs->IsInstance()) { @@ -88,7 +88,7 @@ class AlphaEqualHandler: * \param rhs The right hand operand. * \return The comparison result. */ - bool AttrEqual(const NodeRef& lhs, const NodeRef& rhs) { + bool AttrEqual(const ObjectRef& lhs, const ObjectRef& rhs) { auto compute = [&]() { if (&lhs == &rhs) return true; if (auto lhsd = lhs.as()) { @@ -127,7 +127,7 @@ class AlphaEqualHandler: return Compare(compute(), lhs, rhs); } - bool Compare(bool result, const NodeRef& lhs, const NodeRef& rhs) { + bool Compare(bool result, const ObjectRef& lhs, const ObjectRef& rhs) { if (assert_mode_) { CHECK(result) << "\n" << AsText(lhs, true) << "\nis not equal to:\n" << AsText(rhs, true); } @@ -180,7 +180,7 @@ class AlphaEqualHandler: * \param rhs The right hand operand. * \return The compare result. */ - bool LeafNodeEqual(const ObjectRef& lhs, const ObjectRef& rhs) { + bool LeafObjectEqual(const ObjectRef& lhs, const ObjectRef& rhs) { if (lhs.same_as(rhs)) return true; auto it = equal_map_.find(lhs); if (it != equal_map_.end()) { @@ -197,7 +197,7 @@ class AlphaEqualHandler: } using AttrsEqualHandler::VisitAttr_; bool VisitAttr_(const Variable* lhs, const ObjectRef& other) final { - return LeafNodeEqual(GetRef(lhs), other); + return LeafObjectEqual(GetRef(lhs), other); } // Type equality @@ -211,13 +211,13 @@ class AlphaEqualHandler: } bool VisitType_(const IncompleteTypeNode* lhs, const Type& other) final { - return LeafNodeEqual(GetRef(lhs), other); + return LeafObjectEqual(GetRef(lhs), other); } bool VisitType_(const TypeVarNode* lhs, const Type& other) final { if (const TypeVarNode* rhs = other.as()) { if (lhs->kind != rhs->kind) return false; - return LeafNodeEqual(GetRef(lhs), other); + return LeafObjectEqual(GetRef(lhs), other); } else { return false; } @@ -290,7 +290,7 @@ class AlphaEqualHandler: } bool VisitType_(const GlobalTypeVarNode* lhs, const Type& other) final { - return LeafNodeEqual(GetRef(lhs), other); + return LeafObjectEqual(GetRef(lhs), other); } bool VisitType_(const TypeCallNode* lhs, const Type& other) final { @@ -366,7 +366,7 @@ class AlphaEqualHandler: if (const VarNode* rhs = other.as()) { if (lhs->name_hint() != rhs->name_hint()) return false; if (!TypeEqual(lhs->type_annotation, rhs->type_annotation)) return false; - return LeafNodeEqual(GetRef(lhs), other); + return LeafObjectEqual(GetRef(lhs), other); } else { return false; } @@ -600,23 +600,23 @@ bool AlphaEqual(const Expr& lhs, const Expr& rhs) { // TODO(@jroesch): move to correct namespace? TVM_REGISTER_API("relay._make._alpha_equal") -.set_body_typed([](NodeRef a, NodeRef b) { +.set_body_typed([](ObjectRef a, ObjectRef b) { return AlphaEqualHandler(false, false).Equal(a, b); }); TVM_REGISTER_API("relay._make._assert_alpha_equal") -.set_body_typed([](NodeRef a, NodeRef b) { +.set_body_typed([](ObjectRef a, ObjectRef b) { bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b); CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not alpha equal"; }); TVM_REGISTER_API("relay._make._graph_equal") -.set_body_typed([](NodeRef a, NodeRef b) { +.set_body_typed([](ObjectRef a, ObjectRef b) { return AlphaEqualHandler(true, false).Equal(a, b); }); TVM_REGISTER_API("relay._make._assert_graph_equal") -.set_body_typed([](NodeRef a, NodeRef b) { +.set_body_typed([](ObjectRef a, ObjectRef b) { bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b); CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not graph equal"; }); diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index 3bc916d9a4066..ca8755730d801 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -33,11 +33,11 @@ using namespace tvm::runtime; ObjectPtr GetSourceNameNode(const std::string& name) { // always return pointer as the reference can change as map re-allocate. // or use another level of indirection by creating a unique_ptr - static std::unordered_map > source_map; + static std::unordered_map > source_map; auto sn = source_map.find(name); if (sn == source_map.end()) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); source_map[name] = n; n->name = std::move(name); return n; @@ -66,7 +66,7 @@ TVM_REGISTER_NODE_TYPE(SourceNameNode) }); Span SpanNode::make(SourceName source, int lineno, int col_offset) { - auto n = make_node(); + auto n = make_object(); n->source = std::move(source); n->lineno = lineno; n->col_offset = col_offset; @@ -88,7 +88,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(IdNode); TVM_REGISTER_API("relay._base.set_span") -.set_body_typed([](NodeRef node_ref, Span sp) { +.set_body_typed([](ObjectRef node_ref, Span sp) { auto rn = node_ref.as(); CHECK(rn); rn->span = sp; diff --git a/src/relay/ir/error.cc b/src/relay/ir/error.cc index 33273f972ea81..7c47c7441dbb5 100644 --- a/src/relay/ir/error.cc +++ b/src/relay/ir/error.cc @@ -37,7 +37,7 @@ void RelayErrorStream::Raise() const { } template -using NodeMap = std::unordered_map; +using NodeMap = std::unordered_map; void ErrorReporter::RenderErrors(const Module& module, bool use_color) { // First we pick an error reporting strategy for each error. @@ -46,7 +46,7 @@ void ErrorReporter::RenderErrors(const Module& module, bool use_color) { CHECK(!err.sp.defined()) << "attempting to use spanned errors, currently not supported"; } - NodeMap> error_maps; + NodeMap> error_maps; // Set control mode in order to produce colors; if (use_color) { @@ -132,7 +132,7 @@ void ErrorReporter::RenderErrors(const Module& module, bool use_color) { LOG(FATAL) << annotated_prog.str() << std::endl; } -void ErrorReporter::ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err) { +void ErrorReporter::ReportAt(const GlobalVar& global, const ObjectRef& node, const Error& err) { size_t index_to_insert = this->errors_.size(); this->errors_.push_back(err); auto it = this->node_to_error_.find(node); diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index cae35895dbbf4..66e083d498cbb 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -30,7 +30,7 @@ using tvm::IRPrinter; using namespace tvm::runtime; Constant ConstantNode::make(runtime::NDArray data) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->data = std::move(data); return Constant(n); } @@ -63,7 +63,7 @@ TensorType ConstantNode::tensor_type() const { } Tuple TupleNode::make(tvm::Array fields) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->fields = std::move(fields); return Tuple(n); } @@ -81,14 +81,14 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) Var VarNode::make(Id vid, Type type_annotation) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->vid = std::move(vid); n->type_annotation = std::move(type_annotation); return Var(n); } Var VarNode::make(std::string name_hint, Type type_annotation) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->name_hint = std::move(name_hint); return VarNode::make(Id(n), type_annotation); } @@ -110,7 +110,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); GlobalVar GlobalVarNode::make(std::string name_hint) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->name_hint = std::move(name_hint); return GlobalVar(n); } @@ -132,7 +132,7 @@ Function FunctionNode::make(tvm::Array params, Type ret_type, tvm::Array type_params, tvm::Attrs attrs) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); CHECK(params.defined()); CHECK(type_params.defined()); n->params = std::move(params); @@ -157,7 +157,7 @@ FuncType FunctionNode::func_type_annotation() const { } bool FunctionNode::IsPrimitive() const { - NodeRef res = FunctionGetAttr(GetRef(this), attr::kPrimitive); + ObjectRef res = FunctionGetAttr(GetRef(this), attr::kPrimitive); const ir::IntImm* pval = res.as(); return pval && pval->value != 0; } @@ -183,13 +183,13 @@ TVM_REGISTER_API("relay._expr.FunctionGetParams") }); bool FunctionNode::UseDefaultCompiler() const { - NodeRef res = FunctionGetAttr(GetRef(this), attr::kCompiler); + ObjectRef res = FunctionGetAttr(GetRef(this), attr::kCompiler); const ir::StringImm* pval = res.as(); return pval == nullptr || pval->value == "default"; } -NodeRef FunctionGetAttr(const Function& func, const std::string& key) { - if (!func->attrs.defined()) { return NodeRef(); } +ObjectRef FunctionGetAttr(const Function& func, const std::string& key) { + if (!func->attrs.defined()) { return ObjectRef(); } const DictAttrsNode* dict_attrs = func->attrs.as(); CHECK(dict_attrs); @@ -197,19 +197,19 @@ NodeRef FunctionGetAttr(const Function& func, const std::string& key) { if (it != dict_attrs->dict.end()) { return (*it).second; } else { - return NodeRef(); + return ObjectRef(); } } -Function FunctionSetAttr(const Function& func, const std::string& key, const NodeRef& data) { +Function FunctionSetAttr(const Function& func, const std::string& key, const ObjectRef& data) { const DictAttrsNode* dattrs = func->attrs.as(); Attrs func_attrs; if (dattrs) { - Map dict = dattrs->dict; + Map dict = dattrs->dict; dict.Set(key, data); func_attrs = DictAttrsNode::make(dict); } else { - Map dict = {{key, data}}; + Map dict = {{key, data}}; func_attrs = DictAttrsNode::make(dict); } @@ -236,7 +236,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) Call CallNode::make(Expr op, Array args, Attrs attrs, Array type_args) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->op = std::move(op); n->args = std::move(args); n->attrs = std::move(attrs); @@ -257,7 +257,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); Let LetNode::make(Var var, Expr value, Expr body) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->var = std::move(var); n->value = std::move(value); n->body = std::move(body); @@ -277,7 +277,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->cond = std::move(cond); n->true_branch = std::move(true_branch); n->false_branch = std::move(false_branch); @@ -297,7 +297,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); TupleGetItem TupleGetItemNode::make(Expr tuple, int index) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->tuple = std::move(tuple); n->index = index; return TupleGetItem(n); @@ -315,7 +315,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); RefCreate RefCreateNode::make(Expr value) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->value = std::move(value); return RefCreate(n); } @@ -332,7 +332,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); RefRead RefReadNode::make(Expr ref) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->ref = std::move(ref); return RefRead(n); } @@ -349,7 +349,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); RefWrite RefWriteNode::make(Expr ref, Expr value) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->ref = std::move(ref); n->value = std::move(value); return RefWrite(n); @@ -372,8 +372,8 @@ TVM_REGISTER_API("relay._expr.TempExprRealize") }); TVM_REGISTER_API("relay._expr.FunctionSetAttr") -.set_body_typed( - [](Function func, std::string name, NodeRef ref) { +.set_body_typed( + [](Function func, std::string name, ObjectRef ref) { return FunctionSetAttr(func, name, ref); }); diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index ac45d61e873d4..e3846c93d49a2 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -340,7 +340,7 @@ class ExprApplyVisit : public ExprVisitor { private: std::function f_; - std::unordered_set visited_; + std::unordered_set visited_; }; void PostOrderVisit(const Expr& e, std::function fvisit) { @@ -422,7 +422,7 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { func->ret_type, func->type_params, func->attrs); - std::unordered_set set; + std::unordered_set set; for (const auto& v : FreeVars(expr)) { set.insert(v); } @@ -445,7 +445,7 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { TVM_REGISTER_API("relay._expr.Bind") .set_body([](TVMArgs args, TVMRetValue* ret) { - NodeRef input = args[0]; + ObjectRef input = args[0]; if (input->IsInstance()) { *ret = Bind(Downcast(input), args[1]); } else { diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index f37b1a4c10be7..15f5105808aa3 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -47,8 +47,8 @@ class RelayHashHandler: * \param ref The node to hash. * \return the hash value. */ - size_t Hash(const NodeRef& ref) { - if (!ref.defined()) return NodeHash()(ref); + size_t Hash(const ObjectRef& ref) { + if (!ref.defined()) return ObjectHash()(ref); if (ref->IsInstance()) { return TypeHash(Downcast(ref)); @@ -64,9 +64,9 @@ class RelayHashHandler: * \param ref The attributes. * \return the hash value */ - size_t AttrHash(const NodeRef& ref) { + size_t AttrHash(const ObjectRef& ref) { if (!ref.defined()) { - return NodeHash()(ref); + return ObjectHash()(ref); } return AttrsHashHandler::Hash(ref); } @@ -78,7 +78,7 @@ class RelayHashHandler: */ size_t TypeHash(const Type& type) { if (!type.defined()) { - return NodeHash()(type); + return ObjectHash()(type); } auto found = hash_map_.find(type); if (found != hash_map_.end()) { @@ -102,7 +102,7 @@ class RelayHashHandler: */ size_t ExprHash(const Expr& expr) { if (!expr.defined()) { - return NodeHash()(expr); + return ObjectHash()(expr); } auto found = hash_map_.find(expr); if (found != hash_map_.end()) { @@ -221,7 +221,7 @@ class RelayHashHandler: return hash; } - size_t BindVar(const NodeRef& var) { + size_t BindVar(const ObjectRef& var) { size_t hash = std::hash()(var_counter++); CHECK_EQ(hash_map_.count(var), 0); if (auto var_node = var.as()) { @@ -238,7 +238,7 @@ class RelayHashHandler: size_t VisitExpr_(const VarNode* var) final { // hash free variable - size_t name_hash = std::hash()(var->vid.get()); + size_t name_hash = std::hash()(var->vid.get()); return Combine(name_hash, TypeHash(var->type_annotation)); } @@ -308,7 +308,7 @@ class RelayHashHandler: } size_t VisitExpr_(const OpNode* op) final { - return NodeHash()(GetRef(op)); + return ObjectHash()(GetRef(op)); } size_t VisitExpr_(const ConstantNode* rconst) final { @@ -416,7 +416,7 @@ class RelayHashHandler: } private: // renaming of NodeRef to indicate two nodes equals to each other - std::unordered_map hash_map_; + std::unordered_map hash_map_; int var_counter = 0; }; @@ -429,7 +429,7 @@ size_t StructuralHash::operator()(const Expr& expr) const { } TVM_REGISTER_API("relay._analysis._expr_hash") -.set_body_typed([](NodeRef ref) { +.set_body_typed([](ObjectRef ref) { return static_cast(RelayHashHandler().Hash(ref)); }); diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 3bd8d59aaf497..2fa79c7b63221 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -38,7 +38,7 @@ Module ModuleNode::make(tvm::Map global_funcs, tvm::Map global_type_defs, std::unordered_set imports ) { - auto n = make_node(); + auto n = make_object(); n->functions = std::move(global_funcs); n->type_definitions = std::move(global_type_defs); n->global_type_var_map_ = {}; @@ -327,14 +327,14 @@ TVM_REGISTER_API("relay._module.Module_Add") .set_body([](TVMArgs args, TVMRetValue* ret) { Module mod = args[0]; GlobalVar var = args[1]; - NodeRef val = args[2]; + ObjectRef val = args[2]; bool update = args[3]; CHECK(val->IsInstance()); if (val->IsInstance()) { mod->Add(var, Downcast(val), update); } else if (val->IsInstance()) { GlobalVar gv = Downcast(val); - auto mod_copy = Module(make_node(*mod.operator->())); + auto mod_copy = Module(make_object(*mod.operator->())); mod_copy = transform::EtaExpand( /* expand_constructor */ false, /* expand_global_var */ true)(mod_copy); auto func = mod_copy->Lookup(gv->name_hint); diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index 7b5217d4c0665..05788b1a78b52 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -67,7 +67,7 @@ const Op& Op::Get(const std::string& name) { OpRegistry::OpRegistry() { OpManager* mgr = OpManager::Global(); - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->index_ = mgr->op_counter++; op_ = Op(n); } @@ -205,17 +205,17 @@ TVM_REGISTER_API("relay.op._Register") }); // helper to get internal dev function in objectref. -struct Op2NodePtr : public ObjectRef { - static NodePtr Get(const Op& op) { - return GetDataPtr(op); +struct Op2ObjectPtr : public ObjectRef { + static ObjectPtr Get(const Op& op) { + return GetDataPtr(op); } }; -NodePtr CreateOp(const std::string& name) { +ObjectPtr CreateOp(const std::string& name) { // Hack use TVMRetValue as exchange auto op = Op::Get(name); CHECK(op.defined()) << "Cannot find op \'" << name << '\''; - return Op2NodePtr::Get(op); + return Op2ObjectPtr::Get(op); } TVM_REGISTER_NODE_TYPE(OpNode) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 597ef4abee4f9..478469c586efd 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -116,14 +116,14 @@ class TextMetaDataContext { * \param node The node to be converted to meta node. * \return A string representation of the meta node. */ - Doc GetMetaNode(const NodeRef& node) { + Doc GetMetaNode(const ObjectRef& node) { auto it = meta_repr_.find(node); if (it != meta_repr_.end()) { return it->second; } std::string type_key = node->GetTypeKey(); CHECK(!type_key.empty()); - Array& mvector = + Array& mvector = meta_data_[type_key]; int64_t index = static_cast(mvector.size()); mvector.push_back(node); @@ -143,7 +143,7 @@ class TextMetaDataContext { */ Doc GetMetaSection() const { if (meta_data_.size() == 0) return Doc(); - return Doc(SaveJSON(Map(meta_data_.begin(), meta_data_.end()))); + return Doc(SaveJSON(Map(meta_data_.begin(), meta_data_.end()))); } /*! \return whether the meta data context is empty. */ @@ -153,9 +153,9 @@ class TextMetaDataContext { private: /*! \brief additional metadata stored in TVM json format */ - std::unordered_map > meta_data_; + std::unordered_map > meta_data_; /*! \brief map from meta data into its string representation */ - std::unordered_map meta_repr_; + std::unordered_map meta_repr_; }; class PrettyPrinter : @@ -191,7 +191,7 @@ class PrettyPrinter : } // indent a new body - Doc PrintBody(const NodeRef& node, int indent = 2) { + Doc PrintBody(const ObjectRef& node, int indent = 2) { Doc doc; Doc body; doc << "{"; @@ -202,7 +202,7 @@ class PrettyPrinter : // create a new scope by creating a new printer object. This allows temp var // numbers to be reused and prevents hoisted vars from escaping too far - Doc PrintScope(const NodeRef& node) { + Doc PrintScope(const ObjectRef& node) { // print in a new scope doc_stack_.push_back(Doc()); // must print first so doc_stack_.back() reference doesn't become stale @@ -212,7 +212,7 @@ class PrettyPrinter : return doc; } - Doc PrintFinal(const NodeRef& node) { + Doc PrintFinal(const ObjectRef& node) { if (node.as()) { Expr expr = Downcast(node); dg_ = DependencyGraph::Create(&arena_, expr); @@ -235,7 +235,7 @@ class PrettyPrinter : std::vector PrintCallAttrs(const Attrs& attrs, const Expr& op); std::vector PrintFuncAttrs(const Attrs& attrs); - Doc Print(const NodeRef& node, bool meta = false, bool try_inline = false) { + Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false) { if (node.as()) { return PrintExpr(Downcast(node), meta, try_inline); } else if (node.as()) { @@ -383,7 +383,7 @@ class PrettyPrinter : Doc printed_expr; if (meta) { - printed_expr = meta_.GetMetaNode(GetRef(expr.get())); + printed_expr = meta_.GetMetaNode(GetRef(expr.get())); } else if (!inline_expr && expr.as()) { // wrap GNFed let in brackets Doc body; @@ -440,7 +440,7 @@ class PrettyPrinter : } // default fall-back, record it as meta node. Doc doc; - return doc << Print(GetRef(op), true); + return doc << Print(GetRef(op), true); } Doc VisitExpr_(const TupleNode* op) final { @@ -624,7 +624,7 @@ class PrettyPrinter : if (it != memo_pattern_.end()) return it->second; Doc printed_pattern; if (meta) { - printed_pattern = meta_.GetMetaNode(GetRef(pattern.get())); + printed_pattern = meta_.GetMetaNode(GetRef(pattern.get())); } else { printed_pattern = VisitPattern(pattern); } @@ -687,7 +687,7 @@ class PrettyPrinter : if (it != memo_type_.end()) return it->second; Doc printed_type; if (meta) { - printed_type = meta_.GetMetaNode(GetRef(type.get())); + printed_type = meta_.GetMetaNode(GetRef(type.get())); } else { printed_type = VisitType(type); } @@ -695,9 +695,9 @@ class PrettyPrinter : return printed_type; } - Doc VisitTypeDefault_(const Node* node) final { + Doc VisitTypeDefault_(const Object* node) final { // by default always print as meta data - return Print(GetRef(node), true); + return Print(GetRef(node), true); } Doc VisitType_(const TypeVarNode* node) final { @@ -728,7 +728,7 @@ class PrettyPrinter : Doc doc; doc << "Tensor[("; std::vector shapes; - for (NodeRef shape : node->shape) { + for (ObjectRef shape : node->shape) { shapes.push_back(PrintAttr(shape)); } doc << PrintSep(shapes); @@ -816,7 +816,7 @@ class PrettyPrinter : if (value.as()) { printed_attr << "?"; } else if (meta) { - printed_attr = meta_.GetMetaNode(Downcast(value)); + printed_attr = meta_.GetMetaNode(Downcast(value)); } else { printed_attr = VisitAttr(value); } @@ -866,11 +866,11 @@ class PrettyPrinter : /*! \brief Stack of docs to implement scoped GNFing. */ std::vector doc_stack_{}; /*! \brief Map from Expr to Doc */ - std::unordered_map memo_; + std::unordered_map memo_; /*! \brief Map from Type to Doc */ - std::unordered_map memo_type_; + std::unordered_map memo_type_; /*! \brief Map from Type to Doc */ - std::unordered_map memo_pattern_; + std::unordered_map memo_pattern_; /*! \brief name allocation map */ std::unordered_map name_alloc_map_; /*! \brief meta data context */ @@ -969,7 +969,7 @@ std::vector PrettyPrinter::PrintFuncAttrs(const Attrs& attrs) { return docs; } -std::string PrettyPrint_(const NodeRef& node, +std::string PrettyPrint_(const ObjectRef& node, bool show_meta_data, runtime::TypedPackedFunc annotate) { Doc doc; @@ -978,20 +978,20 @@ std::string PrettyPrint_(const NodeRef& node, return doc.str(); } -std::string PrettyPrint(const NodeRef& node) { +std::string PrettyPrint(const ObjectRef& node) { Doc doc; doc << PrettyPrinter(false, runtime::TypedPackedFunc()).PrintFinal(node); return doc.str(); } -std::string AsText(const NodeRef& node, +std::string AsText(const ObjectRef& node, bool show_meta_data, runtime::TypedPackedFunc annotate) { return PrettyPrint_(node, show_meta_data, annotate); } TVM_REGISTER_API("relay._expr.AsText") -.set_body_typed)>(AsText); diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index 94e9883d4e411..70071d0445aa3 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -30,7 +30,7 @@ using tvm::IRPrinter; using namespace tvm::runtime; TensorType TensorTypeNode::make(Array shape, DataType dtype) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->shape = std::move(shape); n->dtype = std::move(dtype); return TensorType(n); @@ -64,7 +64,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); TypeVar TypeVarNode::make(std::string name, Kind kind) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->var = tvm::Var(name); n->kind = std::move(kind); return TypeVar(n); @@ -85,7 +85,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); GlobalTypeVar GlobalTypeVarNode::make(std::string name, Kind kind) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->var = tvm::Var(name); n->kind = std::move(kind); return GlobalTypeVar(n); @@ -106,7 +106,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); TypeCall TypeCallNode::make(Type func, tvm::Array args) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->func = std::move(func); n->args = std::move(args); return TypeCall(n); @@ -125,7 +125,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); IncompleteType IncompleteTypeNode::make(Kind kind) { - auto n = make_node(); + auto n = make_object(); n->kind = std::move(kind); return IncompleteType(n); } @@ -147,7 +147,7 @@ FuncType FuncTypeNode::make(tvm::Array arg_types, Type ret_type, tvm::Array type_params, tvm::Array type_constraints) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->arg_types = std::move(arg_types); n->ret_type = std::move(ret_type); n->type_params = std::move(type_params); @@ -172,7 +172,7 @@ TypeRelation TypeRelationNode::make(TypeRelationFn func, Array args, int num_inputs, Attrs attrs) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->func = std::move(func); n->args = std::move(args); n->num_inputs = num_inputs; @@ -194,7 +194,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); TupleType TupleTypeNode::make(Array fields) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->fields = std::move(fields); return TupleType(n); } @@ -211,7 +211,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); RefType RefTypeNode::make(Type value) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->value = std::move(value); return RefType(n); } diff --git a/src/relay/ir/type_functor.h b/src/relay/ir/type_functor.h index 67c139185ebf6..09049cf83f867 100644 --- a/src/relay/ir/type_functor.h +++ b/src/relay/ir/type_functor.h @@ -93,7 +93,7 @@ class TypeFunctor { virtual R VisitType_(const GlobalTypeVarNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeDataNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; - virtual R VisitTypeDefault_(const Node* op, Args...) { + virtual R VisitTypeDefault_(const Object* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); throw; // unreachable, written to stop compiler warning } diff --git a/src/relay/op/algorithm/argsort.cc b/src/relay/op/algorithm/argsort.cc index b64d656b66a07..7a58cfd258a97 100644 --- a/src/relay/op/algorithm/argsort.cc +++ b/src/relay/op/algorithm/argsort.cc @@ -51,7 +51,7 @@ Expr MakeArgsort(Expr data, int axis, bool is_ascend, DataType dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = axis; attrs->is_ascend = is_ascend; attrs->dtype = dtype; diff --git a/src/relay/op/algorithm/topk.cc b/src/relay/op/algorithm/topk.cc index ecb3f7d3be058..055d65bf32523 100644 --- a/src/relay/op/algorithm/topk.cc +++ b/src/relay/op/algorithm/topk.cc @@ -72,7 +72,7 @@ Expr MakeTopK(Expr data, std::string ret_type, bool is_ascend, DataType dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->k = k; attrs->axis = axis; attrs->ret_type = ret_type; diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/op/annotation/annotation.cc index 6835525c35854..9234591659c5c 100644 --- a/src/relay/op/annotation/annotation.cc +++ b/src/relay/op/annotation/annotation.cc @@ -41,7 +41,7 @@ TVM_REGISTER_NODE_TYPE(OnDeviceAttrs); TVM_REGISTER_API("relay.op.annotation._make.on_device") .set_body_typed([](Expr data, int device_type) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->device_type = device_type; static const Op& op = Op::Get("on_device"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -87,7 +87,7 @@ TVM_ADD_FILELINE) TVM_REGISTER_NODE_TYPE(CastHintAttrs); Expr CastHint(Expr data, DataType dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("annotation.cast_hint"); return CallNode::make(op, {data}, Attrs{attrs}, {}); diff --git a/src/relay/op/debug.cc b/src/relay/op/debug.cc index f7f800fccb10f..f592d3ed3f74c 100644 --- a/src/relay/op/debug.cc +++ b/src/relay/op/debug.cc @@ -55,7 +55,7 @@ RELAY_REGISTER_OP("debug") .set_attr("FTVMCompute", DebugCompute); Expr MakeDebug(Expr expr, std::string name) { - auto dattrs = make_node(); + auto dattrs = make_object(); if (name.size() > 0) { dattrs->debug_func = EnvFunc::Get(name); } else { diff --git a/src/relay/op/device_copy.cc b/src/relay/op/device_copy.cc index 3b997a273fa5c..290ccef06d990 100644 --- a/src/relay/op/device_copy.cc +++ b/src/relay/op/device_copy.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -44,7 +44,7 @@ TVM_REGISTER_NODE_TYPE(DeviceCopyAttrs); TVM_REGISTER_API("relay.op._make.device_copy") .set_body_typed([](Expr data, int src_dev_type, int dst_dev_type) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->src_dev_type = src_dev_type; attrs->dst_dev_type = dst_dev_type; static const Op& op = Op::Get("device_copy"); diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index a65312316076f..f6329f7af7094 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -73,7 +73,7 @@ Expr MakeResize(Expr data, std::string method, bool align_corners, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->size = std::move(size); attrs->layout = std::move(layout); attrs->method = std::move(method); diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index c535d76838c82..72edeac053997 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -43,7 +43,7 @@ TVM_REGISTER_NODE_TYPE(ShapeFuncAttrs); // being able to see the arguments as well? TVM_REGISTER_API("relay.op.memory._make.alloc_storage") .set_body_typed([](Expr size, Expr alignment, DataType dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("memory.alloc_storage"); return CallNode::make(op, {size, alignment}, Attrs(attrs), {}); @@ -90,7 +90,7 @@ RELAY_REGISTER_OP("memory.alloc_storage") TVM_REGISTER_API("relay.op.memory._make.alloc_tensor") .set_body_typed assert_shape)>( [](Expr storage, tvm::relay::Expr shape, DataType dtype, Array assert_shape) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->dtype = dtype; if (assert_shape.defined()) { attrs->assert_shape = assert_shape; @@ -260,7 +260,7 @@ TVM_REGISTER_API("relay.op.memory._make.shape_func") .set_body_typed)>( [](Expr func, Expr inputs, Expr outputs, Array is_input) { static const Op& op = Op::Get("memory.shape_func"); - auto attrs = make_node(); + auto attrs = make_object(); attrs->is_input = is_input; return CallNode::make(op, {func, inputs, outputs}, Attrs(attrs), {}); }); diff --git a/src/relay/op/nn/bitserial.cc b/src/relay/op/nn/bitserial.cc index d651baeccb4ce..973ee0b3fe051 100644 --- a/src/relay/op/nn/bitserial.cc +++ b/src/relay/op/nn/bitserial.cc @@ -86,7 +86,7 @@ bool BitPackRel(const Array& types, int num_inputs, const Attrs& attrs, Expr MakeBitPack(Expr data, int bits, int pack_axis, int bit_axis, DataType pack_type, std::string name) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->bits = bits; attrs->pack_axis = pack_axis; attrs->bit_axis = bit_axis; @@ -151,7 +151,7 @@ Expr MakeBinaryConv2D(Expr data, Expr weight, Array strides, Array kernel_size, int activation_bits, int weight_bits, std::string data_layout, std::string kernel_layout, DataType pack_dtype, DataType out_dtype, bool unipolar) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->strides = std::move(strides); attrs->padding = std::move(padding); attrs->channels = std::move(channels); @@ -224,7 +224,7 @@ bool BinaryDenseRel(const Array& types, int num_inputs, const Attrs& attrs // Positional relay function to create bitserial dense operator used by frontend FFI. Expr MakeBinaryDense(Expr data, Expr weight, IndexExpr units, int data_bits, int weight_bits, DataType pack_dtype, DataType out_dtype, bool unipolar) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->units = units; attrs->data_bits = data_bits; attrs->weight_bits = weight_bits; diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 4a1fd466108d5..40c24462c8f7b 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -66,7 +66,7 @@ Expr MakeConv2D(Expr data, std::string kernel_layout, std::string out_layout, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->strides = std::move(strides); attrs->padding = std::move(padding); attrs->dilation = std::move(dilation); @@ -124,7 +124,7 @@ Expr MakeConv3D(Expr data, std::string kernel_layout, std::string out_layout, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->strides = std::move(strides); attrs->padding = std::move(padding); attrs->dilation = std::move(dilation); @@ -289,7 +289,7 @@ Expr MakeConv2DTranspose(Expr data, std::string out_layout, Array output_padding, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->channels = std::move(channels); attrs->kernel_size = std::move(kernel_size); attrs->strides = std::move(strides); @@ -448,7 +448,7 @@ Expr MakeConv1DTranspose(Expr data, std::string out_layout, Array output_padding, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->channels = std::move(channels); attrs->kernel_size = std::move(kernel_size); attrs->strides = std::move(strides); @@ -595,7 +595,7 @@ Expr MakeConv2DWinograd(Expr data, std::string kernel_layout, std::string out_layout, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->tile_size = tile_size; attrs->strides = std::move(strides); attrs->padding = std::move(padding); @@ -668,7 +668,7 @@ bool Conv2DWinogradWeightTransformRel(const Array& types, Expr MakeConv2DWinogradWeightTransform(Expr weight, int tile_size) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->tile_size = tile_size; static const Op& op = Op::Get("nn.contrib_conv2d_winograd_weight_transform"); return CallNode::make(op, {weight}, Attrs(attrs), {}); @@ -708,7 +708,7 @@ Expr MakeConv2DWinogradNNPACK(Expr data, std::string kernel_layout, std::string out_layout, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->strides = std::move(strides); attrs->padding = std::move(padding); attrs->dilation = std::move(dilation); @@ -783,7 +783,7 @@ bool Conv2DWinogradNNPACKWeightTransformRel(const Array& types, Expr MakeConv2DWinogradNNPACKWeightTransform(Expr weight, int convolution_algorithm, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->convolution_algorithm = convolution_algorithm; attrs->out_dtype = std::move(out_dtype); static const Op& op = Op::Get("nn.contrib_conv2d_winograd_nnpack_weight_transform"); @@ -821,7 +821,7 @@ Expr MakeConv2DNCHWcInt8(Expr data, std::string kernel_layout, std::string out_layout, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->strides = std::move(strides); attrs->padding = std::move(padding); attrs->dilation = std::move(dilation); @@ -870,7 +870,7 @@ Expr MakeConv2DNCHWc(Expr data, std::string kernel_layout, std::string out_layout, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->strides = std::move(strides); attrs->padding = std::move(padding); attrs->dilation = std::move(dilation); @@ -920,7 +920,7 @@ Expr MakeDepthwiseConv2DNCHWc(Expr data, std::string kernel_layout, std::string out_layout, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->strides = std::move(strides); attrs->padding = std::move(padding); attrs->dilation = std::move(dilation); @@ -1079,7 +1079,7 @@ Expr MakeDeformableConv2D(Expr data, std::string kernel_layout, std::string out_layout, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->strides = strides; attrs->padding = padding; attrs->dilation = dilation; diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index dfb360a2dec06..79c3e687db367 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -72,7 +72,7 @@ bool BiasAddRel(const Array& types, Expr MakeBiasAdd(Expr data, Expr bias, int axis) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("nn.bias_add"); return CallNode::make(op, {data, bias}, Attrs(attrs), {}); @@ -104,7 +104,7 @@ RELAY_REGISTER_OP("nn.bias_add") TVM_REGISTER_NODE_TYPE(FIFOBufferAttrs); Expr MakeFIFOBuffer(Expr input, Expr buffer, int axis) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("nn.fifo_buffer"); return CallNode::make(op, {input, buffer}, Attrs(attrs), {}); @@ -175,7 +175,7 @@ Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->units = units; attrs->out_dtype = out_dtype; static const Op& op = Op::Get("nn.dense"); @@ -208,7 +208,7 @@ TVM_REGISTER_NODE_TYPE(LeakyReluAttrs); // Positional relay function to create leaky relu operator used by frontend FFI. Expr MakeLeakyRelu(Expr data, double alpha) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->alpha = alpha; static const Op& op = Op::Get("nn.leaky_relu"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -288,7 +288,7 @@ Array > PReluInferCorrectLayout( Expr MakePRelu(Expr data, Expr alpha, int axis) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("nn.prelu"); return CallNode::make(op, {data, alpha}, Attrs(attrs), {}); @@ -327,7 +327,7 @@ TVM_REGISTER_NODE_TYPE(SoftmaxAttrs); TVM_REGISTER_API("relay.op.nn._make.softmax") .set_body_typed([](Expr data, int axis) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("nn.softmax"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -362,7 +362,7 @@ RELAY_REGISTER_OP("nn.softmax") // relay.nn.log_softmax TVM_REGISTER_API("relay.op.nn._make.log_softmax") .set_body_typed([](Expr data, int axis) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("nn.log_softmax"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -504,7 +504,7 @@ Expr MakeLRN(Expr data, double alpha, double beta, double bias) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->size = size; attrs->axis = axis; attrs->alpha = alpha; @@ -545,7 +545,7 @@ TVM_REGISTER_NODE_TYPE(L2NormalizeAttrs); Expr MakeL2Normalize(Expr data, double eps, Array axis) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->eps = eps; attrs->axis = std::move(axis); static const Op& op = Op::Get("nn.l2_normalize"); @@ -591,7 +591,7 @@ bool DropoutRel(const Array& types, } Expr MakeDropout(Expr data, double rate) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->rate = rate; static const Op& op = Op::Get("nn.dropout"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -680,7 +680,7 @@ bool BatchNormRel(const Array& types, Expr MakeBatchNorm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, int axis, double epsilon, bool center, bool scale) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = axis; attrs->epsilon = epsilon; attrs->center = center; @@ -763,7 +763,7 @@ bool InstanceNormRel(const Array& types, Expr MakeInstanceNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, bool center, bool scale) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = axis; attrs->epsilon = epsilon; attrs->center = center; @@ -833,7 +833,7 @@ bool LayerNormRel(const Array& types, Expr MakeLayerNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, bool center, bool scale) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = axis; attrs->epsilon = epsilon; attrs->center = center; @@ -1024,7 +1024,7 @@ bool DepthToSpaceRel(const Array& types, int num_inputs, const Attrs& attr // Positional relay function to create DepthToSpace operator // used by frontend FFI Expr MakeDepthToSpace(Expr data, int block_size, std::string layout, std::string mode) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->block_size = block_size; attrs->layout = std::move(layout); attrs->mode = std::move(mode); @@ -1082,7 +1082,7 @@ bool SpaceToDepthRel(const Array& types, int num_inputs, const Attrs& attr // Positional relay function to create SpaceToDepth operator // used by frontend FFI Expr MakeSpaceToDepth(Expr data, int block_size, std::string layout) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->block_size = block_size; attrs->layout = std::move(layout); static const Op& op = Op::Get("nn.space_to_depth"); diff --git a/src/relay/op/nn/pad.cc b/src/relay/op/nn/pad.cc index 519619f8812a7..5cde41446fe63 100644 --- a/src/relay/op/nn/pad.cc +++ b/src/relay/op/nn/pad.cc @@ -192,7 +192,7 @@ Expr MakePad(Expr data, Array > pad_width, double pad_value, std::string pad_mode) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->pad_value = pad_value; attrs->pad_width = std::move(pad_width); attrs->pad_mode = std::move(pad_mode); @@ -267,7 +267,7 @@ bool MirrorPadRel(const Array& types, // Handler to create a call to the padding op used by front-end FFI Expr MakeMirrorPad(Expr data, Array > pad_width, std::string mode) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->mode = mode; attrs->pad_width = std::move(pad_width); static const Op& op = Op::Get("nn.mirror_pad"); diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index e7529a9d7bb9e..00216900e2b59 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -63,7 +63,7 @@ Expr MakeMaxPool(Expr data, std::string layout, bool ceil_mode, std::string op_name) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = std::move(strides); attrs->padding = std::move(padding); @@ -82,7 +82,7 @@ Expr MakeAvgPool(Expr data, bool ceil_mode, bool count_include_pad, std::string op_name) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = std::move(strides); attrs->padding = std::move(padding); @@ -359,7 +359,7 @@ Array GlobalPool2DCompute(const Attrs& attrs, Expr MakeGlobalAvgPool2D(Expr data, std::string layout) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->layout = std::move(layout); static const Op& op = Op::Get("nn.global_avg_pool2d"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -391,7 +391,7 @@ RELAY_REGISTER_OP("nn.global_avg_pool2d") // GlobalMaxPool Expr MakeGlobalMaxPool2D(Expr data, std::string layout) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->layout = std::move(layout); static const Op& op = Op::Get("nn.global_max_pool2d"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -511,7 +511,7 @@ Array AdaptivePool2DCompute(const Attrs& attrs, Expr MakeAdaptiveAvgPool2D(Expr data, Array output_size, std::string layout) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); static const Op& op = Op::Get("contrib.adaptive_avg_pool2d"); @@ -550,7 +550,7 @@ RELAY_REGISTER_OP("contrib.adaptive_avg_pool2d") Expr MakeAdaptiveMaxPool2D(Expr data, Array output_size, std::string layout) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); static const Op& op = Op::Get("contrib.adaptive_max_pool2d"); @@ -647,7 +647,7 @@ Array Pool2DGradCompute(const Attrs& attrs, const Array& inputs, // MaxPool2DGrad Expr MakeMaxPool2DGrad(Expr out_grad, Expr data, Array pool_size, Array strides, Array padding, std::string layout, bool ceil_mode) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = std::move(strides); attrs->padding = std::move(padding); @@ -695,7 +695,7 @@ RELAY_REGISTER_OP("nn.max_pool2d_grad") Expr MakeAvgPool2DGrad(Expr out_grad, Expr data, Array pool_size, Array strides, Array padding, std::string layout, bool ceil_mode, bool count_include_pad) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = std::move(strides); attrs->padding = std::move(padding); diff --git a/src/relay/op/nn/sparse.cc b/src/relay/op/nn/sparse.cc index 7cf8a27f3b56f..fc22725977f2a 100644 --- a/src/relay/op/nn/sparse.cc +++ b/src/relay/op/nn/sparse.cc @@ -65,7 +65,7 @@ bool SparseDenseRel(const Array& types, int num_inputs, const Attrs& attrs // Positional relay function to create dense operator used by frontend FFI. Expr MakeSparseDense(Expr data, Expr weight_data, Expr weight_indices, Expr weight_indptr) { - auto attrs = make_node(); + auto attrs = make_object(); static const Op& op = Op::Get("nn.sparse_dense"); return CallNode::make(op, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs), {}); } @@ -114,7 +114,7 @@ bool SparseTransposeRel(const Array& types, int num_inputs, const Attrs& a } Expr MakeSparseTranspose(Expr sparse_data, Expr sparse_indices, Expr sparse_indptr) { - auto attrs = make_node(); + auto attrs = make_object(); static const Op& op = Op::Get("nn.sparse_transpose"); return CallNode::make(op, {sparse_data, sparse_indices, sparse_indptr}, Attrs(attrs), {}); } diff --git a/src/relay/op/nn/upsampling.cc b/src/relay/op/nn/upsampling.cc index 61b40588c3d72..2ba7b2f7bcf43 100644 --- a/src/relay/op/nn/upsampling.cc +++ b/src/relay/op/nn/upsampling.cc @@ -102,7 +102,7 @@ Expr MakeUpSampling(Expr data, std::string layout, std::string method, bool align_corners) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->layout = std::move(layout); attrs->method = std::move(method); attrs->scale_h = scale_h; @@ -182,7 +182,7 @@ Expr MakeUpSampling3D(Expr data, std::string layout, std::string method, std::string coordinate_transformation_mode) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->layout = std::move(layout); attrs->method = std::move(method); attrs->scale_d = scale_d; diff --git a/src/relay/op/op_common.h b/src/relay/op/op_common.h index 42c1fc485a630..53495ccff15d1 100644 --- a/src/relay/op/op_common.h +++ b/src/relay/op/op_common.h @@ -145,7 +145,7 @@ class OpMatch { private: /*! \brief The match function map. */ - std::unordered_map match_map_; + std::unordered_map match_map_; /*! \brief An optional default case. */ MatchFunc default_; }; diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 7e499bae7683b..4e9a900c7cd69 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -308,7 +308,7 @@ bool ReduceRel(const Array& types, Array axis, \ bool keepdims, \ bool exclude) { \ - auto attrs = make_node(); \ + auto attrs = make_object(); \ attrs->axis = std::move(axis); \ attrs->keepdims = keepdims; \ attrs->exclude = exclude; \ @@ -625,7 +625,7 @@ Expr MakeVariance(Expr data, Array axis, bool keepdims, bool exclude) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = std::move(axis); attrs->keepdims = keepdims; attrs->exclude = exclude; diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index ff018e43aea76..7407f21e8e9a5 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -77,7 +77,7 @@ Array CastCompute(const Attrs& attrs, Expr MakeCast(Expr data, DataType dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("cast"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -165,7 +165,7 @@ Array ReinterpretCompute(const Attrs& attrs, const Array& inputs } Expr MakeReinterpret(Expr data, DataType dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("reinterpret"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -242,7 +242,7 @@ Array ExpandDimsCompute(const Attrs& attrs, Expr MakeExpandDims(Expr data, int axis, int num_newaxis) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = axis; attrs->num_newaxis = num_newaxis; static const Op& op = Op::Get("expand_dims"); @@ -328,7 +328,7 @@ Array> ConcatenateLayout( Expr MakeConcatenate(Expr data, int axis) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("concatenate"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -423,7 +423,7 @@ Array StackCompute(const Attrs& attrs, Expr MakeStack(Expr data, int axis) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("stack"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -515,7 +515,7 @@ Array TransposeCompute(const Attrs& attrs, Expr MakeTranspose(Expr data, Array axes) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axes = std::move(axes); static const Op& op = Op::Get("transpose"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -706,7 +706,7 @@ Array ReshapeCompute(const Attrs& attrs, Expr MakeReshape(Expr data, Array newshape) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->newshape = std::move(newshape); attrs->reverse = false; static const Op& op = Op::Get("reshape"); @@ -860,7 +860,7 @@ bool ArgWhereRel(const Array& types, TVM_REGISTER_API("relay.op._make.argwhere") .set_body_typed([](Expr data) { static const Op& op = Op::Get("argwhere"); - auto attrs = make_node(); + auto attrs = make_object(); return CallNode::make(op, {data}, Attrs(attrs), {}); }); @@ -938,7 +938,7 @@ Expr MakeTake(Expr data, Expr indices, Integer axis, std::string mode) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = std::move(axis); attrs->mode = std::move(mode); static const Op& op = Op::Get("take"); @@ -1019,7 +1019,7 @@ Array FullCompute(const Attrs& attrs, Expr MakeFull(Expr fill_value, Array shape, DataType dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->shape = std::move(shape); attrs->dtype = std::move(dtype); static const Op& op = Op::Get("full"); @@ -1054,7 +1054,7 @@ bool InitOpRel(const Array& types, Expr MakeZeros(Array shape, DataType dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->shape = std::move(shape); attrs->dtype = std::move(dtype); static const Op& op = Op::Get("zeros"); @@ -1075,7 +1075,7 @@ RELAY_REGISTER_OP("zeros") Expr MakeOnes(Array shape, DataType dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->shape = std::move(shape); attrs->dtype = std::move(dtype); static const Op& op = Op::Get("ones"); @@ -1244,7 +1244,7 @@ Expr MakeArange(Expr start, Expr stop, Expr step, DataType dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->start = start; attrs->stop = stop; attrs->step = step; @@ -1335,7 +1335,7 @@ Array RepeatCompute(const Attrs& attrs, Expr MakeRepeat(Expr data, int repeats, int axis) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->repeats = repeats; attrs->axis = axis; static const Op& op = Op::Get("repeat"); @@ -1445,7 +1445,7 @@ Array TileCompute(const Attrs& attrs, Expr MakeTile(Expr data, Array reps) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->reps = reps; static const Op& op = Op::Get("tile"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -1506,7 +1506,7 @@ Array ReverseCompute(const Attrs& attrs, Expr MakeReverse(Expr data, int axis) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("reverse"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -1623,7 +1623,7 @@ TVM_REGISTER_NODE_TYPE(SqueezeAttrs); Expr MakeSqueeze(Expr data, Array axis) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = std::move(axis); static const Op& op = Op::Get("squeeze"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -1764,7 +1764,7 @@ bool BroadCastToRel(const Array& types, Expr MakeBroadCastTo(Expr data, Array shape) { static const Op& op = Op::Get("broadcast_to"); - auto attrs = make_node(); + auto attrs = make_object(); attrs->shape = std::move(shape); return CallNode::make(op, {data}, Attrs(attrs), {}); } @@ -2006,7 +2006,7 @@ Expr MakeStridedSlice(Expr data, Array begin, Array end, Array strides) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->begin = std::move(begin); attrs->end = std::move(end); attrs->strides = std::move(strides); @@ -2189,9 +2189,9 @@ Array SplitCompute(const Attrs& attrs, } Expr MakeSplit(Expr data, - NodeRef indices_or_sections, + ObjectRef indices_or_sections, int axis) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = axis; attrs->indices_or_sections = std::move(indices_or_sections); static const Op& op = Op::Get("split"); @@ -2294,7 +2294,7 @@ bool SliceLikeRel(const Array& types, Expr MakeSliceLike(Expr data, Expr shape_like, Array axes) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axes = std::move(axes); static const Op& op = Op::Get("slice_like"); return CallNode::make(op, {data, shape_like}, Attrs(attrs), {}); @@ -2403,7 +2403,7 @@ bool LayoutTransformRel(const Array& types, Expr MakeLayoutTransform(Expr data, std::string src_layout, std::string dst_layout) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->src_layout = std::move(src_layout); attrs->dst_layout = std::move(dst_layout); static const Op& op = Op::Get("layout_transform"); @@ -2431,7 +2431,7 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w] /* relay._contrib_reverse_reshape */ Expr MakeReverseReshape(Expr data, Array newshape) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->newshape = std::move(newshape); attrs->reverse = true; static const Op& op = Op::Get("_contrib_reverse_reshape"); @@ -2566,7 +2566,7 @@ Expr MakeSequenceMask(Expr data, Expr valid_length, double mask_value, int axis) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->mask_value = std::move(mask_value); attrs->axis = std::move(axis); static const Op& op = Op::Get("sequence_mask"); @@ -2687,7 +2687,7 @@ Expr MakeOneHot(Expr indices, int depth, int axis, DataType dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->depth = std::move(depth); attrs->axis = axis; attrs->dtype = dtype; diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index 710c910794c84..d4cd7be807b17 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -159,7 +159,7 @@ TVM_REGISTER_NODE_TYPE(ClipAttrs); TVM_REGISTER_API("relay.op._make.clip") .set_body_typed([](Expr a, double a_min, double a_max) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->a_min = a_min; attrs->a_max = a_max; static const Op& op = Op::Get("clip"); @@ -302,7 +302,7 @@ Array ShapeOfCompute(const Attrs& attrs, TVM_REGISTER_API("relay.op._make.shape_of") .set_body_typed([](Expr data, DataType dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("shape_of"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -353,7 +353,7 @@ Array NdarraySizeCompute(const Attrs& attrs, TVM_REGISTER_API("relay.op.contrib._make.ndarray_size") .set_body_typed([](Expr data, DataType dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("contrib.ndarray_size"); return CallNode::make(op, {data}, Attrs(attrs), {}); diff --git a/src/relay/op/vision/multibox_op.cc b/src/relay/op/vision/multibox_op.cc index 28289e76810ff..2dd09403f1442 100644 --- a/src/relay/op/vision/multibox_op.cc +++ b/src/relay/op/vision/multibox_op.cc @@ -60,7 +60,7 @@ Expr MakeMultiBoxPrior(Expr data, Array steps, Array offsets, bool clip) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->sizes = std::move(sizes); attrs->ratios = std::move(ratios); attrs->steps = std::move(steps); @@ -135,7 +135,7 @@ Expr MakeMultiBoxTransformLoc(Expr cls_prob, bool clip, double threshold, Array variances) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->clip = std::move(clip); attrs->threshold = std::move(threshold); attrs->variances = std::move(variances); diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index cba5b6bc7c501..6759e186eeda0 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -52,7 +52,7 @@ Expr MakeGetValidCounts(Expr data, double score_threshold, int id_index, int score_index) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->score_threshold = score_threshold; attrs->id_index = id_index; attrs->score_index = score_index; @@ -114,7 +114,7 @@ Expr MakeNMS(Expr data, int id_index, bool return_indices, bool invalid_to_bottom) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->max_output_size = max_output_size; attrs->iou_threshold = iou_threshold; attrs->force_suppress = force_suppress; diff --git a/src/relay/op/vision/rcnn_op.cc b/src/relay/op/vision/rcnn_op.cc index 52440969ae596..24f4b98b8ed0a 100644 --- a/src/relay/op/vision/rcnn_op.cc +++ b/src/relay/op/vision/rcnn_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -51,7 +51,7 @@ bool ROIAlignRel(const Array& types, int num_inputs, const Attrs& attrs, Expr MakeROIAlign(Expr data, Expr rois, Array pooled_size, double spatial_scale, int sample_ratio, std::string layout) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->pooled_size = pooled_size; attrs->spatial_scale = spatial_scale; attrs->sample_ratio = sample_ratio; @@ -102,7 +102,7 @@ bool ROIPoolRel(const Array& types, int num_inputs, const Attrs& attrs, Expr MakeROIPool(Expr data, Expr rois, Array pooled_size, double spatial_scale, std::string layout) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->pooled_size = pooled_size; attrs->spatial_scale = spatial_scale; attrs->layout = layout; @@ -163,7 +163,7 @@ Expr MakeProposal(Expr cls_prob, Expr bbox_pred, Expr im_info, Array Array ratios, int feature_stride, double threshold, int rpn_pre_nms_top_n, int rpn_post_nms_top_n, int rpn_min_size, bool iou_loss) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->scales = scales; attrs->ratios = ratios; attrs->feature_stride = feature_stride; diff --git a/src/relay/op/vision/yolo.cc b/src/relay/op/vision/yolo.cc index fe0684376c390..74b59f649ccd7 100644 --- a/src/relay/op/vision/yolo.cc +++ b/src/relay/op/vision/yolo.cc @@ -62,7 +62,7 @@ bool YoloReorgRel(const Array& types, Expr MakeYoloReorg(Expr data, Integer stride) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->stride = stride; static const Op& op = Op::Get("vision.yolo_reorg"); return CallNode::make(op, {data}, Attrs(attrs), {}); diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc index bd89c5123bd7f..b3b08c1451013 100644 --- a/src/relay/pass/alter_op_layout.cc +++ b/src/relay/pass/alter_op_layout.cc @@ -107,8 +107,8 @@ class AlterTransformMemorizer : public TransformMemorizer { * 2. Do not support nested tuple arguments. */ Expr AlterOpLayout(const Expr& expr) { - AlterTransformMemorizer alterMemorizer(make_node()); - auto fcontext = [&](const Call& call) -> NodeRef { return alterMemorizer; }; + AlterTransformMemorizer alterMemorizer(make_object()); + auto fcontext = [&](const Call& call) -> ObjectRef { return alterMemorizer; }; return ForwardRewrite(expr, LayoutRewriter, fcontext); } diff --git a/src/relay/pass/canonicalize_cast.cc b/src/relay/pass/canonicalize_cast.cc index 6913eb2d80c52..c790659012eef 100644 --- a/src/relay/pass/canonicalize_cast.cc +++ b/src/relay/pass/canonicalize_cast.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -92,12 +92,11 @@ class CastCanonicalizer : public ExprMutator { } private: - std::unordered_map ref_counter_; + std::unordered_map ref_counter_; // cast op is frequently checked for equivalence. Therefore, we cache it to // reduce lookup overhead. const Op& cast_op_; - Expr GetNewCallArg(const Expr& e) { // if e is a upcast and ref count > 1, create an copy; otherwise call the default visitor Expr new_expr = this->VisitExpr(e); diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index 109d86e806f61..e5c253e4ac56b 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -91,7 +91,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { const CallNode* group_root = branches[0][0]; const auto* attrs = group_root->attrs.as(); CHECK(attrs); - const auto new_attrs = make_node(); + const auto new_attrs = make_object(); new_attrs->strides = attrs->strides; new_attrs->padding = attrs->padding; new_attrs->dilation = attrs->dilation; diff --git a/src/relay/pass/combine_parallel_op.h b/src/relay/pass/combine_parallel_op.h index 858926e662e6a..619a153595b72 100644 --- a/src/relay/pass/combine_parallel_op.h +++ b/src/relay/pass/combine_parallel_op.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -46,10 +46,10 @@ using Branch = std::vector; using Group = std::vector; using FIsSupportedOp = std::function; using FAreCompatibleOps = std::function; -using ExprSubstMap = std::unordered_map; +using ExprSubstMap = std::unordered_map; /* - * Class to find parallel branches starting with op that are + * Class to find parallel branches starting with op that are * grouped if they are able to be combined. They are eligible to * be combined if they have the same input data. * Op can be followed by zero or more elemwise or broadcast ops, @@ -91,22 +91,22 @@ class BranchGroupFinder : private ExprVisitor { const Op& cached_op_; /* \brief function to return true if op is eligible to be combined, - * false otherwise + * false otherwise */ FIsSupportedOp fis_supported_op_; /* \brief function to return true if two parallel ops are eligible - * to be combined, false otherwise + * to be combined, false otherwise */ FAreCompatibleOps fare_compatible_ops_; /* \brief ops that are on the first (logically, leftmost) branch * of parallel ops and are eligible to be combined */ - std::unordered_set op_roots_; + std::unordered_set op_roots_; /* \brief map of Expr to CallNodes that follow it */ - std::unordered_map, NodeHash, NodeEqual> children_map_; + std::unordered_map, ObjectHash, ObjectEqual> children_map_; /* * \brief Creates new branch from op and its children that have diff --git a/src/relay/pass/convert_layout.cc b/src/relay/pass/convert_layout.cc index fa8b8722f8143..8b223ee100d1a 100644 --- a/src/relay/pass/convert_layout.cc +++ b/src/relay/pass/convert_layout.cc @@ -117,8 +117,8 @@ class ConvertTransformMemorizer : public TransformMemorizer { */ Expr ConvertLayout(const Expr& expr, const std::string& desired_layout) { ConvertTransformMemorizer transformMemorizer( - make_node(desired_layout)); - auto fcontext = [&](const Call& call) -> NodeRef { return transformMemorizer; }; + make_object(desired_layout)); + auto fcontext = [&](const Call& call) -> ObjectRef { return transformMemorizer; }; return ForwardRewrite(expr, LayoutRewriter, fcontext); } diff --git a/src/relay/pass/de_duplicate.cc b/src/relay/pass/de_duplicate.cc index af25e9fbac5d7..6816cc7d2d830 100644 --- a/src/relay/pass/de_duplicate.cc +++ b/src/relay/pass/de_duplicate.cc @@ -104,8 +104,8 @@ Expr DeDup(const Expr& e) { } private: - std::unordered_map rename_; - std::unordered_map type_rename_; + std::unordered_map rename_; + std::unordered_map type_rename_; }; CHECK(WellFormed(e)) << AsText(e, false); Expr ret = DeDupMutator().VisitExpr(e); diff --git a/src/relay/pass/dead_code.cc b/src/relay/pass/dead_code.cc index df16baeeed7be..14bca58cf3285 100644 --- a/src/relay/pass/dead_code.cc +++ b/src/relay/pass/dead_code.cc @@ -36,8 +36,8 @@ namespace tvm { namespace relay { template -using VarMap = std::unordered_map; -using VarSet = std::unordered_set; +using VarMap = std::unordered_map; +using VarSet = std::unordered_set; class CalcDep; class FindDef : private ExprVisitor { diff --git a/src/relay/pass/dependency_graph.cc b/src/relay/pass/dependency_graph.cc index 42b829fc3c731..81c205a33c2f2 100644 --- a/src/relay/pass/dependency_graph.cc +++ b/src/relay/pass/dependency_graph.cc @@ -64,7 +64,7 @@ class DependencyGraph::Creator : private ExprFunctor { parent->children.Push(child_link); } - std::unordered_set visited_; + std::unordered_set visited_; DependencyGraph::Node* NewNode(bool new_scope) { auto* ret = arena_->make(); diff --git a/src/relay/pass/dependency_graph.h b/src/relay/pass/dependency_graph.h index 6b2af7e156a8f..d6a4e9588df93 100644 --- a/src/relay/pass/dependency_graph.h +++ b/src/relay/pass/dependency_graph.h @@ -54,7 +54,7 @@ class DependencyGraph { }; /*! \brief Maps a Relay Expr to its node in the dependency graph. */ - std::unordered_map expr_node; + std::unordered_map expr_node; /*! \brief The dependency graph in post DFS order. */ std::vector post_dfs_order; diff --git a/src/relay/pass/device_annotation.cc b/src/relay/pass/device_annotation.cc index 6ad04b0e15e43..91a7fa315f5df 100644 --- a/src/relay/pass/device_annotation.cc +++ b/src/relay/pass/device_annotation.cc @@ -280,7 +280,7 @@ class RewriteAnnotation : public ExprMutator { * \return The created call node. */ Call CreateDeviceCopy(const Expr& src, int src_dev_type, int dst_dev_type) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->src_dev_type = src_dev_type; attrs->dst_dev_type = dst_dev_type; static const Op& op = Op::Get("device_copy"); diff --git a/src/relay/pass/eliminate_common_subexpr.cc b/src/relay/pass/eliminate_common_subexpr.cc index 07827d2c8e142..d180fcc150be8 100644 --- a/src/relay/pass/eliminate_common_subexpr.cc +++ b/src/relay/pass/eliminate_common_subexpr.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -76,7 +76,7 @@ class CommonSubexprEliminator : public ExprMutator { return new_expr; } - std::unordered_map, NodeHash, NodeEqual> expr_map_; + std::unordered_map, ObjectHash, ObjectEqual> expr_map_; runtime::TypedPackedFunc fskip_; }; diff --git a/src/relay/pass/eta_expand.cc b/src/relay/pass/eta_expand.cc index dca08cc834d1c..888874cf0f750 100644 --- a/src/relay/pass/eta_expand.cc +++ b/src/relay/pass/eta_expand.cc @@ -49,7 +49,7 @@ class TypeVarReplacer : public TypeMutator { private: /*! \brief variable replacement map to remap old type vars to fresh ones */ - std::unordered_map replace_map_; + std::unordered_map replace_map_; }; /*! diff --git a/src/relay/pass/expr_subst.cc b/src/relay/pass/expr_subst.cc index baca63233338d..d3e6aa8dbfe60 100644 --- a/src/relay/pass/expr_subst.cc +++ b/src/relay/pass/expr_subst.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -30,7 +30,7 @@ namespace relay { class ExprSubstituter : public ExprMutator { public: - explicit ExprSubstituter(std::unordered_map subst_map) + explicit ExprSubstituter(std::unordered_map subst_map) : subst_map_(subst_map) {} Expr VisitExpr(const Expr& expr) final { @@ -45,7 +45,8 @@ class ExprSubstituter : public ExprMutator { tvm::Map subst_map_; }; -Expr ExprSubst(const Expr& expr, std::unordered_map subst_map) { +Expr ExprSubst(const Expr& expr, + std::unordered_map subst_map) { return ExprSubstituter(std::move(subst_map)).Mutate(expr); } diff --git a/src/relay/pass/expr_subst.h b/src/relay/pass/expr_subst.h index bc53d3f51be0a..2ffefa25657d5 100644 --- a/src/relay/pass/expr_subst.h +++ b/src/relay/pass/expr_subst.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -29,7 +29,8 @@ namespace tvm { namespace relay { -Expr ExprSubst(const Expr& expr, std::unordered_map subst_map); +Expr ExprSubst(const Expr& expr, + std::unordered_map subst_map); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/feature.cc b/src/relay/pass/feature.cc index d610f9523f8ae..79830a709fe09 100644 --- a/src/relay/pass/feature.cc +++ b/src/relay/pass/feature.cc @@ -36,7 +36,7 @@ FeatureSet DetectFeature(const Expr& expr) { return FeatureSet::No(); } struct FeatureDetector : ExprVisitor { - std::unordered_set visited_; + std::unordered_set visited_; FeatureSet fs = FeatureSet::No(); void VisitExpr(const Expr& expr) final { diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index 1e22571f6b432..4a6417b174e35 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -52,7 +52,7 @@ class ConstantChecker : private ExprVisitor { } private: - std::unordered_map memo_; + std::unordered_map memo_; void VisitExpr_(const TupleNode* n) final { bool result = true; @@ -266,7 +266,7 @@ class ConstantFolder : public ExprMutator { } // Cast the constant into correct dtype - auto cast_attrs = make_node(); + auto cast_attrs = make_object(); cast_attrs->dtype = param->dtype; Expr ret = CallNode::make(cast_op_, { shape }, Attrs(cast_attrs), {}); return ConstEvaluate(ret); diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc index e13a50a99c588..711297ca1883d 100644 --- a/src/relay/pass/fold_scale_axis.cc +++ b/src/relay/pass/fold_scale_axis.cc @@ -95,13 +95,16 @@ class MessageNode : public RelayNode { static Message make(const AxesSet& axes, bool require_positive); static constexpr const char* _type_key = "relay.pass.fold_scale_axis.Message"; - TVM_DECLARE_NODE_TYPE_INFO(MessageNode, RelayNode); + TVM_DECLARE_FINAL_OBJECT_INFO(MessageNode, RelayNode); }; -RELAY_DEFINE_NODE_REF(Message, MessageNode, NodeRef); +class Message : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Message, ObjectRef, MessageNode); +}; Message MessageNode::make(const AxesSet& axes, bool require_positive) { - auto n = make_node(); + auto n = make_object(); n->axes = axes; n->require_positive = require_positive; return Message(n); @@ -183,7 +186,7 @@ class ScaledExprNode : public TempExprNode { } static constexpr const char* _type_key = "relay.fold_scale_axis.ScaledExpr"; - TVM_DECLARE_NODE_TYPE_INFO(ScaledExprNode, TempExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(ScaledExprNode, TempExprNode); }; using FForwardRewrite = TypedPackedFunc< @@ -196,7 +199,7 @@ using FForwardRewrite = TypedPackedFunc< //---------------------------------------------- class ForwardPrep : private ExprVisitor { public: - std::unordered_map + std::unordered_map Prepare(const Expr& body) { this->Update(body, NullValue()); this->VisitExpr(body); @@ -215,7 +218,7 @@ class ForwardPrep : private ExprVisitor { // The invoke list std::vector > flist_; // The message on each node. - std::unordered_map message_; + std::unordered_map message_; // Update the message stored at node. void Update(const Expr& node, const Message& message) { // We run intersection of messages: @@ -228,7 +231,7 @@ class ForwardPrep : private ExprVisitor { // because %z2 will propagate null to %y, // the AxesSet on %y is also null, // and the forward folding won't be triggered. - const Node* key = node.get(); + const Object* key = node.get(); if (message_.count(key)) { message_[key] = Intersect(message_[key], message); } else { @@ -323,7 +326,7 @@ Expr ReluForwardRewrite(const Call& ref_call, const auto* input = new_args[0].as(); if (input == nullptr) return Expr(nullptr); // return transformed conv2d - auto rnode = make_node(); + auto rnode = make_object(); rnode->value = CallNode::make( ref_call->op, {input->value}, ref_call->attrs, ref_call->type_args); rnode->scale = input->scale; @@ -366,7 +369,7 @@ Expr AddSubForwardRewrite(const Call& ref_call, if (!slhs && !srhs) return Expr(); const auto* tlhs = ref_call->args[0]->type_as(); const auto* trhs = ref_call->args[1]->type_as(); - auto rnode = make_node(); + auto rnode = make_object(); if (slhs != nullptr) { CHECK(srhs == nullptr); @@ -422,7 +425,7 @@ Expr MultiplyForwardRewrite(const Call& ref_call, const auto* trhs = ref_call->args[1]->type_as(); Expr lhs = new_args[0]; Expr rhs = new_args[1]; - auto rnode = make_node(); + auto rnode = make_object(); if (MatchBroadcastToLeftAxes(tlhs, trhs, expected_out_axes, &rhs) && (!message->require_positive || IsAllPositiveConstant(rhs))) { @@ -531,12 +534,12 @@ RELAY_REGISTER_OP("nn.conv2d") Expr ForwardFoldScaleAxis(const Expr& data) { auto message = ForwardPrep().Prepare(data); - auto fcontext = [&](const Call& call) -> NodeRef{ + auto fcontext = [&](const Call& call) -> ObjectRef{ auto it = message.find(call.get()); if (it != message.end()) { return it->second; } else { - return NodeRef(nullptr); + return ObjectRef(nullptr); } }; return ForwardRewrite( @@ -571,7 +574,7 @@ using FBackwardTransform = TypedPackedFunc< class BackwardPrep : private ExprVisitor { public: // The message on each node. - std::unordered_map + std::unordered_map Prepare(const Expr& body) { ref_counter_ = GetExprRefCount(body); this->VisitExpr(body); @@ -580,9 +583,9 @@ class BackwardPrep : private ExprVisitor { private: // The message on each node. - std::unordered_map message_; + std::unordered_map message_; // reference counter of an internal expr - std::unordered_map ref_counter_; + std::unordered_map ref_counter_; // Visit the expression. void VisitExpr_(const CallNode* call) { ExprVisitor::VisitExpr_(call); @@ -612,7 +615,7 @@ class BackwardPrep : private ExprVisitor { }; class BackwardTransformerNode : - public Node, + public Object, private ExprMutator { public: // Run forward transform. @@ -667,11 +670,11 @@ class BackwardTransformerNode : void VisitAttrs(tvm::AttrVisitor* v) {} static constexpr const char* _type_key = "relay.fold_scale_axis.FBackwardTransformer"; - TVM_DECLARE_NODE_TYPE_INFO(BackwardTransformerNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(BackwardTransformerNode, Object); private: // Valid axes on each node. - std::unordered_map message_; + std::unordered_map message_; // Override mutation of call. Expr VisitExpr_(const CallNode* call_node) final { return Transform(call_node, NullValue(), NullValue()); @@ -680,11 +683,11 @@ class BackwardTransformerNode : Expr Transform(const CallNode* call_node, Message message, Expr scale); }; -class BackwardTransformer : public NodeRef { +class BackwardTransformer : public ObjectRef { public: BackwardTransformer() {} explicit BackwardTransformer( - ::tvm::ObjectPtr<::tvm::Object> n) : NodeRef(n) { + ::tvm::ObjectPtr<::tvm::Object> n) : ObjectRef(n) { } BackwardTransformerNode* operator->() const { return static_cast(get_mutable()); @@ -938,7 +941,7 @@ RELAY_REGISTER_OP("nn.conv2d") .set_attr("FScaleAxisBackwardTransform", Conv2DBackwardTransform); Expr BackwardFoldScaleAxis(const Expr& data) { - return make_node()->Fold(data); + return make_object()->Fold(data); } } // namespace fold_scale_axis diff --git a/src/relay/pass/forward_rewrite.cc b/src/relay/pass/forward_rewrite.cc index fe5cc36cba95f..fe0df010b6267 100644 --- a/src/relay/pass/forward_rewrite.cc +++ b/src/relay/pass/forward_rewrite.cc @@ -61,14 +61,14 @@ class TempRealizer : private ExprMutator { class ForwardRewriter : private ExprMutator { public: ForwardRewriter(const OpMap* rewrite_map, - std::function fcontext, + std::function fcontext, std::function fmulti_ref_trigger) : rewrite_map_(rewrite_map), fcontext_(fcontext), fmulti_ref_trigger_(fmulti_ref_trigger) {} ForwardRewriter(const FForwardRewrite* rewrite_func, - std::function fcontext, + std::function fcontext, std::function fmulti_ref_trigger) : rewrite_func_(rewrite_func), fcontext_(fcontext), @@ -88,11 +88,11 @@ class ForwardRewriter : private ExprMutator { const OpMap* rewrite_map_{nullptr}; const FForwardRewrite* rewrite_func_{nullptr}; // The context.const - std::function fcontext_{nullptr}; + std::function fcontext_{nullptr}; // The multiple reference trigger std::function fmulti_ref_trigger_{nullptr}; // Internal ref counter - std::unordered_map ref_counter_; + std::unordered_map ref_counter_; // internal realizer TempRealizer realizer_; @@ -172,7 +172,7 @@ class ForwardRewriter : private ExprMutator { if (frewrite != nullptr) { Expr res = frewrite( ref_call, call_args, - fcontext_ != nullptr ? fcontext_(ref_call) : NodeRef(nullptr)); + fcontext_ != nullptr ? fcontext_(ref_call) : ObjectRef(nullptr)); if (res.defined()) return res; // abort, use old rule for (size_t i = 0; i < call_args.size(); ++i) { @@ -192,7 +192,7 @@ class ForwardRewriter : private ExprMutator { Expr ForwardRewrite(const Expr& expr, const std::string& rewrite_map_name, - std::function fcontext, + std::function fcontext, std::function fmulti_ref_trigger) { auto rewrite_map = Op::GetAttr(rewrite_map_name); return ForwardRewriter(&rewrite_map, fcontext, fmulti_ref_trigger).Rewrite(expr); @@ -200,7 +200,7 @@ Expr ForwardRewrite(const Expr& expr, Expr ForwardRewrite(const Expr& expr, const FForwardRewrite& rewrite_func, - std::function fcontext, + std::function fcontext, std::function fmulti_ref_trigger) { return ForwardRewriter(&rewrite_func, fcontext, fmulti_ref_trigger).Rewrite(expr); } diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 8209a8010b986..7b8f6de743823 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -103,7 +103,7 @@ class IndexedForwardGraph { /*! \brief A node in the graph. */ struct Node { /*! \brief weak reference to the corresponding edge. */ - const tvm::Node* ref{nullptr}; + const tvm::Object* ref{nullptr}; /*! \brief The index of the node in topological order. */ size_t index{0}; /*! \brief Whether this node is referenced by external source */ @@ -114,7 +114,7 @@ class IndexedForwardGraph { LinkedList outputs; }; /*! \brief The node map that maps node to graph */ - std::unordered_map node_map; + std::unordered_map node_map; /*! \brief All the nodes in post DFS order */ std::vector post_dfs_order; @@ -124,7 +124,7 @@ class IndexedForwardGraph { for (size_t i = 0; i < post_dfs_order.size(); ++i) { Node* node = post_dfs_order[i]; os << "node[" << i << "], " - << GetRef(node->ref) + << GetRef(node->ref) << " outputs=["; for (auto* link = node->outputs.head; link != nullptr; link = link->next) { os << link->value.node->index << ", "; @@ -167,7 +167,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { void Update(const Expr& node, IndexedForwardGraph::Node* parent, OpPatternKind pattern) { - const tvm::Node* key = node.get(); + const tvm::Object* key = node.get(); IndexedForwardGraph::Node* current; auto it = graph_.node_map.find(key); if (it != graph_.node_map.end()) { @@ -186,10 +186,10 @@ class IndexedForwardGraph::Creator : private ExprVisitor { } } - void AddNode(const tvm::Node* key) { + void AddNode(const tvm::Object* key) { auto it = graph_.node_map.find(key); CHECK(it != graph_.node_map.end()) - << "Cannot find node " << GetRef(key); + << "Cannot find node " << GetRef(key); IndexedForwardGraph::Node* node = it->second; CHECK(node->ref == nullptr); node->ref = key; @@ -523,12 +523,12 @@ class GraphPartitioner { /*! \brief The pattern of the group */ OpPatternKind pattern; /*! \brief reference to the root node. */ - const tvm::Node* root_ref{nullptr}; + const tvm::Object* root_ref{nullptr}; /*! * \brief Reference to the master node, * this field is not nullptr only if pattern is kOutEWiseFusable. */ - const tvm::Node* master_ref{nullptr}; + const tvm::Object* master_ref{nullptr}; /*! * \brief Find the group root, perform path compression * \return The root type node. @@ -847,7 +847,7 @@ class FuseMutator : private ExprMutator { /*! \brief Internal arena. */ common::Arena arena_; /*! \brief The group assignment map. */ - std::unordered_map gmap_; + std::unordered_map gmap_; /* \brief Internal group information map. */ std::unordered_map ginfo_; diff --git a/src/relay/pass/gradient.cc b/src/relay/pass/gradient.cc index 69d12c26f1037..61f7e2d8979d4 100644 --- a/src/relay/pass/gradient.cc +++ b/src/relay/pass/gradient.cc @@ -133,7 +133,7 @@ struct FirstOrderReverseAD : ExprFunctor { const OpMap rev_map = Op::GetAttr("FPrimalGradient"); std::vector> backprop_actions; // we assume no closure so no need for lexical scoping - std::unordered_map env; + std::unordered_map env; LetList* ll; FirstOrderReverseAD(LetList* ll) : ll(ll) { } @@ -385,7 +385,7 @@ Expr BPEmpty() { } struct ReverseAD : ExprMutator { - using ADVarMap = std::unordered_map; + using ADVarMap = std::unordered_map; Var bp; std::shared_ptr ad_vars; diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index afcc4935fa417..7a524ee233180 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -63,7 +63,7 @@ * so we have to deduplicate them. * * 4: In the generated code, as it call TypeSubst, multiple VarNode might have same Id. - * While it is permitted, most pass use NodeHash for Var, + * While it is permitted, most pass use ObjectHash for Var, * and having multiple VarNode for same Id break them. * Thus we remap them to a single Id for now. * @@ -110,7 +110,7 @@ using namespace runtime; */ struct VarHash { size_t operator()(const Var& v) const { - return NodeHash()(v->vid); + return ObjectHash()(v->vid); } }; @@ -130,13 +130,13 @@ Expr PostProcess(const Expr&); class StaticNode : public RelayNode { public: static constexpr const char* _type_key = "relay.Static"; - TVM_DECLARE_BASE_NODE_INFO(StaticNode, RelayNode); + TVM_DECLARE_BASE_OBJECT_INFO(StaticNode, RelayNode); }; -class Static : public NodeRef { +class Static : public ObjectRef { public: Static() {} - explicit Static(ObjectPtr n) : NodeRef(n) {} + explicit Static(ObjectPtr n) : ObjectRef(n) {} const StaticNode* operator->() const { return static_cast(get()); } @@ -146,7 +146,7 @@ class Static : public NodeRef { using Time = size_t; -struct PStaticNode : Node { +struct PStaticNode : Object { static Time time() { static Time time_ = 0; Time ret = time_; @@ -160,35 +160,44 @@ struct PStaticNode : Node { pstatic(pstatic), dynamic(dynamic), created_time(time()) { } explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) { } static constexpr const char* _type_key = "relay.PStatic"; - TVM_DECLARE_NODE_TYPE_INFO(PStaticNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(PStaticNode, Object); }; -RELAY_DEFINE_NODE_REF(PStatic, PStaticNode, NodeRef); +class PStatic : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(PStatic, ObjectRef, PStaticNode); +}; struct STupleNode : StaticNode { std::vector fields; explicit STupleNode(const std::vector& fields) : fields(fields) { } static constexpr const char* _type_key = "relay.STuple"; - TVM_DECLARE_NODE_TYPE_INFO(STupleNode, StaticNode); + TVM_DECLARE_FINAL_OBJECT_INFO(STupleNode, StaticNode); }; -RELAY_DEFINE_NODE_REF(STuple, STupleNode, Static); +class STuple : public Static { + public: + TVM_DEFINE_OBJECT_REF_METHODS(STuple, Static, STupleNode); +}; Static MkSTuple(const std::vector& fields) { - return Static(make_node(fields)); + return Static(make_object(fields)); } struct STensorNode : StaticNode { runtime::NDArray data; explicit STensorNode(const NDArray& data) : data(data) { } static constexpr const char* _type_key = "relay.STensor"; - TVM_DECLARE_NODE_TYPE_INFO(STensorNode, StaticNode); + TVM_DECLARE_FINAL_OBJECT_INFO(STensorNode, StaticNode); }; -RELAY_DEFINE_NODE_REF(STensor, STensorNode, Static); +class STensor : public Static { + public: + TVM_DEFINE_OBJECT_REF_METHODS(STensor, Static, STensorNode); +}; Static MkSTensor(const NDArray& data) { - return Static(make_node(data)); + return Static(make_object(data)); } struct SConstructorNode : StaticNode { @@ -197,25 +206,31 @@ struct SConstructorNode : StaticNode { SConstructorNode(const Constructor& constructor, const std::vector& fields) : constructor(constructor), fields(fields) { } static constexpr const char* _type_key = "relay.SConstructor"; - TVM_DECLARE_NODE_TYPE_INFO(SConstructorNode, StaticNode); + TVM_DECLARE_FINAL_OBJECT_INFO(SConstructorNode, StaticNode); }; -RELAY_DEFINE_NODE_REF(SConstructor, SConstructorNode, Static); +class SConstructor : public Static { + public: + TVM_DEFINE_OBJECT_REF_METHODS(SConstructor, Static, SConstructorNode); +}; Static MkSConstructor(const Constructor& constructor, const std::vector& fields) { - return Static(make_node(constructor, fields)); + return Static(make_object(constructor, fields)); } struct SRefNode : StaticNode { static constexpr const char* _type_key = "relay.SRef"; // we will use the address as the guid for hashing - TVM_DECLARE_NODE_TYPE_INFO(SRefNode, StaticNode); + TVM_DECLARE_FINAL_OBJECT_INFO(SRefNode, StaticNode); }; -RELAY_DEFINE_NODE_REF(SRef, SRefNode, Static); +class SRef : public Static { + public: + TVM_DEFINE_OBJECT_REF_METHODS(SRef, Static, SRefNode); +}; Static MkSRef() { - return Static(make_node()); + return Static(make_object()); } using Func = std::function(func)); + return Static(make_object(func)); } @@ -246,10 +264,10 @@ class FuelNode; * Every time we recurse, we do a meet and require that progress must be made. * This ensures we do not recurse infinitely in the Partial Evaluator. */ -class Fuel : public NodeRef { +class Fuel : public ObjectRef { public: Fuel() {} - explicit Fuel(ObjectPtr n) : NodeRef(n) {} + explicit Fuel(ObjectPtr n) : ObjectRef(n) {} const FuelNode* operator->() const; using ContainerType = FuelNode; @@ -279,7 +297,7 @@ class FuelNode : public RelayNode { return std::get<0>(ret); } static constexpr const char* _type_key = "relay.Fuel"; - TVM_DECLARE_BASE_NODE_INFO(FuelNode, RelayNode); + TVM_DECLARE_BASE_OBJECT_INFO(FuelNode, RelayNode); }; const FuelNode* Fuel::operator->() const { @@ -301,13 +319,16 @@ struct FSeqNode : FuelNode { } explicit FSeqNode(const std::vector& fuels) : fuels(fuels) { } static constexpr const char* _type_key = "relay.FSeq"; - TVM_DECLARE_NODE_TYPE_INFO(FSeqNode, FuelNode); + TVM_DECLARE_FINAL_OBJECT_INFO(FSeqNode, FuelNode); }; -RELAY_DEFINE_NODE_REF(FSeq, FSeqNode, Fuel); +class FSeq : public Fuel { + public: + TVM_DEFINE_OBJECT_REF_METHODS(FSeq, Fuel, FSeqNode); +}; Fuel MkFSeq(const std::vector& fuels) { - return Fuel(make_node(fuels)); + return Fuel(make_object(fuels)); } Fuel MkFTime(Time time); @@ -321,13 +342,16 @@ struct FTimeNode : FuelNode { } explicit FTimeNode(Time time) : time(time) { } static constexpr const char* _type_key = "relay.FTime"; - TVM_DECLARE_NODE_TYPE_INFO(FTimeNode, FuelNode); + TVM_DECLARE_FINAL_OBJECT_INFO(FTimeNode, FuelNode); }; -RELAY_DEFINE_NODE_REF(FTime, FTimeNode, Fuel); +class FTime : public Fuel { + public: + TVM_DEFINE_OBJECT_REF_METHODS(FTime, Fuel, FTimeNode); +}; Fuel MkFTime(Time time) { - return Fuel(make_node(time)); + return Fuel(make_object(time)); } Fuel MkFTValue(size_t tvalue); @@ -342,13 +366,16 @@ struct FTValueNode : FuelNode { } explicit FTValueNode(size_t tvalue) : tvalue(tvalue) { } static constexpr const char* _type_key = "relay.FTValue"; - TVM_DECLARE_NODE_TYPE_INFO(FTValueNode, FuelNode); + TVM_DECLARE_FINAL_OBJECT_INFO(FTValueNode, FuelNode); }; -RELAY_DEFINE_NODE_REF(FTValue, FTValueNode, Fuel); +class FTValue : public Fuel { + public: + TVM_DEFINE_OBJECT_REF_METHODS(FTValue, Fuel, FTValueNode); +}; Fuel MkFTValue(size_t tvalue) { - return Fuel(make_node(tvalue)); + return Fuel(make_object(tvalue)); } /*! \brief Initially every element has Fuel of FTop. It is the largest element. @@ -361,13 +388,16 @@ struct FTopNode : FuelNode { return std::make_tuple(f, !f.as()); } static constexpr const char* _type_key = "relay.FTop"; - TVM_DECLARE_NODE_TYPE_INFO(FTopNode, FuelNode); + TVM_DECLARE_FINAL_OBJECT_INFO(FTopNode, FuelNode); }; -RELAY_DEFINE_NODE_REF(FTop, FTopNode, Fuel); +class FTop : public Fuel { + public: + TVM_DEFINE_OBJECT_REF_METHODS(FTop, Fuel, FTopNode); +}; Fuel MkFTop() { - return Fuel(make_node()); + return Fuel(make_object()); } /*! @@ -500,11 +530,11 @@ class Store { PStatic HasStatic(const Static& stat, const Expr& dynamic) { CHECK(stat.defined()); - return PStatic(make_node(stat, dynamic)); + return PStatic(make_object(stat, dynamic)); } PStatic NoStatic(const Expr& dynamic) { - return PStatic(make_node(dynamic)); + return PStatic(make_object(dynamic)); } enum struct MatchStatus { @@ -559,6 +589,7 @@ struct WithFuncIdAttrs : public tvm::AttrsNode { TVM_REGISTER_NODE_TYPE(WithFuncIdAttrs); + RELAY_REGISTER_OP("annotation.with_funcid") .describe(R"code(Annotate a function with a funcid.)code" TVM_ADD_FILELINE) @@ -569,7 +600,7 @@ TVM_ADD_FILELINE) static const Op& with_funcid_op = Op::Get("annotation.with_funcid"); Expr MkWithFuncId(const Expr& expr, FuncId fid) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->fid = fid; return CallNode::make(with_funcid_op, {expr}, Attrs(attrs), {}); } @@ -1147,7 +1178,7 @@ class PartialEvaluator : public ExprFunctor private: Environment env_; Module mod_; - std::unordered_map gv_map_; + std::unordered_map gv_map_; /*! Termination checking is done as follows: * We have finitely many FunctionIds. * Each FunctionId maps to a class of semantically equivalent function (ignoring type), @@ -1161,7 +1192,7 @@ class PartialEvaluator : public ExprFunctor * when we PE inside the Function body. * Termination is guaranteed because Fuel is finitely descending - there can only be so many meet. */ - std::unordered_map func_map_; + std::unordered_map func_map_; std::unordered_map fuel_map_; Store store_; DLContext context_ = CPUContext(); diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index 97b8fd681cb8c..909ba0b8d712c 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -44,7 +44,7 @@ struct RelayPassContextThreadLocalEntry { std::stack context_stack; RelayPassContextThreadLocalEntry() { - default_context = PassContext(make_node()); + default_context = PassContext(make_object()); } }; @@ -77,7 +77,7 @@ PassContext PassContext::Current() { } PassContext PassContext::Create() { - return PassContext(make_node()); + return PassContext(make_object()); } class ModulePass; @@ -126,10 +126,13 @@ class ModulePassNode : public PassNode { PassInfo pass_info); static constexpr const char* _type_key = "relay.ModulePass"; - TVM_DECLARE_NODE_TYPE_INFO(ModulePassNode, PassNode); + TVM_DECLARE_FINAL_OBJECT_INFO(ModulePassNode, PassNode); }; -RELAY_DEFINE_NODE_REF(ModulePass, ModulePassNode, Pass); +class ModulePass : public Pass { + public: + TVM_DEFINE_OBJECT_REF_METHODS(ModulePass, Pass, ModulePassNode); +}; class FunctionPass; @@ -180,7 +183,7 @@ class FunctionPassNode : public PassNode { PassInfo pass_info); static constexpr const char* _type_key = "relay.FunctionPass"; - TVM_DECLARE_NODE_TYPE_INFO(FunctionPassNode, PassNode); + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPassNode, PassNode); private: /* @@ -193,7 +196,10 @@ class FunctionPassNode : public PassNode { bool SkipFunction(const Function& func) const; }; -RELAY_DEFINE_NODE_REF(FunctionPass, FunctionPassNode, Pass); +class FunctionPass : public Pass { + public: + TVM_DEFINE_OBJECT_REF_METHODS(FunctionPass, Pass, FunctionPassNode); +}; /*! * \brief The SequentialNode contains a set of passes that transform Relay @@ -258,13 +264,13 @@ class SequentialNode : public PassNode { Module operator()(const Module& mod, const PassContext& pass_ctx) const final; static constexpr const char* _type_key = "relay.Sequential"; - TVM_DECLARE_NODE_TYPE_INFO(SequentialNode, PassNode); + TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode); }; PassInfo PassInfoNode::make(int opt_level, std::string name, tvm::Array required) { - auto pass_info = make_node(); + auto pass_info = make_object(); pass_info->opt_level = opt_level; pass_info->name = std::move(name); pass_info->required = std::move(required); @@ -274,7 +280,7 @@ PassInfo PassInfoNode::make(int opt_level, ModulePass ModulePassNode::make( runtime::TypedPackedFunc pass_func, PassInfo pass_info) { - auto n = make_node(); + auto n = make_object(); n->pass_func = std::move(pass_func); n->pass_info = std::move(pass_info); return ModulePass(n); @@ -297,7 +303,7 @@ Module ModulePassNode::operator()(const Module& mod, FunctionPass FunctionPassNode::make( runtime::TypedPackedFunc pass_func, PassInfo pass_info) { - auto n = make_node(); + auto n = make_object(); n->pass_func = std::move(pass_func); n->pass_info = std::move(pass_info); return FunctionPass(n); @@ -330,20 +336,20 @@ Module FunctionPassNode::operator()(const Module& mod, } bool FunctionPassNode::SkipFunction(const Function& func) const { - NodeRef skip_opt = FunctionGetAttr(func, attr::kSkipOptimization); + ObjectRef skip_opt = FunctionGetAttr(func, attr::kSkipOptimization); const ir::IntImm* pval = skip_opt.as(); return (pval && pval->value != 0) || (!func->UseDefaultCompiler()); } Sequential::Sequential(tvm::Array passes, PassInfo pass_info) { - auto n = make_node(); + auto n = make_object(); n->passes = std::move(passes); n->pass_info = std::move(pass_info); data_ = std::move(n); } Sequential::Sequential(tvm::Array passes, std::string name) { - auto n = make_node(); + auto n = make_object(); n->passes = std::move(passes); PassInfo pass_info = PassInfoNode::make(2, std::move(name), {}); n->pass_info = std::move(pass_info); diff --git a/src/relay/pass/pass_util.h b/src/relay/pass/pass_util.h index 225ce610d9091..2d4722bc6759f 100644 --- a/src/relay/pass/pass_util.h +++ b/src/relay/pass/pass_util.h @@ -39,7 +39,7 @@ namespace relay { * \param body The body expression. * \return The reference count mapping. */ -std::unordered_map +std::unordered_map GetExprRefCount(const Expr& body); /*! @@ -108,57 +108,57 @@ inline bool IsAtomic(const Expr& e) { return e.as() || e.as() || e.as() || e.as(); } -template +template struct TreeNode { - typedef std::shared_ptr> pointer; + typedef std::shared_ptr> pointer; virtual ~TreeNode() {} }; -template -struct TreeLeafNode : TreeNode { - using TreeNodePtr = typename TreeNode::pointer; +template +struct TreeLeafNode : TreeNode { + using TreeObjectPtr = typename TreeNode::pointer; Expr body; explicit TreeLeafNode(Expr body): body(body) {} - static TreeNodePtr Make(Expr body) { + static TreeObjectPtr Make(Expr body) { return std::make_shared(body); } ~TreeLeafNode() {} }; -template -struct TreeLeafFatalNode : TreeNode { - using TreeNodePtr = typename TreeNode::pointer; +template +struct TreeLeafFatalNode : TreeNode { + using TreeObjectPtr = typename TreeNode::pointer; TreeLeafFatalNode() = default; - static TreeNodePtr Make() { + static TreeObjectPtr Make() { return std::make_shared(); } ~TreeLeafFatalNode() {} }; -template -struct TreeBranchNode : TreeNode { - using TreeNodePtr = typename TreeNode::pointer; +template +struct TreeBranchNode : TreeNode { + using TreeObjectPtr = typename TreeNode::pointer; - ConditionNodePtr cond; - TreeNodePtr then_branch; - TreeNodePtr else_branch; + ConditionObjectPtr cond; + TreeObjectPtr then_branch; + TreeObjectPtr else_branch; - TreeBranchNode(ConditionNodePtr cond, - TreeNodePtr then_branch, - TreeNodePtr else_branch) + TreeBranchNode(ConditionObjectPtr cond, + TreeObjectPtr then_branch, + TreeObjectPtr else_branch) : cond(cond), then_branch(then_branch), else_branch(else_branch) {} - static TreeNodePtr Make(ConditionNodePtr cond, - TreeNodePtr then_branch, - TreeNodePtr else_branch) { + static TreeObjectPtr Make(ConditionObjectPtr cond, + TreeObjectPtr then_branch, + TreeObjectPtr else_branch) { return std::make_shared(cond, then_branch, else_branch); } diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 5e93ea1ff0aad..d3ec342b883d6 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -104,9 +104,9 @@ inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs, size_t base = tlhs->shape.size() - trhs->shape.size(); size_t j = 0; - NodePtr squeeze_attrs; + ObjectPtr squeeze_attrs; if (rhs_value != nullptr) { - squeeze_attrs = make_node(); + squeeze_attrs = make_object(); } for (size_t i = 0; i < tlhs->shape.size(); ++i) { @@ -149,7 +149,7 @@ inline Expr ExpandBiasToMatchAxis(Expr bias, if (i == axes.size()) { int64_t num_pad_axis = target_ndim - axes[i - 1]->value - 1; if (num_pad_axis > 0) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = i; attrs->num_newaxis = static_cast(num_pad_axis); bias = CallNode::make(expand_dims, {bias}, Attrs(attrs), {}); @@ -158,7 +158,7 @@ inline Expr ExpandBiasToMatchAxis(Expr bias, int64_t diff = axes[i]->value - axes[i - 1]->value; CHECK_GE(diff, 0L); if (diff > 0) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = i; attrs->num_newaxis = static_cast(diff); bias = CallNode::make(expand_dims, {bias}, Attrs(attrs), {}); @@ -291,7 +291,7 @@ T GetScalarFromConstant(Expr expr) { inline Expr Cast(Expr x, DataType dtype) { static const Op& op = Op::Get("cast"); - auto attrs = make_node(); + auto attrs = make_object(); attrs->dtype = dtype; return CallNode::make(op, {x}, Attrs(attrs), {}); } @@ -322,7 +322,7 @@ inline Expr Round(Expr x) { inline Expr Clip(Expr x, double a_min, double a_max) { static const Op& op = Op::Get("clip"); - auto attrs = make_node(); + auto attrs = make_object(); attrs->a_min = a_min; attrs->a_max = a_max; return CallNode::make(op, {x}, Attrs(attrs), {}); @@ -358,7 +358,7 @@ inline Expr ZerosLike(Expr e) { } inline Expr Zeros(Array shape, DataType dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->shape = std::move(shape); attrs->dtype = std::move(dtype); static const Op& op = Op::Get("zeros"); @@ -406,7 +406,7 @@ inline Expr Copy(Expr data) { inline Expr Mean(Expr data, Array axis, bool keepdims, bool exclude) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = std::move(axis); attrs->keepdims = keepdims; attrs->exclude = exclude; @@ -415,7 +415,7 @@ inline Expr Mean(Expr data, Array axis, bool keepdims, bool exclude) { } inline Expr Variance(Expr data, Expr mean, Array axis, bool keepdims, bool exclude) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = std::move(axis); attrs->keepdims = keepdims; attrs->exclude = exclude; @@ -437,7 +437,7 @@ static inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) { static inline Expr Full(Expr fill_value, Array shape, DataType dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->shape = std::move(shape); attrs->dtype = std::move(dtype); static const Op& op = Op::Get("full"); @@ -448,7 +448,7 @@ static inline Expr Conv2D(Expr data, Expr weight, Array strides, Array padding, Array dilation, int groups, IndexExpr channels, Array kernel_size, std::string data_layout, std::string kernel_layout, std::string out_layout, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->strides = std::move(strides); attrs->padding = std::move(padding); attrs->dilation = std::move(dilation); @@ -467,7 +467,7 @@ static inline Expr Dense(Expr data, Expr weight, IndexExpr units, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->units = units; attrs->out_dtype = out_dtype; static const Op& op = Op::Get("nn.dense"); @@ -475,7 +475,7 @@ static inline Expr Dense(Expr data, } static inline Expr Sum(Expr data, Array axis, bool keepdims, bool exclude) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = std::move(axis); attrs->keepdims = keepdims; attrs->exclude = exclude; @@ -484,7 +484,7 @@ static inline Expr Sum(Expr data, Array axis, bool keepdims, bool exclu } static inline Expr Reshape(Expr data, Array newshape) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->newshape = std::move(newshape); attrs->reverse = false; static const Op& op = Op::Get("reshape"); @@ -494,7 +494,7 @@ static inline Expr Reshape(Expr data, Array newshape) { static inline Expr AvgPool2D(Expr data, Array pool_size, Array strides, Array padding, std::string layout, bool ceil_mode, bool count_include_pad) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = std::move(strides); attrs->padding = std::move(padding); @@ -507,7 +507,7 @@ static inline Expr AvgPool2D(Expr data, Array pool_size, Array> pad_width, double pad_value, std::string pad_mode) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->pad_value = pad_value; attrs->pad_width = std::move(pad_width); attrs->pad_mode = std::move(pad_mode); @@ -516,7 +516,7 @@ static inline Expr Pad(Expr data, Array> pad_width, double pad_ } static inline Expr Tile(Expr data, Array reps) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->reps = reps; static const Op& op = Op::Get("tile"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -530,7 +530,7 @@ Expr MakeStridedSlice(Expr data, Array begin, Array end, Array Expr MakeStack(Expr data, int axis); -Expr MakeSplit(Expr data, NodeRef indices_or_sections, int axis); +Expr MakeSplit(Expr data, ObjectRef indices_or_sections, int axis); Expr MakeSqueeze(Expr data, Array axis); diff --git a/src/relay/pass/quantize/annotate.cc b/src/relay/pass/quantize/annotate.cc index c834d8e868fa4..c3d01071caf70 100644 --- a/src/relay/pass/quantize/annotate.cc +++ b/src/relay/pass/quantize/annotate.cc @@ -50,10 +50,13 @@ class QAnnotateExprNode : public TempExprNode { Expr Realize() const final; static constexpr const char* _type_key = "relay.QAnnotateExpr"; - TVM_DECLARE_NODE_TYPE_INFO(QAnnotateExprNode, TempExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(QAnnotateExprNode, TempExprNode); }; -RELAY_DEFINE_NODE_REF(QAnnotateExpr, QAnnotateExprNode, TempExpr); +class QAnnotateExpr : public TempExpr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(QAnnotateExpr, TempExpr, QAnnotateExprNode); +}; Expr QAnnotateExprNode::Realize() const { @@ -61,7 +64,7 @@ Expr QAnnotateExprNode::Realize() const { } QAnnotateExpr QAnnotateExprNode::make(Expr expr, QAnnotateKind kind) { - auto rnode = make_node(); + auto rnode = make_object(); rnode->expr = expr; rnode->kind = kind; return QAnnotateExpr(rnode); diff --git a/src/relay/pass/quantize/calibrate.cc b/src/relay/pass/quantize/calibrate.cc index e78abbf6aee0c..f9893f57a2c12 100644 --- a/src/relay/pass/quantize/calibrate.cc +++ b/src/relay/pass/quantize/calibrate.cc @@ -56,7 +56,7 @@ class StatsCollector : private ExprMutator { if (new_call->op == simulated_quantize_op_) { auto attrs = new_call->attrs.as(); // rewrite the annotation - auto new_attrs = make_node(); + auto new_attrs = make_object(); const Expr& quantize_input = new_call->args[0]; // expression being quantized auto placeholder = MakeConstantScalar(DataType::Float(32), 0.); // unused argument Array new_args{quantize_input, placeholder, placeholder, placeholder}; diff --git a/src/relay/pass/quantize/partition.cc b/src/relay/pass/quantize/partition.cc index 6914831029379..710684caa1b1f 100644 --- a/src/relay/pass/quantize/partition.cc +++ b/src/relay/pass/quantize/partition.cc @@ -50,10 +50,13 @@ class QPartitionExprNode : public TempExprNode { Expr Realize() const final; static constexpr const char* _type_key = "relay.QPartitionExpr"; - TVM_DECLARE_NODE_TYPE_INFO(QPartitionExprNode, TempExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(QPartitionExprNode, TempExprNode); }; -RELAY_DEFINE_NODE_REF(QPartitionExpr, QPartitionExprNode, TempExpr); +class QPartitionExpr : public TempExpr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(QPartitionExpr, TempExpr, QPartitionExprNode); +}; Expr QPartitionExprNode::Realize() const { @@ -64,7 +67,7 @@ Expr QPartitionExprNode::Realize() const { } QPartitionExpr QPartitionExprNode::make(Expr expr) { - auto rnode = make_node(); + auto rnode = make_object(); rnode->expr = expr; return QPartitionExpr(rnode); } diff --git a/src/relay/pass/quantize/quantize.cc b/src/relay/pass/quantize/quantize.cc index c022d4236b05e..ef78bf2503d82 100644 --- a/src/relay/pass/quantize/quantize.cc +++ b/src/relay/pass/quantize/quantize.cc @@ -70,7 +70,7 @@ TVM_REGISTER_API("relay._quantize.simulated_quantize") .set_body_typed( [](Expr data, Expr dom_scale, Expr clip_min, Expr clip_max, int kind, bool sign, std::string rounding) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->kind = kind; attrs->sign = sign; attrs->rounding = rounding; @@ -88,7 +88,7 @@ struct TVMQConfigThreadLocalEntry { std::stack context_stack; TVMQConfigThreadLocalEntry() : - default_config(make_node()) { + default_config(make_object()) { } }; diff --git a/src/relay/pass/quantize/quantize.h b/src/relay/pass/quantize/quantize.h index 77900ab33e7b6..bfb7653686b63 100644 --- a/src/relay/pass/quantize/quantize.h +++ b/src/relay/pass/quantize/quantize.h @@ -62,7 +62,7 @@ class QConfig; /*! * \brief Container for build configuration options */ -class QConfigNode : public Node { +class QConfigNode : public Object { public: int nbit_input = 8; int nbit_weight = 8; @@ -73,10 +73,10 @@ class QConfigNode : public Node { std::string calibrate_mode = "global_scale"; double global_scale = 8.0; std::string weight_scale = "power2"; - Array skip_conv_layers = Array(NodePtr(nullptr)); + Array skip_conv_layers = Array(ObjectPtr(nullptr)); bool do_simulation = false; bool round_for_shift = true; - Array debug_enabled_ops = Array(NodePtr(nullptr)); + Array debug_enabled_ops = Array(ObjectPtr(nullptr)); std::string rounding = "UPWARD"; void VisitAttrs(AttrVisitor* v) { @@ -97,16 +97,16 @@ class QConfigNode : public Node { } static constexpr const char* _type_key = "relay.quantize.QConfig"; - TVM_DECLARE_NODE_TYPE_INFO(QConfigNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(QConfigNode, Object); }; /*! * \brief Container for build configuration options */ -class QConfig : public NodeRef { +class QConfig : public ObjectRef { public: QConfig() {} - explicit QConfig(ObjectPtr n) : NodeRef(n) {} + explicit QConfig(ObjectPtr n) : ObjectRef(n) {} const QConfigNode* operator->() const { return static_cast(get()); diff --git a/src/relay/pass/quantize/realize.cc b/src/relay/pass/quantize/realize.cc index 7a7e218ced054..bb8edf1edda71 100644 --- a/src/relay/pass/quantize/realize.cc +++ b/src/relay/pass/quantize/realize.cc @@ -45,10 +45,13 @@ class QRealizeExprNode : public TempExprNode { public: Expr data; static constexpr const char* _type_key = "relay.quantize.QRealizeExpr"; - TVM_DECLARE_BASE_NODE_INFO(QRealizeExprNode, TempExprNode); + TVM_DECLARE_BASE_OBJECT_INFO(QRealizeExprNode, TempExprNode); }; -RELAY_DEFINE_NODE_REF(QRealizeExpr, QRealizeExprNode, TempExpr); +class QRealizeExpr : public TempExpr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(QRealizeExpr, TempExpr, QRealizeExprNode); +}; class QRealizeIntExprNode : public QRealizeExprNode { @@ -67,10 +70,13 @@ class QRealizeIntExprNode : public QRealizeExprNode { TVM_DLL static QRealizeIntExpr make(Expr data, Expr dom_scale, DataType dtype); static constexpr const char * _type_key = "relay.quantize.QRealizeIntExpr"; - TVM_DECLARE_NODE_TYPE_INFO(QRealizeIntExprNode, QRealizeExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(QRealizeIntExprNode, QRealizeExprNode); }; -RELAY_DEFINE_NODE_REF(QRealizeIntExpr, QRealizeIntExprNode, QRealizeExpr); +class QRealizeIntExpr : public QRealizeExpr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(QRealizeIntExpr, QRealizeExpr, QRealizeIntExprNode); +}; Expr QRealizeIntExprNode::Realize() const { @@ -82,7 +88,7 @@ Expr QRealizeIntExprNode::Realize() const { } QRealizeIntExpr QRealizeIntExprNode::make(Expr data, Expr dom_scale, DataType dtype) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->data = std::move(data); n->dom_scale = std::move(dom_scale); n->dtype = std::move(dtype); @@ -120,7 +126,7 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype, Expr QuantizeRealize(const Call& ref_call, const Array& new_args, - const NodeRef& ctx) { + const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); // do not handle data type cast const auto param = ref_call->attrs.as(); @@ -196,7 +202,7 @@ RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize") Expr Conv2dRealize(const Call& ref_call, const Array& new_args, - const NodeRef& ctx) { + const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); CHECK_EQ(new_args.size(), 2); if (!new_args[0]->IsInstance() && !new_args[1]->IsInstance()) { @@ -214,7 +220,7 @@ Expr Conv2dRealize(const Call& ref_call, Expr rdata = Cast(rhs->data, cfg->dtype_weight); const auto ref_attrs = ref_call->attrs.as(); - auto attrs = make_node(); + auto attrs = make_object(); *attrs = *ref_attrs; DataType out_dtype = cfg->dtype_activation; attrs->out_dtype = out_dtype; @@ -232,7 +238,7 @@ RELAY_REGISTER_OP("nn.conv2d") Expr DenseRealize(const Call& ref_call, const Array& new_args, - const NodeRef& ctx) { + const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); CHECK_EQ(new_args.size(), 2); if (!new_args[0]->IsInstance() || !new_args[1]->IsInstance()) { @@ -248,7 +254,7 @@ Expr DenseRealize(const Call& ref_call, Expr rdata = Cast(rhs->data, cfg->dtype_weight); const auto ref_attrs = ref_call->attrs.as(); - auto attrs = make_node(); + auto attrs = make_object(); *attrs = *ref_attrs; DataType out_dtype = cfg->dtype_activation; attrs->out_dtype = out_dtype; @@ -266,7 +272,7 @@ RELAY_REGISTER_OP("nn.dense") Expr MulRealize(const Call& ref_call, const Array& new_args, - const NodeRef& ctx) { + const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); CHECK_EQ(new_args.size(), 2); if (new_args[0].as() && new_args[1].as()) { @@ -364,7 +370,7 @@ Array UnifyDTypeScale(const Array& ref_args, const Array& args Expr AddRealize(const Call& ref_call, const Array& new_args, - const NodeRef& ctx) { + const ObjectRef& ctx) { CHECK_EQ(new_args.size(), 2); if (new_args[0].as() && new_args[1].as()) { DataType dtype; @@ -383,11 +389,11 @@ RELAY_REGISTER_OP("add") Expr ClipRealize(const Call& ref_call, const Array& new_args, - const NodeRef& ctx) { + const ObjectRef& ctx) { CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as()) { const auto ref_attrs = ref_call->attrs.as(); - auto attrs = make_node(); + auto attrs = make_object(); double dom_scale = GetScalarFromConstant(n->dom_scale); attrs->a_min = ref_attrs->a_min / dom_scale; attrs->a_max = ref_attrs->a_max / dom_scale; @@ -406,7 +412,7 @@ RELAY_REGISTER_OP("clip") Expr ConcatenateRealize(const Call& ref_call, const Array& new_args, - const NodeRef& ctx) { + const ObjectRef& ctx) { CHECK_EQ(new_args.size(), 1); CHECK_EQ(ref_call->args.size(), 1); @@ -438,7 +444,7 @@ RELAY_REGISTER_OP("concatenate") /* \brief forward the original operator */ Expr IdentityRealize(const Call& ref_call, const Array& new_args, - const NodeRef& ctx) { + const ObjectRef& ctx) { CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as()) { Expr ret = ForwardOp(ref_call, {n->data}); @@ -460,7 +466,7 @@ RELAY_REGISTER_OP("annotation.stop_fusion") /* \brief for unary operators which requantize its input to dtype_nbit */ Expr CastDtypeInputRealize(const Call& ref_call, const Array& new_args, - const NodeRef& ctx) { + const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as()) { @@ -478,7 +484,7 @@ RELAY_REGISTER_OP("nn.max_pool2d") Expr AvgPoolRealize(const Call& ref_call, const Array& new_args, - const NodeRef& ctx) { + const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as()) { @@ -501,7 +507,7 @@ RELAY_REGISTER_OP("nn.global_avg_pool2d") Expr CastHintRealize(const Call& ref_call, const Array& new_args, - const NodeRef& ctx) { + const ObjectRef& ctx) { const auto param = ref_call->attrs.as(); CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as()) { diff --git a/src/relay/pass/simplify_inference.cc b/src/relay/pass/simplify_inference.cc index acd5163d13352..6d6171c9e461b 100644 --- a/src/relay/pass/simplify_inference.cc +++ b/src/relay/pass/simplify_inference.cc @@ -173,7 +173,7 @@ class InferenceSimplifier : public ExprMutator { const Op& dropout_op_; const Op& instance_norm_op_; const Op& layer_norm_op_; - std::unordered_map ty_map_; + std::unordered_map ty_map_; }; Expr SimplifyInference(const Expr& e) { diff --git a/src/relay/pass/to_a_normal_form.cc b/src/relay/pass/to_a_normal_form.cc index 3cce4b6b81a50..57894e015f0b2 100644 --- a/src/relay/pass/to_a_normal_form.cc +++ b/src/relay/pass/to_a_normal_form.cc @@ -110,7 +110,7 @@ class Fill : ExprFunctor { private: const DependencyGraph& dg_; std::unordered_map* node_scope_; - std::unordered_map memo; + std::unordered_map memo; Fill(const DependencyGraph& dg, std::unordered_map* node_scope) : diff --git a/src/relay/pass/to_cps.cc b/src/relay/pass/to_cps.cc index c20695becd6c6..1dfa327d8b0ed 100644 --- a/src/relay/pass/to_cps.cc +++ b/src/relay/pass/to_cps.cc @@ -89,10 +89,10 @@ Type CPSType(const Type& t, const TypeVar& answer) { } // transform global functions into cps form. -using CPSMap = std::unordered_map; +using CPSMap = std::unordered_map; // transform vars from the original program into new vars, so their type will be correct. -using VarMap = std::unordered_map; +using VarMap = std::unordered_map; /* * The meta continuation. diff --git a/src/relay/pass/to_graph_normal_form.cc b/src/relay/pass/to_graph_normal_form.cc index 5060c13fc75f8..b00e0d420641e 100644 --- a/src/relay/pass/to_graph_normal_form.cc +++ b/src/relay/pass/to_graph_normal_form.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -52,7 +52,7 @@ class UseVarVisitor : public ExprVisitor { class GNF : public ExprMutator { private: - std::unordered_map var_map_; + std::unordered_map var_map_; Expr VisitExpr_(const VarNode* vn) override { Var v = GetRef(vn); return var_map_.count(v) == 0 ? v : var_map_.at(v); diff --git a/src/relay/pass/transform_layout.h b/src/relay/pass/transform_layout.h index f6c5e9af6d62d..d283a239f2f65 100644 --- a/src/relay/pass/transform_layout.h +++ b/src/relay/pass/transform_layout.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -41,15 +41,16 @@ namespace relay { /*! * \brief Memorizes layout transformations to reuse. */ -class TransformMemorizerNode : public Node { +class TransformMemorizerNode : public Object { public: /*! \brief The key for the memorizer map is (Expr, src_layout, dst_layout). */ - using TransformKey = std::tuple; + using TransformKey = std::tuple; struct key_hash : public std::function { std::size_t operator()(const TransformKey& k) const { return dmlc::HashCombine( - dmlc::HashCombine(std::hash()(std::get<0>(k)), std::get<1>(k)), + dmlc::HashCombine( + std::hash()(std::get<0>(k)), std::get<1>(k)), (std::get<2>(k))); } }; @@ -58,16 +59,16 @@ class TransformMemorizerNode : public Node { std::unordered_map memo; static constexpr const char* _type_key = "relay.alter_op_layout.TransformMemorizerNode"; - TVM_DECLARE_NODE_TYPE_INFO(TransformMemorizerNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(TransformMemorizerNode, Object); }; /*! * \brief Container that transforms the layouts and memorizes them. */ -class TransformMemorizer : public NodeRef { +class TransformMemorizer : public ObjectRef { public: TransformMemorizer() {} - explicit TransformMemorizer(ObjectPtr n) : NodeRef(n) {} + explicit TransformMemorizer(ObjectPtr n) : ObjectRef(n) {} TransformMemorizerNode* operator->() { return static_cast(get_mutable()); @@ -85,7 +86,7 @@ class TransformMemorizer : public NodeRef { return raw; } - std::tuple key = + std::tuple key = std::make_tuple<>(raw.get(), src_layout.name(), dst_layout.name()); auto& memo = operator->()->memo; @@ -179,7 +180,7 @@ class LayoutAlternatedExprNode : public TempExprNode { } static constexpr const char* _type_key = "relay.alter_op_layout.LayoutAlternatedExprNode"; - TVM_DECLARE_NODE_TYPE_INFO(LayoutAlternatedExprNode, TempExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(LayoutAlternatedExprNode, TempExprNode); }; /*! @@ -187,10 +188,10 @@ class LayoutAlternatedExprNode : public TempExprNode { * \tparam TransformMemorizerT The derived TransformMemorizer type. */ template -class LayoutAlternatedExpr : public NodeRef { +class LayoutAlternatedExpr : public ObjectRef { public: LayoutAlternatedExpr() {} - explicit LayoutAlternatedExpr(ObjectPtr n) : NodeRef(n) {} + explicit LayoutAlternatedExpr(ObjectPtr n) : ObjectRef(n) {} LayoutAlternatedExprNode* operator->() { return static_cast*>(get_mutable()); @@ -219,7 +220,7 @@ class LayoutAlternatedExpr : public NodeRef { * - Transform the original call to reuse the new layouts using TransformMemorizer. */ template -Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const NodeRef& ctx) { +Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { std::vector> inputs; std::vector normal_new_args; Array> input_shapes; @@ -239,7 +240,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Nod inputs.push_back(GetRef>(inp)); return inp->value; } else { - auto inode = make_node>(); + auto inode = make_object>(); inode->value = arg; inode->memorizer = memorizer; inputs.push_back(LayoutAlternatedExpr(inode)); @@ -342,7 +343,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Nod Expr tuple_output = CallNode::make(new_call->op, transformed_args, new_call->attrs); Array fields; for (size_t i = 0; i < new_out.size(); ++i) { - auto rnode = make_node>(); + auto rnode = make_object>(); rnode->value = TupleGetItemNode::make(tuple_output, i); rnode->old_layout = old_out[i]; rnode->new_layout = new_out[i]; @@ -351,7 +352,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Nod } return TupleNode::make(fields); } else { - auto rnode = make_node>(); + auto rnode = make_object>(); CHECK_EQ(new_out.size(), 1); rnode->value = CallNode::make(new_call->op, transformed_args, new_call->attrs); rnode->old_layout = old_out[0]; diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 2c4cff4983a60..6e992bbeea1a0 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -90,7 +90,7 @@ struct ResolvedTypeInfo { Type checked_type; // Only allocated when the expression is a call. - Array type_args = Array(NodePtr(nullptr)); + Array type_args = Array(ObjectPtr(nullptr)); }; // @@ -128,7 +128,7 @@ class TypeInferencer : private ExprFunctor, // map from expression to checked type // type inferencer will populate it up - std::unordered_map type_map_; + std::unordered_map type_map_; // The solver used by the inferencer. TypeSolver solver_; @@ -138,7 +138,7 @@ class TypeInferencer : private ExprFunctor, // Perform unification on two types and report the error at the expression // or the span of the expression. - Type Unify(const Type& t1, const Type& t2, const NodeRef& expr) { + Type Unify(const Type& t1, const Type& t2, const ObjectRef& expr) { try { return solver_.Unify(t1, t2, expr); } catch (const dmlc::Error &e) { @@ -168,7 +168,7 @@ class TypeInferencer : private ExprFunctor, return ret; } - void ReportFatalError(const NodeRef& expr, const Error& err) { + void ReportFatalError(const ObjectRef& expr, const Error& err) { CHECK(this->current_func_.defined()); this->err_reporter.ReportAt(this->current_func_, expr, err); this->err_reporter.RenderErrors(this->mod_); @@ -215,7 +215,7 @@ class TypeInferencer : private ExprFunctor, } Type tuple_type = GetType(op->tuple); Type rtype = IncompleteTypeNode::make(Kind::kType); - auto attrs = make_node(); + auto attrs = make_object(); attrs->index = op->index; solver_.AddConstraint(TypeRelationNode::make( tuple_getitem_rel_, {tuple_type, rtype}, 1, Attrs(attrs)), GetRef(op)); @@ -235,7 +235,7 @@ class TypeInferencer : private ExprFunctor, unknown_args.push_back(IncompleteTypeNode::make(Kind::kType)); } Type expected = TypeCallNode::make(con->constructor->belong_to, unknown_args); - Type unified = Unify(t, expected, GetRef(con)); + Type unified = Unify(t, expected, GetRef(con)); auto* tc = unified.as(); if (!tc) { @@ -250,7 +250,7 @@ class TypeInferencer : private ExprFunctor, << "the number of type vars in the type data: " << td->type_vars.size() << " != " << tc->args.size())); } - std::unordered_map type_var_map_; + std::unordered_map type_var_map_; for (size_t i = 0; i < td->type_vars.size(); ++i) { type_var_map_[td->type_vars[i]] = tc->args[i]; } @@ -274,7 +274,7 @@ class TypeInferencer : private ExprFunctor, unknown_args.push_back(IncompleteTypeNode::make(Kind::kType)); } Type expected = TupleTypeNode::make(unknown_args); - Type unified = Unify(t, expected, GetRef(tup)); + Type unified = Unify(t, expected, GetRef(tup)); auto* tt = unified.as(); if (!tt) { @@ -372,7 +372,7 @@ class TypeInferencer : private ExprFunctor, Type PrimitiveCall(const FuncTypeNode* op, Array arg_types, const Attrs& attrs, - const NodeRef& loc) { + const ObjectRef& loc) { if (op->type_params.size() != arg_types.size() + 1) return Type(); if (op->type_constraints.size() != 1) return Type(); const TypeRelationNode* rel = op->type_constraints[0].as(); @@ -594,7 +594,7 @@ class TypeInferencer : private ExprFunctor, class TypeInferencer::Resolver : public ExprMutator, PatternMutator { public: - Resolver(const std::unordered_map& tmap, + Resolver(const std::unordered_map& tmap, TypeSolver* solver) : tmap_(tmap), solver_(solver) { } @@ -723,7 +723,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { // Copy on write optimization // If new_e is an old expression, // we make a copy mutating an existing reference. - NodePtr ptr = make_node(*new_e.as()); + ObjectPtr ptr = make_object(*new_e.as()); new_e = Expr(ptr); new_call = ( std::is_base_of::value ? @@ -763,8 +763,8 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { } private: - std::unordered_map vmap_; - const std::unordered_map& tmap_; + std::unordered_map vmap_; + const std::unordered_map& tmap_; TypeSolver* solver_; // whether attach the checked type as type_annotation // if original type anntation is missing. @@ -814,7 +814,7 @@ Function InferType(const Function& func, const Module& mod, const GlobalVar& var) { CHECK(mod.defined()) << "internal error: module must be set for type inference"; - Function func_copy = Function(make_node(*func.operator->())); + Function func_copy = Function(make_object(*func.operator->())); func_copy->checked_type_ = func_copy->func_type_annotation(); mod->AddUnchecked(var, func_copy); Expr func_ret = TypeInferencer(mod, var).Infer(func_copy); diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 8376d36698999..86ebe0f22c8dd 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -56,7 +56,7 @@ class TypeSolver::Reporter : public TypeReporterNode { return true; } - TVM_DLL void SetLocation(const NodeRef& ref) final { + TVM_DLL void SetLocation(const ObjectRef& ref) final { location = ref; } @@ -66,7 +66,7 @@ class TypeSolver::Reporter : public TypeReporterNode { private: /*! \brief The location to report unification errors at. */ - mutable NodeRef location; + mutable ObjectRef location; TypeSolver* solver_; }; @@ -95,7 +95,7 @@ class TypeSolver::OccursChecker : public TypeVisitor { class TypeSolver::Unifier : public TypeFunctor { public: - explicit Unifier(TypeSolver* solver, const NodeRef& loc) : solver_(solver), loc(loc) {} + explicit Unifier(TypeSolver* solver, const ObjectRef& loc) : solver_(solver), loc(loc) {} Type Unify(const Type& src, const Type& dst) { // Known limitation @@ -150,8 +150,8 @@ class TypeSolver::Unifier : public TypeFunctor { } // default: unify only if alpha-equal - Type VisitTypeDefault_(const Node* op, const Type& tn) final { - NodeRef nr = GetRef(op); + Type VisitTypeDefault_(const Object* op, const Type& tn) final { + ObjectRef nr = GetRef(op); Type t1 = GetRef(nr.as()); if (!AlphaEqual(t1, tn)) { return Type(nullptr); @@ -365,7 +365,7 @@ class TypeSolver::Unifier : public TypeFunctor { private: TypeSolver* solver_; - NodeRef loc; + ObjectRef loc; }; class TypeSolver::Resolver : public TypeMutator { @@ -408,8 +408,8 @@ class TypeSolver::Propagator : public TypeFunctor { } } - void VisitTypeDefault_(const Node* op) override { - NodeRef nr = GetRef(op); + void VisitTypeDefault_(const Object* op) override { + ObjectRef nr = GetRef(op); Type t = GetRef(nr.as()); UpdateRelSet(t); } @@ -492,8 +492,8 @@ class TypeSolver::Merger : public TypeFunctor { } } - void VisitTypeDefault_(const Node* op) override { - NodeRef nr = GetRef(op); + void VisitTypeDefault_(const Object* op) override { + ObjectRef nr = GetRef(op); Type t = GetRef(nr.as()); TransferLinks(t); } @@ -533,7 +533,7 @@ TypeSolver::TypeSolver( const GlobalVar& current_func, const Module& module, ErrorReporter* err_reporter) - : reporter_(make_node(this)), + : reporter_(make_object(this)), current_func(current_func), err_reporter_(err_reporter), module_(module) { @@ -558,19 +558,19 @@ void TypeSolver::MergeFromTo(TypeNode* src, TypeNode* dst) { } // Add equality constraint -Type TypeSolver::Unify(const Type& dst, const Type& src, const NodeRef& loc) { +Type TypeSolver::Unify(const Type& dst, const Type& src, const ObjectRef& loc) { Unifier unifier(this, loc); return unifier.Unify(dst, src); } -void TypeSolver::ReportError(const Error& err, const NodeRef& location) { +void TypeSolver::ReportError(const Error& err, const ObjectRef& location) { CHECK(location.defined()); CHECK(current_func.defined()); err_reporter_->ReportAt(current_func, location, err); } // Add type constraint to the solver. -void TypeSolver::AddConstraint(const TypeConstraint& constraint, const NodeRef& loc) { +void TypeSolver::AddConstraint(const TypeConstraint& constraint, const ObjectRef& loc) { if (const auto* op = constraint.as()) { // create a new relation node. RelationNode* rnode = arena_.make(); diff --git a/src/relay/pass/type_solver.h b/src/relay/pass/type_solver.h index fa9ef7a156466..bf1ac716cfc59 100644 --- a/src/relay/pass/type_solver.h +++ b/src/relay/pass/type_solver.h @@ -69,7 +69,7 @@ class TypeSolver { * \param constraint The constraint to be added. * \param location The location at which the constraint was incurred. */ - void AddConstraint(const TypeConstraint& constraint, const NodeRef& lcoation); + void AddConstraint(const TypeConstraint& constraint, const ObjectRef& lcoation); /*! * \brief Resolve type to the solution type in the solver. * \param type The type to be resolved. @@ -87,13 +87,13 @@ class TypeSolver { * \param rhs The right operand * \param location The location at which the unification problem arose. */ - Type Unify(const Type& lhs, const Type& rhs, const NodeRef& location); + Type Unify(const Type& lhs, const Type& rhs, const ObjectRef& location); /*! * \brief Report an error at the provided location. * \param err The error to report. * \param loc The location at which to report the error. */ - void ReportError(const Error& err, const NodeRef& location); + void ReportError(const Error& err, const ObjectRef& location); private: class OccursChecker; @@ -155,7 +155,7 @@ class TypeSolver { /*! \brief list types to this relation */ LinkedList type_list; /*! \brief The location this type relation originated from. */ - NodeRef location; + ObjectRef location; }; /*! \brief A simple union find between shapes. */ @@ -167,7 +167,7 @@ class TypeSolver { /*! \brief Number of resolved relations */ size_t num_resolved_rels_{0}; /*! \brief map from types to type nodes. */ - std::unordered_map tmap_; + std::unordered_map tmap_; /*! \brief Internal queue to update the relation */ std::queue update_queue_; /*! \brief allocator of all the internal node obhect*/ diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index 17c527b392374..2efb479c3156f 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -35,7 +35,7 @@ namespace relay { template struct InsertionSet { - std::unordered_set set; + std::unordered_set set; std::vector data; void Insert(const T& t) { if (set.count(t) == 0) { @@ -279,7 +279,7 @@ TVM_REGISTER_API("relay._analysis.free_vars") TVM_REGISTER_API("relay._analysis.bound_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { - NodeRef x = args[0]; + ObjectRef x = args[0]; if (x.as()) { *ret = BoundVars(Downcast(x)); } else { @@ -292,7 +292,7 @@ TVM_REGISTER_API("relay._analysis.all_vars") TVM_REGISTER_API("relay._analysis.free_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { - NodeRef x = args[0]; + ObjectRef x = args[0]; Module mod = args[1]; if (x.as()) { *ret = FreeTypeVars(Downcast(x), mod); @@ -303,7 +303,7 @@ TVM_REGISTER_API("relay._analysis.free_type_vars") TVM_REGISTER_API("relay._analysis.bound_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { - NodeRef x = args[0]; + ObjectRef x = args[0]; Module mod = args[1]; if (x.as()) { *ret = BoundTypeVars(Downcast(x), mod); @@ -314,7 +314,7 @@ TVM_REGISTER_API("relay._analysis.bound_type_vars") TVM_REGISTER_API("relay._analysis.all_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { - NodeRef x = args[0]; + ObjectRef x = args[0]; Module mod = args[1]; if (x.as()) { *ret = AllTypeVars(Downcast(x), mod); @@ -328,11 +328,11 @@ TVM_REGISTER_API("relay._analysis.all_type_vars") * \param body The body expression. * \return The reference count mapping. */ -std::unordered_map +std::unordered_map GetExprRefCount(const Expr& body) { class ExprRefCounter : private ExprVisitor { public: - std::unordered_map + std::unordered_map Get(const Expr& body) { this->VisitExpr(body); return std::move(this->visit_counter_); diff --git a/src/relay/pass/well_formed.cc b/src/relay/pass/well_formed.cc index abcedd2ab4831..2bbf9792dd1d3 100644 --- a/src/relay/pass/well_formed.cc +++ b/src/relay/pass/well_formed.cc @@ -34,10 +34,10 @@ namespace relay { class WellFormedChecker : private ExprVisitor, PatternVisitor { bool well_formed = true; - std::vector> scope; - std::unordered_set current_bound; - std::unordered_set total_bound; - std::unordered_set free; + std::vector> scope; + std::unordered_set current_bound; + std::unordered_set total_bound; + std::unordered_set free; struct Scope { WellFormedChecker* wfc; diff --git a/src/relay/qnn/op/concatenate.cc b/src/relay/qnn/op/concatenate.cc index 8c27d47632a1a..43d47e21822e9 100644 --- a/src/relay/qnn/op/concatenate.cc +++ b/src/relay/qnn/op/concatenate.cc @@ -39,7 +39,7 @@ TVM_REGISTER_NODE_TYPE(QnnConcatenateAttrs); Expr MakeQnnConcatenate(Expr data, Array input_scales, Array input_zero_points, double output_scale, int32_t output_zero_point, int axis) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->input_scales = std::move(input_scales); attrs->input_zero_points = std::move(input_zero_points); attrs->output_scale = output_scale; diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index a629bf2b462e3..669b04fdda48c 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -607,7 +607,7 @@ Expr MakeQnnConv2D(Expr data, Expr weight, int32_t input_zero_point, int32_t ker int groups, IndexExpr channels, Array kernel_size, std::string data_layout, std::string kernel_layout, std::string out_layout, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->strides = std::move(strides); attrs->padding = std::move(padding); attrs->dilation = std::move(dilation); diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index ad0da52ec120c..2353e5a890961 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -60,7 +60,7 @@ Expr MakeQuantizedDense(Expr data, Expr weight, int32_t input_zero_point, int32_t kernel_zero_point, double input_scale, double kernel_scale, IndexExpr units, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->units = std::move(units); attrs->out_dtype = out_dtype; attrs->input_zero_point = input_zero_point; diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc index 7daee4664ac5d..a1e23808d4371 100644 --- a/src/relay/qnn/op/dequantize.cc +++ b/src/relay/qnn/op/dequantize.cc @@ -56,7 +56,7 @@ bool DequantizeRel(const Array& types, Expr MakeDequantize(Expr data, double input_scale, int32_t input_zero_point) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->input_scale = input_scale; attrs->input_zero_point = input_zero_point; // real_value = scale * (quantized_value - zero_point) diff --git a/src/relay/qnn/op/op_common.h b/src/relay/qnn/op/op_common.h index be8e197b78b07..2c116fedeaeea 100644 --- a/src/relay/qnn/op/op_common.h +++ b/src/relay/qnn/op/op_common.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -50,7 +50,7 @@ namespace qnn { .set_body_typed( \ [](Expr lhs, Expr rhs, double lhs_scale, int32_t lhs_zero_point, double rhs_scale, \ int32_t rhs_zero_point, double output_scale, int32_t output_zero_point) { \ - auto attrs = make_node(); \ + auto attrs = make_object(); \ attrs->lhs_scale = lhs_scale; \ attrs->lhs_zero_point = lhs_zero_point; \ attrs->rhs_scale = rhs_scale; \ diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index 6b7fecd191fc1..18dd9aa01af58 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -60,7 +60,7 @@ Expr MakeQuantize(Expr data, double output_scale, int32_t output_zero_point, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->output_scale = output_scale; attrs->output_zero_point = output_zero_point; attrs->out_dtype = std::move(out_dtype); diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index ec8c845dc8c69..93284cb38e87e 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -164,7 +164,7 @@ bool RequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, // used by frontend FFI. Expr MakeRequantize(Expr data, double input_scale, int32_t input_zero_point, double output_scale, int32_t output_zero_point, std::string rounding, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->input_scale = std::move(input_scale); attrs->input_zero_point = std::move(input_zero_point); attrs->output_scale = std::move(output_scale); diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h index 6659a44e63f65..e359296c1d1a8 100644 --- a/src/relay/qnn/util.h +++ b/src/relay/qnn/util.h @@ -84,7 +84,7 @@ static inline Expr Requantize(const Expr& data, const Array& input_sh double input_scale, int32_t input_zero_point, double output_scale, int32_t output_zero_point, const DataType& out_dtype, const std::string& rounding = "UPWARD") { - auto attrs = make_node(); + auto attrs = make_object(); attrs->input_scale = std::move(input_scale); attrs->input_zero_point = std::move(input_zero_point); attrs->output_scale = std::move(output_scale); diff --git a/src/runtime/vm/memory_manager.h b/src/runtime/vm/memory_manager.h index 292fb55e59950..95a154ce6beee 100644 --- a/src/runtime/vm/memory_manager.h +++ b/src/runtime/vm/memory_manager.h @@ -137,7 +137,7 @@ class Storage : public ObjectRef { public: explicit Storage(Buffer buffer); - TVM_DEFINE_OBJECT_REF_METHODS_MUT(Storage, ObjectRef, StorageObj); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Storage, ObjectRef, StorageObj); }; } // namespace vm diff --git a/src/schedule/auto_inline_elem_wise.cc b/src/schedule/auto_inline_elem_wise.cc index 62739bb22004f..e587f385734fa 100644 --- a/src/schedule/auto_inline_elem_wise.cc +++ b/src/schedule/auto_inline_elem_wise.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -33,7 +33,7 @@ class ElemWiseDetector : public ir::IRVisitor { public: explicit ElemWiseDetector(Array axis) : axis_(axis) {} - void Visit(const NodeRef& e) final { + void Visit(const ObjectRef& e) final { if (!is_elem_wise_) return; IRVisitor::Visit(e); } diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc index e213df5e659d6..d4baded91f7cd 100644 --- a/src/schedule/bound.cc +++ b/src/schedule/bound.cc @@ -47,7 +47,7 @@ struct GraphContext { /*! \brief The bind map */ std::unordered_map bind_map; /*! \brief map from op to stage */ - std::unordered_map op2stage_; + std::unordered_map op2stage_; }; bool NeedRelax(const IterVar& iv, diff --git a/src/schedule/graph.cc b/src/schedule/graph.cc index 518f05a03250a..c3024a71977fb 100644 --- a/src/schedule/graph.cc +++ b/src/schedule/graph.cc @@ -62,7 +62,7 @@ namespace std { template <> struct hash<::tvm::schedule::TensorDimKey> { std::size_t operator()(const ::tvm::schedule::TensorDimKey& k) const { - size_t lhs = ::tvm::NodeHash()(k.f); + size_t lhs = ::tvm::ObjectHash()(k.f); size_t rhs = static_cast(k.value_index) << 16UL | static_cast(k.dim); lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); @@ -80,7 +80,7 @@ namespace schedule { ReadGraph CreateReadGraph(const Array& roots) { ReadGraph rmap; std::vector stack; - std::unordered_set visited; + std::unordered_set visited; // initialize the roots for (Operation op : roots) { stack.push_back(op); @@ -106,9 +106,9 @@ ReadGraph CreateReadGraph(const Array& roots) { // Return if op is inside the subgraph. bool GetSubGraphByPostDFS_( const Operation& op, - const std::unordered_set& boundary, + const std::unordered_set& boundary, bool include_bounary, - std::unordered_map* visited, + std::unordered_map* visited, Array* result) { if (visited->count(op.get())) { return visited->at(op.get()); @@ -143,11 +143,11 @@ Array GetSubGraph(const Array& outputs, const Array& inputs, bool include_inputs) { Array result; - std::unordered_set boundary; + std::unordered_set boundary; for (Tensor t : inputs) { boundary.insert(t->op.get()); } - std::unordered_map visited; + std::unordered_map visited; for (Tensor t : outputs) { GetSubGraphByPostDFS_(t->op, boundary, include_inputs, &visited, &result); @@ -192,7 +192,7 @@ FeedGraph CreateFeedGraph(const ReadGraph& g) { AttachPath CreateAttachPath(Schedule sch) { AttachPath ret; for (Stage stage : sch->stages) { - std::unordered_set visited; + std::unordered_set visited; Array path; for (Stage s = stage; s.defined();) { CHECK(!visited.count(s.get())) @@ -236,7 +236,7 @@ using ReachGraph = std::unordered_map >; ReachGraph GetReachGraph(const Array& ops) { ReachGraph reach; - std::unordered_set bset; + std::unordered_set bset; for (size_t i = 0; i < ops.size(); ++i) { bset.insert(ops[i].get()); } @@ -255,20 +255,20 @@ ReachGraph GetReachGraph(const Array& ops) { } } } else if (const auto* compute_op = op.as()) { - std::unordered_map vmap; + std::unordered_map vmap; const auto& axis = compute_op->axis; Tensor t = op.output(0); for (size_t i = 0; i < axis.size(); ++i) { vmap[axis[i]->var.get()] = TensorDimKey(t, i); reach[TensorDimKey(t, i)] = {}; } - auto fvisit = [&vmap, &reach, &bset](const NodeRef& n) { + auto fvisit = [&vmap, &reach, &bset](const ObjectRef& n) { const ir::Call *call = n.as(); if (call != nullptr && call->func.defined()) { if (!bset.count(call->func.get())) return; for (size_t i = 0; i < call->args.size(); ++i) { TensorDimKey dkey(call, static_cast(i)); - auto fpush = [&dkey, &vmap, &reach](const NodeRef& node) { + auto fpush = [&dkey, &vmap, &reach](const ObjectRef& node) { const Variable *v = node.as(); auto it = vmap.find(v); if (it != vmap.end()) { @@ -304,8 +304,8 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { const ScanOpNode* scan = scan_op.as(); Array body = ScanGetBody(scan_op); - std::unordered_map exact_reach; - std::unordered_set fail_set; + std::unordered_map exact_reach; + std::unordered_set fail_set; for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) { for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) { @@ -342,7 +342,7 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { } } } else if (const auto* compute_op = op.as()) { - std::unordered_map > vmap; + std::unordered_map > vmap; const auto& axis = compute_op->axis; for (size_t i = 0; i < axis.size(); ++i) { std::vector keys; @@ -352,7 +352,7 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { vmap[axis[i]->var.get()] = std::move(keys); } auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set]( - const NodeRef& n) { + const ObjectRef& n) { const ir::Call *call = n.as(); if (call != nullptr && call->func.defined()) { for (size_t i = 0; i < call->args.size(); ++i) { diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index c9afcf45a1f2e..70a73abc46984 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -34,7 +34,7 @@ namespace tvm { // find first occurance location in leaf template size_t FindNodeRef(ArrayNode* array_node, const T& v) { - const Node* n = v.get(); + const Object* n = v.get(); for (size_t i = 0; i < array_node->data.size(); ++i) { if (array_node->data[i].get() == n) return i; } @@ -98,7 +98,7 @@ Expr InjectPredicate(const Array& predicates, if (predicates.size() == 0) return body; const Reduce* reduce = body.as(); if (reduce) { - auto n = make_node(*reduce); + auto n = make_object(*reduce); n->condition = n->condition && arith::ComputeReduce(predicates, Expr()); return Expr(n); } @@ -591,7 +591,7 @@ void InjectInline(ScheduleNode* sch) { CHECK_EQ(new_body[j].size(), r->source.size()); CHECK(r != nullptr); for (size_t k = 0; k < new_body[j].size(); ++k) { - auto n = make_node(*r); + auto n = make_object(*r); n->value_index = static_cast(k); n->dtype = r->source[k].dtype(); new_body[j].Set(k, Expr(n)); @@ -734,11 +734,11 @@ Array Schedule::rfactor(const Tensor& tensor, const int factor_axis_pos = \ factor_axis >= 0 ? factor_axis : static_cast(compute_op->axis.size() + 1) + factor_axis; CHECK_LE(factor_axis_pos, compute_op->axis.size()); - auto n = make_node(); + auto n = make_object(); n->name = compute_op->name + ".rf"; { // axis relacement. - auto iv_node = make_node(); + auto iv_node = make_object(); iv_node->dom = dom_map.at(axis); CHECK(is_zero(iv_node->dom->min)) << "Can only factor reduction domain starting from 0"; @@ -779,7 +779,7 @@ Array Schedule::rfactor(const Tensor& tensor, for (IterVar iv : reduce_stage->leaf_iter_vars) { if (touch_map.count(iv) && !iv.same_as(axis)) { CHECK_EQ(iv->iter_type, kCommReduce); - auto ncpy = make_node(*iv.operator->()); + auto ncpy = make_object(*iv.operator->()); ncpy->dom = dom_map.at(iv); n->reduce_axis.push_back(IterVar(ncpy)); } diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc index 7a2ab5a4d8b91..ec73c67bedff6 100644 --- a/src/schedule/schedule_lang.cc +++ b/src/schedule/schedule_lang.cc @@ -33,7 +33,7 @@ namespace { // find first occurance location in leaf template size_t FindNodeRef(ArrayNode* array_node, const T& v) { - const Node* n = v.get(); + const Object* n = v.get(); for (size_t i = 0; i < array_node->data.size(); ++i) { if (array_node->data[i].get() == n) return i; } @@ -88,7 +88,7 @@ void Split(StageNode* self, } // namespace Stage::Stage(Operation op) { - auto n = make_node(); + auto n = make_object(); n->op = op; n->origin_op = op; n->all_iter_vars = op->root_iter_vars(); @@ -182,16 +182,16 @@ Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) { // NOLINT(*) FindLeafVar(all_vars, leaf_vars, ivar); auto it = self->iter_var_attrs.find(ivar); - NodePtr n; + ObjectPtr n; if (it != self->iter_var_attrs.end()) { - n = make_node(*(*it).second.operator->()); + n = make_object(*(*it).second.operator->()); if (n->bind_thread.defined() && !n->bind_thread.same_as(thread_ivar)) { LOG(WARNING) << "Axis " << ivar << " is already bind to another thread " << n->bind_thread; } } else { - n = make_node(); + n = make_object(); } n->bind_thread = thread_ivar; self->iter_var_attrs.Set(ivar, IterVarAttr(n)); @@ -353,11 +353,11 @@ inline void UpdateIterVarAttr(StageNode* self, FindLeafVar(all_vars, leaf_vars, var); } auto it = self->iter_var_attrs.find(var); - NodePtr n; + ObjectPtr n; if (it != self->iter_var_attrs.end()) { - n = make_node(*(*it).second.operator->()); + n = make_object(*(*it).second.operator->()); } else { - n = make_node(); + n = make_object(); } fupdate(n.get()); self->iter_var_attrs.Set(var, IterVarAttr(n)); @@ -422,11 +422,11 @@ Stage& Stage::prefetch(const Tensor &tensor, IterVar var, Expr offset) { ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); FindLeafVar(all_vars, leaf_vars, var); auto it = self->iter_var_attrs.find(var); - NodePtr n; + ObjectPtr n; if (it != self->iter_var_attrs.end()) { - n = make_node(*(*it).second.operator->()); + n = make_object(*(*it).second.operator->()); } else { - n = make_node(); + n = make_object(); } n->prefetch_data.push_back(tensor); n->prefetch_offset.push_back(offset); @@ -493,16 +493,16 @@ Stage& Stage::opengl() { } Stage CopyStage(const Stage& s) { - NodePtr n = - make_node(*s.operator->()); + ObjectPtr n = + make_object(*s.operator->()); return Stage(n); } Schedule Schedule::copy() const { // map of stages. const ScheduleNode* self = operator->(); - std::unordered_map smap; - NodePtr n = make_node(); + std::unordered_map smap; + ObjectPtr n = make_object(); n->outputs = self->outputs; // Copy the stages. for (Stage s : self->stages) { @@ -605,7 +605,7 @@ Stage Schedule::create_group(const Array& outputs, int count{0}; }; // Map of group->touched counter - std::unordered_map counter; + std::unordered_map counter; // The parent group; Stage parent_group; // Detect common parent and child. @@ -624,7 +624,7 @@ Stage Schedule::create_group(const Array& outputs, } } // Create the new group stage. - Stage gstage(make_node()); + Stage gstage(make_object()); gstage->group = parent_group; if (parent_group.defined()) { ++parent_group->num_child_stages; @@ -716,7 +716,7 @@ bool ScheduleNode::Contain(const Operation& op) const { } Schedule ScheduleNode::make(Array ops) { - auto n = make_node(); + auto n = make_object(); Schedule sch(n); n->outputs = ops; auto g = schedule::CreateReadGraph(n->outputs); @@ -759,7 +759,7 @@ IterVarRelation SplitNode::make(IterVar parent, IterVar inner, Expr factor, Expr nparts) { - auto n = make_node(); + auto n = make_object(); n->parent = parent; n->outer = outer; n->inner = inner; @@ -770,7 +770,7 @@ IterVarRelation SplitNode::make(IterVar parent, IterVarRelation FuseNode::make( IterVar outer, IterVar inner, IterVar fused) { - auto n = make_node(); + auto n = make_object(); n->outer = outer; n->inner = inner; n->fused = fused; @@ -778,14 +778,14 @@ IterVarRelation FuseNode::make( } IterVarRelation RebaseNode::make(IterVar parent, IterVar rebased) { - auto n = make_node(); + auto n = make_object(); n->parent = parent; n->rebased = rebased; return IterVarRelation(n); } IterVarRelation SingletonNode::make(IterVar iter) { - auto n = make_node(); + auto n = make_object(); n->iter = iter; return IterVarRelation(n); } diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc index 3a9d0bcb2a98a..0103410e6132b 100644 --- a/src/schedule/schedule_ops.cc +++ b/src/schedule/schedule_ops.cc @@ -220,13 +220,13 @@ class SchedulePostProc : public IRMutator { } } } else if (op->attr_key == ir::attr::buffer_bind_scope) { - Array tuple = Downcast >(op->node); + Array tuple = Downcast >(op->node); Tensor tensor = Downcast(tuple[1]); auto it = replace_op_.find(tensor->op.get()); if (it != replace_op_.end()) { if (it->second.defined()) { return AttrStmt::make( - Array{tuple[0], it->second.output(tensor->value_index)}, + Array{tuple[0], it->second.output(tensor->value_index)}, op->attr_key, op->value, Mutate(op->body)); } else { return this->Mutate(op->body); @@ -344,7 +344,7 @@ class SchedulePostProc : public IRMutator { replace_op_[src->op.get()] = repl_op; } // The thread extent scope. - std::unordered_map thread_extent_scope_; + std::unordered_map thread_extent_scope_; // The scan value std::unordered_map var_value_; // buffer replacement @@ -352,7 +352,7 @@ class SchedulePostProc : public IRMutator { // buffere realization to be replaced std::unordered_map replace_realize_; // replace producer consumer. - std::unordered_map replace_op_; + std::unordered_map replace_op_; }; Stmt ScheduleOps( diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index 4428642b281d6..7aab3edb6aafd 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -220,8 +220,8 @@ TEST(Map, Iterator) { using namespace tvm; Expr a = 1, b = 2; Map map1{{a, b}}; - std::unordered_map map2(map1.begin(), - map1.end()); + std::unordered_map + map2(map1.begin(), map1.end()); CHECK(map2[a].as()->value == 2); } diff --git a/tests/cpp/expr_test.cc b/tests/cpp/expr_test.cc index 7ecf4590ca12d..debfb36f936bd 100644 --- a/tests/cpp/expr_test.cc +++ b/tests/cpp/expr_test.cc @@ -25,7 +25,7 @@ TEST(Expr, Basic) { using namespace tvm; Var x("x"); auto z = max(x + 1 + 2, 100); - NodeRef tmp = z; + ObjectRef tmp = z; Expr zz = Downcast(tmp); std::ostringstream os; os << z; @@ -39,7 +39,7 @@ TEST(ExprNodeRef, Basic) { Var x("x"); Expr z = max(x + 1 + 2, 100); const ir::Max* op = z.as(); - CHECK(GetRef(op).same_as(z)); + CHECK(GetRef(op).same_as(z)); } diff --git a/tests/cpp/ir_visitor_test.cc b/tests/cpp/ir_visitor_test.cc index 079be65079ca9..4282a0026ee68 100644 --- a/tests/cpp/ir_visitor_test.cc +++ b/tests/cpp/ir_visitor_test.cc @@ -28,7 +28,7 @@ TEST(IRVisitor, CountVar) { Var x("x"), y; auto z = x + 1 + y + y; - ir::PostOrderVisit(z, [&n_var](const NodeRef& n) { + ir::PostOrderVisit(z, [&n_var](const ObjectRef& n) { if (n.as()) ++n_var; }); CHECK_EQ(n_var, 2); diff --git a/topi/include/topi/detail/extern.h b/topi/include/topi/detail/extern.h index 1b510a45661ff..fa184bfee7d17 100644 --- a/topi/include/topi/detail/extern.h +++ b/topi/include/topi/detail/extern.h @@ -81,7 +81,7 @@ inline Array make_extern(const Array< Array >& out_shapes, FExtern fextern, std::string name, std::string tag, - ::tvm::Map attrs) { + ::tvm::Map attrs) { CHECK_EQ(out_shapes.size(), out_types.size()) << "make_extern: out_shapes and out_types must have equal size"; diff --git a/topi/include/topi/nn/softmax.h b/topi/include/topi/nn/softmax.h index fa985d1b2086b..c3124bbe6f580 100644 --- a/topi/include/topi/nn/softmax.h +++ b/topi/include/topi/nn/softmax.h @@ -61,7 +61,7 @@ inline Tensor softmax(const Tensor &x, auto k2 = tvm::reduce_axis(Range(0, input_shape[axis]), "k2"); auto reduced_shape = MakeReduceTargetShape({axis}, x, false, false); - tvm::Map attrs; + tvm::Map attrs; attrs.Set("axis", Integer(axis)); auto insert_reduce_index = [axis, ndim](const Array &indices, diff --git a/topi/src/topi.cc b/topi/src/topi.cc index 87837f82635b5..11a90215d71fd 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -757,7 +757,7 @@ inline PackedFunc WrapSchedule(FTVMScheduleBuilder builder) { return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) { auto target = Target::Current(false); Array outs; - NodeRef argNodeRef = args[0]; + ObjectRef argNodeRef = args[0]; if (argNodeRef->type_index() == outs->type_index()) { outs = args[0]; } else {