Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RUNTIME][OBJECT] Introduce static slots for common objects. #5423

Merged
merged 1 commit into from
Apr 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ namespace tvm {
*/
class BaseExprNode : public Object {
public:
static constexpr const char* _type_key = "Expr";
static constexpr const char* _type_key = "BaseExpr";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const uint32_t _type_child_slots = 58;
TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object);
};

Expand Down Expand Up @@ -88,6 +89,7 @@ class PrimExprNode : public BaseExprNode {
DataType dtype;

static constexpr const char* _type_key = "PrimExpr";
static constexpr const uint32_t _type_child_slots = 34;
TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, BaseExprNode);
};

Expand Down Expand Up @@ -161,7 +163,8 @@ class RelayExprNode : public BaseExprNode {
template<typename TTypeNode>
inline const TTypeNode* type_as() const;

static constexpr const char* _type_key = "relay.Expr";
static constexpr const char* _type_key = "RelayExpr";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I like this change, its better to have the keys name spaced to Relay imo.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is mainly for the Base class(given the c++ name is RelayExpr under the root). The relay sub-classes are under the relay namespace.

static constexpr const uint32_t _type_child_slots = 22;
TVM_DECLARE_BASE_OBJECT_INFO(RelayExprNode, BaseExprNode);
};

Expand Down
1 change: 1 addition & 0 deletions include/tvm/ir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ class BaseFuncNode : public RelayExprNode {
}

static constexpr const char* _type_key = "BaseFunc";
static constexpr const uint32_t _type_child_slots = 2;
TVM_DECLARE_BASE_OBJECT_INFO(BaseFuncNode, RelayExprNode);
};

Expand Down
1 change: 1 addition & 0 deletions include/tvm/ir/tensor_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ namespace tvm {
class BaseTensorTypeNode : public TypeNode {
public:
static constexpr const char* _type_key = "relay.BaseTensorType";
static constexpr const uint32_t _type_child_slots = 1;
TVM_DECLARE_BASE_OBJECT_INFO(BaseTensorTypeNode, TypeNode);
};

Expand Down
2 changes: 2 additions & 0 deletions include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class TypeNode : public Object {
static constexpr const char* _type_key = "Type";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const uint32_t _type_child_slots = 14;
TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object);
};

Expand Down Expand Up @@ -391,6 +392,7 @@ inline bool IsVoidType(const Type& type) {
class TypeConstraintNode : public TypeNode {
public:
static constexpr const char* _type_key = "TypeConstraint";
static constexpr const uint32_t _type_child_slots = 1;
TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode);
};

Expand Down
1 change: 1 addition & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,7 @@ class TempExprNode : public ExprNode {
static constexpr const char* _type_key = "relay.TempExpr";
static constexpr const bool _type_has_method_sequal_reduce = false;
static constexpr const bool _type_has_method_shash_reduce = false;
static constexpr const uint32_t _type_child_slots = 0;
TVM_DECLARE_BASE_OBJECT_INFO(TempExprNode, ExprNode);
};

