Skip to content

Commit

Permalink
[RUNTIME][OBJECT] Introduce static slots for common objects. (apache#…
Browse files Browse the repository at this point in the history
…5423)

The _type_child_slots can be used to enable quick type checking optimization
by checking the whether the type index is within the bound.

This PR enables these static slots:

- Introduce a static assert to avoid the scenario when a developer forget to
  _type_child_slots when the field is set for the type's parent.
- Revamp and assign static type index to common runtime objects
- Add a DumpTypeTable call to allow developer monitor the current situation
  of type table and offers suggestions for the slots(ideally the slots equals
  the number of children so there is no overflow.
  • Loading branch information
tqchen authored and trevor-m committed Jun 18, 2020
1 parent b213814 commit 73dd012
Show file tree
Hide file tree
Showing 30 changed files with 112 additions and 72 deletions.
7 changes: 5 additions & 2 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ namespace tvm {
*/
class BaseExprNode : public Object {
public:
static constexpr const char* _type_key = "Expr";
static constexpr const char* _type_key = "BaseExpr";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const uint32_t _type_child_slots = 58;
TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object);
};

Expand Down Expand Up @@ -88,6 +89,7 @@ class PrimExprNode : public BaseExprNode {
DataType dtype;

static constexpr const char* _type_key = "PrimExpr";
static constexpr const uint32_t _type_child_slots = 34;
TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, BaseExprNode);
};

Expand Down Expand Up @@ -161,7 +163,8 @@ class RelayExprNode : public BaseExprNode {
template<typename TTypeNode>
inline const TTypeNode* type_as() const;

static constexpr const char* _type_key = "relay.Expr";
static constexpr const char* _type_key = "RelayExpr";
static constexpr const uint32_t _type_child_slots = 22;
TVM_DECLARE_BASE_OBJECT_INFO(RelayExprNode, BaseExprNode);
};

Expand Down
1 change: 1 addition & 0 deletions include/tvm/ir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ class BaseFuncNode : public RelayExprNode {
}

static constexpr const char* _type_key = "BaseFunc";
static constexpr const uint32_t _type_child_slots = 2;
TVM_DECLARE_BASE_OBJECT_INFO(BaseFuncNode, RelayExprNode);
};

Expand Down
1 change: 1 addition & 0 deletions include/tvm/ir/tensor_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ namespace tvm {
class BaseTensorTypeNode : public TypeNode {
public:
static constexpr const char* _type_key = "relay.BaseTensorType";
static constexpr const uint32_t _type_child_slots = 1;
TVM_DECLARE_BASE_OBJECT_INFO(BaseTensorTypeNode, TypeNode);
};

Expand Down
2 changes: 2 additions & 0 deletions include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class TypeNode : public Object {
static constexpr const char* _type_key = "Type";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const uint32_t _type_child_slots = 14;
TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object);
};

Expand Down Expand Up @@ -391,6 +392,7 @@ inline bool IsVoidType(const Type& type) {
class TypeConstraintNode : public TypeNode {
public:
static constexpr const char* _type_key = "TypeConstraint";
static constexpr const uint32_t _type_child_slots = 1;
TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode);
};

Expand Down
1 change: 1 addition & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,7 @@ class TempExprNode : public ExprNode {
static constexpr const char* _type_key = "relay.TempExpr";
static constexpr const bool _type_has_method_sequal_reduce = false;
static constexpr const bool _type_has_method_shash_reduce = false;
static constexpr const uint32_t _type_child_slots = 0;
TVM_DECLARE_BASE_OBJECT_INFO(TempExprNode, ExprNode);
};

