Skip to content

Commit

Permalink
[REFACTOR][OBJECT] Consoldiate NodePtr/Ref/Hash/Equal to Object (#4603)
Browse files Browse the repository at this point in the history
* [REFACTOR][OBJECT] Consoldiate NodePtr/Ref/Hash/Equal and macros to Object.

Historically, we have classes like NodePtr/Ref/HashEqual.
After unified object protocol, these names are just alias of the object counterpart.
Moreover, there are helper macros defined over the places for defining these object.

This PR consoldiate the terminologies into the corresponding ones
in the Object system so we have a clean and consistent API moving forward.

* Update include/tvm/attrs.h

Co-Authored-By: Wei Chen <ipondering.weic@gmail.com>

* fix compilation

Co-authored-by: Wei Chen <ipondering.weic@gmail.com>
  • Loading branch information
tqchen and wweic committed Dec 31, 2019
1 parent 475158f commit a8c3692
Show file tree
Hide file tree
Showing 215 changed files with 1,623 additions and 1,517 deletions.
12 changes: 6 additions & 6 deletions include/tvm/api_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ namespace tvm {
* \brief Node container of EnvFunc
* \sa EnvFunc
*/
class EnvFuncNode : public Node {
class EnvFuncNode : public Object {
public:
/*! \brief Unique name of the global function */
std::string name;
Expand All @@ -63,7 +63,7 @@ class EnvFuncNode : public Node {
}

static constexpr const char* _type_key = "EnvFunc";
TVM_DECLARE_NODE_TYPE_INFO(EnvFuncNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(EnvFuncNode, Object);
};

/*!
Expand All @@ -73,10 +73,10 @@ class EnvFuncNode : public Node {
* An EnvFunc is saved by its name in the global registry
* under the assumption that the same function is registered during load.
*/
class EnvFunc : public NodeRef {
class EnvFunc : public ObjectRef {
public:
EnvFunc() {}
explicit EnvFunc(NodePtr<Node> n) : NodeRef(n) {}
explicit EnvFunc(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! \return The internal global function pointer */
const EnvFuncNode* operator->() const {
return static_cast<const EnvFuncNode*>(get());
Expand Down Expand Up @@ -119,12 +119,12 @@ class TypedEnvFunc;
* \sa EnvFunc
*/
template<typename R, typename... Args>
class TypedEnvFunc<R(Args...)> : public NodeRef {
class TypedEnvFunc<R(Args...)> : public ObjectRef {
public:
/*! \brief short hand for this function type */
using TSelf = TypedEnvFunc<R(Args...)>;
TypedEnvFunc() {}
explicit TypedEnvFunc(ObjectPtr<Object> n) : NodeRef(n) {}
explicit TypedEnvFunc(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief Assign global function to a TypedEnvFunc
* \param other Another global function.
Expand Down
26 changes: 13 additions & 13 deletions include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class Analyzer;
*
* set = [min_value, max_value]
*/
class ConstIntBoundNode : public Node {
class ConstIntBoundNode : public Object {
public:
int64_t min_value;
int64_t max_value;
Expand All @@ -74,14 +74,14 @@ class ConstIntBoundNode : public Node {
static const constexpr int64_t kNegInf = -kPosInf;

static constexpr const char* _type_key = "arith.ConstIntBound";
TVM_DECLARE_NODE_TYPE_INFO(ConstIntBoundNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(ConstIntBoundNode, Object);
};

/*!
* \brief reference class to ConstIntBoundNode
* \sa ConstIntBoundNode
*/
class ConstIntBound : public NodeRef {
class ConstIntBound : public ObjectRef {
public:
/*!
* \brief constructor by fields.
Expand All @@ -92,7 +92,7 @@ class ConstIntBound : public NodeRef {

static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf;
static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf;
TVM_DEFINE_NODE_REF_METHODS(ConstIntBound, NodeRef, ConstIntBoundNode);
TVM_DEFINE_OBJECT_REF_METHODS(ConstIntBound, ObjectRef, ConstIntBoundNode);
};

/*!
Expand Down Expand Up @@ -155,7 +155,7 @@ class ConstIntBoundAnalyzer {
* This is useful to decide if the index is dividable by certain value.
* For example, if index = 0 + 4 x, then we know it can be divided by 4.
*/
class ModularSetNode : public Node {
class ModularSetNode : public Object {
public:
/*! \brief linear co-efficient */
int64_t coeff;
Expand All @@ -168,18 +168,18 @@ class ModularSetNode : public Node {
}

static constexpr const char* _type_key = "arith.ModularSet";
TVM_DECLARE_NODE_TYPE_INFO(ModularSetNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(ModularSetNode, Object);
};

/*!
* \brief reference of ModularSetNode
* \sa ModularSetNode
*/
class ModularSet : public NodeRef {
class ModularSet : public ObjectRef {
public:
TVM_DLL ModularSet(int64_t coeff, int64_t base);

TVM_DEFINE_NODE_REF_METHODS(ModularSet, NodeRef, ModularSetNode);
TVM_DEFINE_OBJECT_REF_METHODS(ModularSet, ObjectRef, ModularSetNode);
};

/*!
Expand Down Expand Up @@ -349,20 +349,20 @@ enum SignType {
/*!
* \brief Base class of all IntSet containers.
*/
struct IntSetNode : public Node {
struct IntSetNode : public Object {
static constexpr const char* _type_key = "IntSet";
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Object);
TVM_DECLARE_BASE_OBJECT_INFO(IntSetNode, Object);
};

/*!
* \brief Integer set class, represent a set of integers in one dimension.
*/
class IntSet : public NodeRef {
class IntSet : public ObjectRef {
public:
/*! \brief constructor */
IntSet() {}
// constructor from not container.
explicit IntSet(ObjectPtr<Object> n) : NodeRef(n) {}
explicit IntSet(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
Expand Down Expand Up @@ -598,7 +598,7 @@ IntSet EvalSet(Range r,
const std::unordered_map<const Variable*, IntSet>& dom_map);

/*! \brief Map from Expr to IntSet */
using ExprIntSetMap = std::unordered_map<Expr, IntSet, NodeHash, NodeEqual>;
using ExprIntSetMap = std::unordered_map<Expr, IntSet, ObjectHash, ObjectEqual>;
/*!
* \brief Find the integer set of every sub-expression, given the
* domain of each iteration variables.
Expand Down
39 changes: 21 additions & 18 deletions include/tvm/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ namespace tvm {
*/
#define TVM_DECLARE_ATTRS(ClassName, TypeKey) \
static constexpr const char* _type_key = TypeKey; \
TVM_DECLARE_NODE_TYPE_INFO(ClassName, ::tvm::BaseAttrsNode) \
TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \
template<typename FVisit> \
void __VisitAttrs__(FVisit& __fvisit__) // NOLINT(*)

Expand All @@ -83,9 +83,9 @@ namespace tvm {
* \tparam TNodeRef the type to be created.
* \return A instance that will represent None.
*/
template<typename TNodeRef>
inline TNodeRef NullValue() {
return TNodeRef(NodePtr<Node>(nullptr));
template<typename TObjectRef>
inline TObjectRef NullValue() {
return TObjectRef(ObjectPtr<Object>(nullptr));
}

template<>
Expand All @@ -106,7 +106,7 @@ struct AttrError : public dmlc::Error {
/*!
* \brief Information about attribute fields in string representations.
*/
class AttrFieldInfoNode : public Node {
class AttrFieldInfoNode : public Object {
public:
/*! \brief name of the field */
std::string name;
Expand All @@ -121,11 +121,14 @@ class AttrFieldInfoNode : public Node {
v->Visit("description", &description);
}
static constexpr const char* _type_key = "AttrFieldInfo";
TVM_DECLARE_NODE_TYPE_INFO(AttrFieldInfoNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object);
};

/*! \brief AttrFieldInfo */
TVM_DEFINE_NODE_REF(AttrFieldInfo, AttrFieldInfoNode);
class AttrFieldInfo : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(AttrFieldInfo, ObjectRef, AttrFieldInfoNode);
};

class AttrsHashHandler;
class AttrsEqualHandler;
Expand Down Expand Up @@ -217,7 +220,7 @@ class AttrsHash {
* subclass AttrsNode instead.
* \sa AttrsNode
*/
class BaseAttrsNode : public Node {
class BaseAttrsNode : public Object {
public:
using TVMArgs = runtime::TVMArgs;
using TVMRetValue = runtime::TVMRetValue;
Expand Down Expand Up @@ -271,16 +274,16 @@ class BaseAttrsNode : public Node {
TVM_DLL virtual size_t ContentHash(AttrsHash hasher) const = 0;

static constexpr const char* _type_key = "Attrs";
TVM_DECLARE_BASE_NODE_INFO(BaseAttrsNode, Node);
TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object);
};

/*! \brief Base attribute container for all attributes */
class Attrs : public NodeRef {
class Attrs : public ObjectRef {
public:
// normal constructor
Attrs() {}
// construct from shared ptr.
explicit Attrs(NodePtr<Node> n) : NodeRef(n) {}
explicit Attrs(ObjectPtr<Object> n) : ObjectRef(n) {}

/*! \return The attribute node */
const BaseAttrsNode* operator->() const {
Expand All @@ -305,13 +308,13 @@ class Attrs : public NodeRef {
class DictAttrsNode : public BaseAttrsNode {
public:
/*! \brief internal attrs map */
Map<std::string, NodeRef> dict;
Map<std::string, ObjectRef> dict;
/*!
* \brief Consruct a Attrs backed by DictAttrsNode.
* \param dict The attributes.
* \return The dict attributes.
*/
TVM_DLL static Attrs make(Map<std::string, NodeRef> dict);
TVM_DLL static Attrs make(Map<std::string, ObjectRef> dict);
// implementations
void VisitAttrs(AttrVisitor* v) final;
void VisitNonDefaultAttrs(AttrVisitor* v) final;
Expand All @@ -321,7 +324,7 @@ class DictAttrsNode : public BaseAttrsNode {
size_t ContentHash(AttrsHash hasher) const final;
// type info
static constexpr const char* _type_key = "DictAttrs";
TVM_DECLARE_NODE_TYPE_INFO(DictAttrsNode, BaseAttrsNode);
TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode);
};


Expand Down Expand Up @@ -639,7 +642,7 @@ class AttrDocEntry {
public:
using TSelf = AttrDocEntry;

explicit AttrDocEntry(NodePtr<AttrFieldInfoNode> info)
explicit AttrDocEntry(ObjectPtr<AttrFieldInfoNode> info)
: info_(info) {
}
TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) {
Expand All @@ -663,15 +666,15 @@ class AttrDocEntry {
}

private:
NodePtr<AttrFieldInfoNode> info_;
ObjectPtr<AttrFieldInfoNode> info_;
};

class AttrDocVisitor {
public:
template<typename T>
AttrDocEntry operator()(const char* key, T* v) {
NodePtr<AttrFieldInfoNode> info
= make_node<AttrFieldInfoNode>();
ObjectPtr<AttrFieldInfoNode> info
= make_object<AttrFieldInfoNode>();
info->name = key;
info->type_info = TypeName<T>::value;
fields_.push_back(AttrFieldInfo(info));
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ enum BufferType : int {
* It is a composition of primitive symbolic types,
* used to specify the memory layout of the Tensor used in program input.
*/
class Buffer : public NodeRef {
class Buffer : public ObjectRef {
public:
Buffer() {}
explicit Buffer(ObjectPtr<Object> n) : NodeRef(n) {}
explicit Buffer(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief Return a new buffer that is equivalent with current one
* but always add stride field.
Expand Down Expand Up @@ -101,7 +101,7 @@ class Buffer : public NodeRef {
};

/*! \brief Node to represent a buffer */
class BufferNode : public Node {
class BufferNode : public Object {
public:
// Data fields.
/*!
Expand Down Expand Up @@ -169,7 +169,7 @@ class BufferNode : public Node {
BufferType buffer_type);

static constexpr const char* _type_key = "Buffer";
TVM_DECLARE_NODE_TYPE_INFO(BufferNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object);
};

inline const BufferNode* Buffer::operator->() const {
Expand Down
24 changes: 12 additions & 12 deletions include/tvm/build_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ namespace tvm {
* \brief Container for target device information.
* Use target::llvm, target::cuda etc functions instead of constructing directly.
*/
class TargetNode : public Node {
class TargetNode : public Object {
public:
/*! \brief The name of the target device */
std::string target_name;
Expand Down Expand Up @@ -82,18 +82,18 @@ class TargetNode : public Node {
TVM_DLL std::unordered_set<std::string> libs() const;

static constexpr const char* _type_key = "Target";
TVM_DECLARE_NODE_TYPE_INFO(TargetNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(TargetNode, Object);

private:
/*! \brief Internal string repr. */
mutable std::string str_repr_;
};

/*! \brief reference cpass to the target. */
class Target : public NodeRef {
class Target : public ObjectRef {
public:
Target() {}
explicit Target(ObjectPtr<Object> n) : NodeRef(n) {}
explicit Target(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief Create a Target given a string
* \param target_str the string to parse
Expand Down Expand Up @@ -178,7 +178,7 @@ TVM_DLL Target ext_dev(const std::vector<std::string>& options =
/*!
* \brief Container for build configuration options
*/
class BuildConfigNode : public Node {
class BuildConfigNode : public Object {
public:
/*!
* \brief The data alignment to use when constructing buffers. If this is set to
Expand Down Expand Up @@ -254,16 +254,16 @@ class BuildConfigNode : public Node {
}

static constexpr const char* _type_key = "BuildConfig";
TVM_DECLARE_NODE_TYPE_INFO(BuildConfigNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(BuildConfigNode, Object);
};

/*!
* \brief Build configuration for compilations.
*/
class BuildConfig : public ::tvm::NodeRef {
class BuildConfig : public ::tvm::ObjectRef {
public:
BuildConfig() {}
explicit BuildConfig(ObjectPtr<Object> n) : NodeRef(n) {}
explicit BuildConfig(ObjectPtr<Object> n) : ObjectRef(n) {}
const BuildConfigNode* operator->() const {
return static_cast<const BuildConfigNode*>(get());
}
Expand Down Expand Up @@ -375,10 +375,10 @@ class GenericFuncNode;
/*!
* \brief Generic function that can be specialized on a per-target basis.
*/
class GenericFunc : public NodeRef {
class GenericFunc : public ObjectRef {
public:
GenericFunc() {}
explicit GenericFunc(ObjectPtr<Object> n) : NodeRef(n) {}
explicit GenericFunc(ObjectPtr<Object> n) : ObjectRef(n) {}

/*!
* \brief Set the default function implementaiton.
Expand Down Expand Up @@ -471,7 +471,7 @@ inline runtime::TVMRetValue GenericFunc::operator()(Args&& ...args) const {
/*!
* \brief Represents a generic function that can be specialized on a per-target basis.
*/
class GenericFuncNode : public Node {
class GenericFuncNode : public Object {
public:
/*! \brief name of the function */
std::string name_;
Expand All @@ -483,7 +483,7 @@ class GenericFuncNode : public Node {
void VisitAttrs(AttrVisitor* v) {}

static constexpr const char* _type_key = "GenericFunc";
TVM_DECLARE_NODE_TYPE_INFO(GenericFuncNode, Node);
TVM_DECLARE_FINAL_OBJECT_INFO(GenericFuncNode, Object);
};

inline GenericFuncNode* GenericFunc::operator->() {
Expand Down
Loading

0 comments on commit a8c3692

Please sign in to comment.