Expand Down
6 changes: 3 additions & 3 deletions include/tvm/runtime/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ class ADTObj : public Object, public InplaceArrayBase<ADTObj, ObjectRef> {
uint32_t size;
// The fields of the structure follows directly in memory.

static constexpr const uint32_t _type_index = TypeIndex::kVMADT;
static constexpr const char* _type_key = "vm.ADT";
static constexpr const uint32_t _type_index = TypeIndex::kRuntimeADT;
static constexpr const char* _type_key = "runtime.ADT";
TVM_DECLARE_FINAL_OBJECT_INFO(ADTObj, Object);

private:
Expand Down Expand Up @@ -314,7 +314,7 @@ class StringObj : public Object {
/*! \brief The length of the string object. */
uint64_t size;

static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const uint32_t _type_index = TypeIndex::kRuntimeString;
static constexpr const char* _type_key = "runtime.String";
TVM_DECLARE_FINAL_OBJECT_INFO(StringObj, Object);

Expand Down
4 changes: 2 additions & 2 deletions include/tvm/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,10 @@ class NDArray::Container :
using Object::IncRef;

// Information for object protocol.
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const uint32_t _type_index = TypeIndex::kRuntimeNDArray;
static constexpr const uint32_t _type_child_slots = 0;
static constexpr const uint32_t _type_child_slots_can_overflow = true;
static constexpr const char* _type_key = "NDArray";
static constexpr const char* _type_key = "runtime.NDArray";
TVM_DECLARE_BASE_OBJECT_INFO(NDArray::Container, Object);

protected:
Expand Down
43 changes: 31 additions & 12 deletions include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,31 @@
namespace tvm {
namespace runtime {

/*! \brief list of the type index. */
enum TypeIndex {
/*! \brief Root object type. */
kRoot = 0,
kClosure = 1,
kVMADT = 2,
kRuntimeModule = 3,
kStaticIndexEnd,
/*! \brief Type index is allocated during runtime. */
kDynamic = kStaticIndexEnd
};
/*!
* \brief Namespace for the list of type index.
* \note Use struct so that we have to use TypeIndex::ENumName to refer to
* the constant, but still able to use enum.
*/
struct TypeIndex {
enum {
/*! \brief Root object type. */
kRoot = 0,
// Standard static index assignments,
// Frontends can take benefit of these constants.
/*! \brief runtime::Module. */
kRuntimeModule = 1,
/*! \brief runtime::NDArray. */
kRuntimeNDArray = 2,
/*! \brief runtime::String. */
kRuntimeString = 3,
// static assignments that may subject to change.
kRuntimeClosure,
kRuntimeADT,
kStaticIndexEnd,
/*! \brief Type index is allocated during runtime. */
kDynamic = kStaticIndexEnd
};
}; // namespace TypeIndex

/*!
* \brief base class of all object containers.
Expand Down Expand Up @@ -198,7 +212,7 @@ class Object {
using RefCounterType = int32_t;
#endif

static constexpr const char* _type_key = "Object";
static constexpr const char* _type_key = "runtime.Object";

static uint32_t _GetOrAllocRuntimeTypeIndex() {
return TypeIndex::kRoot;
Expand Down Expand Up @@ -675,6 +689,10 @@ struct ObjectEqual {
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \
static_assert(!ParentType::_type_final, "ParentObj maked as final"); \
static uint32_t RuntimeTypeIndex() { \
static_assert(TypeName::_type_child_slots == 0 || \
ParentType::_type_child_slots == 0 || \
TypeName::_type_child_slots < ParentType::_type_child_slots, \
"Need to set _type_child_slots when parent specifies it."); \
if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \
return TypeName::_type_index; \
} \
Expand All @@ -690,6 +708,7 @@ struct ObjectEqual {
return tidx; \
} \


/*!
* \brief helper macro to declare type information in a final class.
* \param TypeName The name of the current type.
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -1268,6 +1268,8 @@ struct unpack_call_dispatcher<void, 0, index, F> {

template<typename R, int nargs, typename F>
inline void unpack_call(const F& f, const TVMArgs& args, TVMRetValue* rv) {
CHECK_EQ(nargs, args.size())
<< "Expect " << nargs << " arguments but get " << args.size();
unpack_call_dispatcher<R, nargs, 0, F>::run(f, args, rv);
}

Expand Down
4 changes: 2 additions & 2 deletions include/tvm/runtime/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ namespace vm {
*/
class ClosureObj : public Object {
public:
static constexpr const uint32_t _type_index = TypeIndex::kClosure;
static constexpr const char* _type_key = "Closure";
static constexpr const uint32_t _type_index = TypeIndex::kRuntimeClosure;
static constexpr const char* _type_key = "runtime.Closure";
TVM_DECLARE_BASE_OBJECT_INFO(ClosureObj, Object);
};

Expand Down
1 change: 1 addition & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class StmtNode : public Object {
static constexpr const char* _type_key = "Stmt";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const uint32_t _type_child_slots = 15;
TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object);
};

Expand Down
1 change: 1 addition & 0 deletions include/tvm/tir/var.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class VarNode : public PrimExprNode {
}

static constexpr const char* _type_key = "tir.Var";
static constexpr const uint32_t _type_child_slots = 1;
TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode);
};

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/quantize/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __enter__(self):
return self

def __exit__(self, ptype, value, trace):
_quantize._ExitQConfigScope(self)
_quantize._ExitQConfigScope()

def __setattr__(self, name, value):
if name in QConfig._node_defaults:
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/runtime/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def getitem_helper(obj, elem_getter, length, idx):
return elem_getter(obj, idx)


@tvm._ffi.register_object("vm.ADT")
@tvm._ffi.register_object("runtime.ADT")
class ADT(Object):
"""Algebatic data type(ADT) object.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/runtime/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from tvm._ffi._ctypes.ndarray import NDArrayBase


@tvm._ffi.register_object
@tvm._ffi.register_object("runtime.NDArray")
class NDArray(NDArrayBase):
"""Lightweight NDArray class of TVM runtime.
Expand Down
1 change: 1 addition & 0 deletions src/arith/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class CanonicalExprNode : public PrimExprNode {
}

static constexpr const char* _type_key = "arith.CanonicalExpr";
static constexpr const uint32_t _type_child_slots = 2;
TVM_DECLARE_BASE_OBJECT_INFO(CanonicalExprNode, PrimExprNode);
};

Expand Down
2 changes: 2 additions & 0 deletions src/runtime/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,5 +188,7 @@ TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile")
.set_body_typed([](Module mod, std::string name, std::string fmt) {
mod->SaveToFile(name, fmt);
});

TVM_REGISTER_OBJECT_TYPE(ModuleNode);
} // namespace runtime
} // namespace tvm
34 changes: 30 additions & 4 deletions src/runtime/object.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ class TypeContext {
return it->second;
}
// try to allocate from parent's type table.
CHECK_LT(parent_tindex, type_table_.size());
CHECK_LT(parent_tindex, type_table_.size())
<< " skey= " << skey << "static_index=" << static_tindex;
TypeInfo& pinfo = type_table_[parent_tindex];
CHECK_EQ(pinfo.index, parent_tindex);