Expand Down
6 changes: 3 additions & 3 deletions include/tvm/runtime/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ class ADTObj : public Object, public InplaceArrayBase<ADTObj, ObjectRef> {
uint32_t size;
// The fields of the structure follows directly in memory.

static constexpr const uint32_t _type_index = TypeIndex::kVMADT;
static constexpr const char* _type_key = "vm.ADT";
static constexpr const uint32_t _type_index = TypeIndex::kRuntimeADT;
static constexpr const char* _type_key = "runtime.ADT";
TVM_DECLARE_FINAL_OBJECT_INFO(ADTObj, Object);

private:
Expand Down Expand Up @@ -314,7 +314,7 @@ class StringObj : public Object {
/*! \brief The length of the string object. */
uint64_t size;

static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const uint32_t _type_index = TypeIndex::kRuntimeString;
static constexpr const char* _type_key = "runtime.String";
TVM_DECLARE_FINAL_OBJECT_INFO(StringObj, Object);

Expand Down
4 changes: 2 additions & 2 deletions include/tvm/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,10 @@ class NDArray::Container :
using Object::IncRef;

// Information for object protocol.
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const uint32_t _type_index = TypeIndex::kRuntimeNDArray;
static constexpr const uint32_t _type_child_slots = 0;
static constexpr const uint32_t _type_child_slots_can_overflow = true;
static constexpr const char* _type_key = "NDArray";
static constexpr const char* _type_key = "runtime.NDArray";
TVM_DECLARE_BASE_OBJECT_INFO(NDArray::Container, Object);

protected:
Expand Down
43 changes: 31 additions & 12 deletions include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,31 @@
namespace tvm {
namespace runtime {

/*! \brief list of the type index. */
enum TypeIndex {
/*! \brief Root object type. */
kRoot = 0,
kClosure = 1,
kVMADT = 2,
kRuntimeModule = 3,
kStaticIndexEnd,
/*! \brief Type index is allocated during runtime. */
kDynamic = kStaticIndexEnd
};
/*!
* \brief Namespace for the list of type index.
* \note Use struct so that we have to use TypeIndex::ENumName to refer to
* the constant, but still able to use enum.
*/
struct TypeIndex {
enum {
/*! \brief Root object type. */
kRoot = 0,
// Standard static index assignments,
// Frontends can take benefit of these constants.
/*! \brief runtime::Module. */
kRuntimeModule = 1,
/*! \brief runtime::NDArray. */
kRuntimeNDArray = 2,
/*! \brief runtime::String. */
kRuntimeString = 3,
// static assignments that may subject to change.
kRuntimeClosure,
kRuntimeADT,
kStaticIndexEnd,
/*! \brief Type index is allocated during runtime. */
kDynamic = kStaticIndexEnd
};
}; // namespace TypeIndex

/*!
* \brief base class of all object containers.
Expand Down Expand Up @@ -198,7 +212,7 @@ class Object {
using RefCounterType = int32_t;
#endif

static constexpr const char* _type_key = "Object";
static constexpr const char* _type_key = "runtime.Object";

static uint32_t _GetOrAllocRuntimeTypeIndex() {
return TypeIndex::kRoot;
Expand Down Expand Up @@ -675,6 +689,10 @@ struct ObjectEqual {
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \
static_assert(!ParentType::_type_final, "ParentObj maked as final"); \
static uint32_t RuntimeTypeIndex() { \
static_assert(TypeName::_type_child_slots == 0 || \
ParentType::_type_child_slots == 0 || \
TypeName::_type_child_slots < ParentType::_type_child_slots, \
"Need to set _type_child_slots when parent specifies it."); \
if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \
return TypeName::_type_index; \
} \
Expand All @@ -690,6 +708,7 @@ struct ObjectEqual {
return tidx; \
} \


/*!
* \brief helper macro to declare type information in a final class.
* \param TypeName The name of the current type.
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -1268,6 +1268,8 @@ struct unpack_call_dispatcher<void, 0, index, F> {

template<typename R, int nargs, typename F>
inline void unpack_call(const F& f, const TVMArgs& args, TVMRetValue* rv) {
CHECK_EQ(nargs, args.size())
<< "Expect " << nargs << " arguments but get " << args.size();
unpack_call_dispatcher<R, nargs, 0, F>::run(f, args, rv);
}

Expand Down
4 changes: 2 additions & 2 deletions include/tvm/runtime/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ namespace vm {
*/
class ClosureObj : public Object {
public:
static constexpr const uint32_t _type_index = TypeIndex::kClosure;
static constexpr const char* _type_key = "Closure";
static constexpr const uint32_t _type_index = TypeIndex::kRuntimeClosure;
static constexpr const char* _type_key = "runtime.Closure";
TVM_DECLARE_BASE_OBJECT_INFO(ClosureObj, Object);
};

Expand Down
1 change: 1 addition & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class StmtNode : public Object {
static constexpr const char* _type_key = "Stmt";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const uint32_t _type_child_slots = 15;
TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object);
};

Expand Down
1 change: 1 addition & 0 deletions include/tvm/tir/var.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class VarNode : public PrimExprNode {
}

static constexpr const char* _type_key = "tir.Var";
static constexpr const uint32_t _type_child_slots = 1;
TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode);
};

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/quantize/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __enter__(self):
return self

def __exit__(self, ptype, value, trace):
_quantize._ExitQConfigScope(self)
_quantize._ExitQConfigScope()

def __setattr__(self, name, value):
if name in QConfig._node_defaults:
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/runtime/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def getitem_helper(obj, elem_getter, length, idx):
return elem_getter(obj, idx)


@tvm._ffi.register_object("vm.ADT")
@tvm._ffi.register_object("runtime.ADT")
class ADT(Object):
"""Algebatic data type(ADT) object.
Expand Down
10 changes: 0 additions & 10 deletions python/tvm/runtime/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,16 +398,6 @@ def load_module(path, fmt=""):
This function will automatically call
cc.create_shared if the path is in format .o or .tar
"""
if os.stat(path).st_size == 0:
logging.info("The lib generated by the NNVM compiler does not contain optimized "
"functions for any operators. This usually happens when an external "
"accelerator, e.g. TensorRT, is employed along with TVM to compile "
"the model, and all the operators in the model are supported by the "
"external accelerator at runtime. Therefore, the NNVM compiler skipped "
"optimizing them at the compile time. The TVM runtime "
"will create an empty Module as a dummy module.")
return _ffi_api.CreateEmptyModule()

# High level handling for .o and .tar file.
# We support this to be consistent with RPC module load.
if path.endswith(".o"):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/runtime/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from tvm._ffi._ctypes.ndarray import NDArrayBase


@tvm._ffi.register_object
@tvm._ffi.register_object("runtime.NDArray")
class NDArray(NDArrayBase):
"""Lightweight NDArray class of TVM runtime.
Expand Down
1 change: 1 addition & 0 deletions src/arith/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class CanonicalExprNode : public PrimExprNode {
}

static constexpr const char* _type_key = "arith.CanonicalExpr";
static constexpr const uint32_t _type_child_slots = 2;
TVM_DECLARE_BASE_OBJECT_INFO(CanonicalExprNode, PrimExprNode);
};

Expand Down
12 changes: 1 addition & 11 deletions src/runtime/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,16 +193,6 @@ TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile")
mod->SaveToFile(name, fmt);
});

TVM_REGISTER_GLOBAL("runtime.IsEmpty")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator Module().IsEmpty();
});

TVM_REGISTER_GLOBAL("runtime.CreateEmptyModule")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Module m;
*ret = m;
});

TVM_REGISTER_OBJECT_TYPE(ModuleNode);
} // namespace runtime
} // namespace tvm
34 changes: 30 additions & 4 deletions src/runtime/object.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ class TypeContext {
return it->second;
}
// try to allocate from parent's type table.
CHECK_LT(parent_tindex, type_table_.size());
CHECK_LT(parent_tindex, type_table_.size())
<< " skey= " << skey << "static_index=" << static_tindex;
TypeInfo& pinfo = type_table_[parent_tindex];
CHECK_EQ(pinfo.index, parent_tindex);

Expand All @@ -108,7 +109,7 @@ class TypeContext {
<< " between " << type_table_[allocated_tindex].name
<< " and "
<< skey;
} else if (pinfo.allocated_slots + num_slots < pinfo.num_slots) {
} else if (pinfo.allocated_slots + num_slots <= pinfo.num_slots) {
// allocate the slot from parent's reserved pool
allocated_tindex = parent_tindex + pinfo.allocated_slots;
// update parent's state
Expand All @@ -119,8 +120,8 @@ class TypeContext {
// allocate new entries.
allocated_tindex = type_counter_;
type_counter_ += num_slots;
CHECK_LE(type_table_.size(), allocated_tindex);
type_table_.resize(allocated_tindex + 1, TypeInfo());
CHECK_LE(type_table_.size(), type_counter_);
type_table_.resize(type_counter_, TypeInfo());
}
CHECK_GT(allocated_tindex, parent_tindex);
// initialize the slot.
Expand Down Expand Up @@ -161,6 +162,25 @@ class TypeContext {
return it->second;
}

void Dump(int min_children_count) {
std::vector<int> num_children(type_table_.size(), 0);
// reverse accumulation so we can get total counts in a bottom-up manner.
for (auto it = type_table_.rbegin(); it != type_table_.rend(); ++it) {
if (it->index != 0) {
num_children[it->parent_index] += num_children[it->index] + 1;
}
}

for (const auto& info : type_table_) {
if (info.index != 0 && num_children[info.index] >= min_children_count) {
std::cerr <<'[' << info.index << "] "<< info.name
<< "\tparent=" << type_table_[info.parent_index].name
<< "\tnum_child_slots=" << info.num_slots - 1
<< "\tnum_children=" << num_children[info.index] << std::endl;
}
}
}

static TypeContext* Global() {
static TypeContext inst;
return &inst;
Expand All @@ -169,6 +189,7 @@ class TypeContext {
private:
TypeContext() {
type_table_.resize(TypeIndex::kStaticIndexEnd, TypeInfo());
type_table_[0].name = "runtime.Object";
}
// mutex to avoid registration from multiple threads.
std::mutex mutex_;
Expand Down Expand Up @@ -208,6 +229,11 @@ TVM_REGISTER_GLOBAL("runtime.ObjectHash")
.set_body_typed([](ObjectRef obj) {
return static_cast<int64_t>(ObjectHash()(obj));
});

TVM_REGISTER_GLOBAL("runtime.DumpTypeTable")
.set_body_typed([](int min_child_count) {
TypeContext::Global()->Dump(min_child_count);
});
} // namespace runtime
} // namespace tvm

Expand Down
1 change: 1 addition & 0 deletions src/target/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ runtime::Module Build(IRModule mod, const Target& target) {
if (BuildConfig::Current()->disable_assert) {
mod = tir::transform::SkipAssert()(mod);
}

std::string build_f_name = "target.build." + target->target_name;
// the build function.
const PackedFunc* bf = runtime::Registry::Get(build_f_name);
Expand Down
2 changes: 1 addition & 1 deletion src/target/opt/build_cuda_on.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ std::string NVRTCCompile(const std::string& code, bool include_path = false) {
return ptx;
}

runtime::Module BuildCUDA(IRModule mod) {
runtime::Module BuildCUDA(IRModule mod, std::string target) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenCUDA cg;
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ void CodeGenOpenCL::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NO
}
}

runtime::Module BuildOpenCL(IRModule mod) {
runtime::Module BuildOpenCL(IRModule mod, std::string target) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenOpenCL cg;
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_opengl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ void CodeGenOpenGL::VisitStmt_(const EvaluateNode* op) {
this->stream << GetVarID(buffer) << " = " << PrintExpr(value) << ";\n";
}

runtime::Module BuildOpenGL(IRModule mod) {
runtime::Module BuildOpenGL(IRModule mod, std::string target) {
bool output_ssa = false;
CodeGenOpenGL cg;
cg.Init(output_ssa);
Expand Down
2 changes: 1 addition & 1 deletion src/target/spirv/build_vulkan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class SPIRVTools {
spv_context ctx_;
};

runtime::Module BuildSPIRV(IRModule mod) {
runtime::Module BuildSPIRV(IRModule mod, std::string target) {
using tvm::runtime::Registry;
using tvm::runtime::VulkanShader;

Expand Down
Loading

0 comments on commit 73dd012

Please sign in to comment.