Expand All @@ -108,7 +109,7 @@ class TypeContext {
<< " between " << type_table_[allocated_tindex].name
<< " and "
<< skey;
} else if (pinfo.allocated_slots + num_slots < pinfo.num_slots) {
} else if (pinfo.allocated_slots + num_slots <= pinfo.num_slots) {
// allocate the slot from parent's reserved pool
allocated_tindex = parent_tindex + pinfo.allocated_slots;
// update parent's state
Expand All @@ -119,8 +120,8 @@ class TypeContext {
// allocate new entries.
allocated_tindex = type_counter_;
type_counter_ += num_slots;
CHECK_LE(type_table_.size(), allocated_tindex);
type_table_.resize(allocated_tindex + 1, TypeInfo());
CHECK_LE(type_table_.size(), type_counter_);
type_table_.resize(type_counter_, TypeInfo());
}
CHECK_GT(allocated_tindex, parent_tindex);
// initialize the slot.
Expand Down Expand Up @@ -161,6 +162,25 @@ class TypeContext {
return it->second;
}

void Dump(int min_children_count) {
std::vector<int> num_children(type_table_.size(), 0);
// reverse accumulation so we can get total counts in a bottom-up manner.
for (auto it = type_table_.rbegin(); it != type_table_.rend(); ++it) {
if (it->index != 0) {
num_children[it->parent_index] += num_children[it->index] + 1;
}
}

for (const auto& info : type_table_) {
if (info.index != 0 && num_children[info.index] >= min_children_count) {
std::cerr <<'[' << info.index << "] "<< info.name
<< "\tparent=" << type_table_[info.parent_index].name
<< "\tnum_child_slots=" << info.num_slots - 1
<< "\tnum_children=" << num_children[info.index] << std::endl;
}
}
}

static TypeContext* Global() {
static TypeContext inst;
return &inst;
Expand All @@ -169,6 +189,7 @@ class TypeContext {
private:
TypeContext() {
type_table_.resize(TypeIndex::kStaticIndexEnd, TypeInfo());
type_table_[0].name = "runtime.Object";
}
// mutex to avoid registration from multiple threads.
std::mutex mutex_;
Expand Down Expand Up @@ -208,6 +229,11 @@ TVM_REGISTER_GLOBAL("runtime.ObjectHash")
.set_body_typed([](ObjectRef obj) {
return static_cast<int64_t>(ObjectHash()(obj));
});

TVM_REGISTER_GLOBAL("runtime.DumpTypeTable")
.set_body_typed([](int min_child_count) {
TypeContext::Global()->Dump(min_child_count);
});
} // namespace runtime
} // namespace tvm

Expand Down
1 change: 1 addition & 0 deletions src/target/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ runtime::Module Build(IRModule mod, const Target& target) {
if (BuildConfig::Current()->disable_assert) {
mod = tir::transform::SkipAssert()(mod);
}

std::string build_f_name = "target.build." + target->target_name;
// the build function.
const PackedFunc* bf = runtime::Registry::Get(build_f_name);
Expand Down
2 changes: 1 addition & 1 deletion src/target/opt/build_cuda_on.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ std::string NVRTCCompile(const std::string& code, bool include_path = false) {
return ptx;
}

runtime::Module BuildCUDA(IRModule mod) {
runtime::Module BuildCUDA(IRModule mod, std::string target) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenCUDA cg;
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ void CodeGenOpenCL::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NO
}
}

runtime::Module BuildOpenCL(IRModule mod) {
runtime::Module BuildOpenCL(IRModule mod, std::string target) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenOpenCL cg;
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_opengl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ void CodeGenOpenGL::VisitStmt_(const EvaluateNode* op) {
this->stream << GetVarID(buffer) << " = " << PrintExpr(value) << ";\n";
}

runtime::Module BuildOpenGL(IRModule mod) {
runtime::Module BuildOpenGL(IRModule mod, std::string target) {
bool output_ssa = false;
CodeGenOpenGL cg;
cg.Init(output_ssa);
Expand Down
2 changes: 1 addition & 1 deletion src/target/spirv/build_vulkan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class SPIRVTools {
spv_context ctx_;
};

runtime::Module BuildSPIRV(IRModule mod) {
runtime::Module BuildSPIRV(IRModule mod, std::string target) {
using tvm::runtime::Registry;
using tvm::runtime::VulkanShader;

Expand Down
2 changes: 1 addition & 1 deletion src/target/stackvm/codegen_stackvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ void CodeGenStackVM::VisitExpr_(const LetNode* op) {
this->Push(op->body);
}

runtime::Module BuildStackVM(const IRModule& mod) {
runtime::Module BuildStackVM(const IRModule& mod, const std::string& target) {
std::unordered_map<std::string, StackVM> fmap;
std::string entry_func;

Expand Down
Loading