Skip to content

Commit

Permalink
[REFACTOR] Replace TensorObj and TensorValue with NDArray (#4643)
Browse files Browse the repository at this point in the history
* replace TensorObj and TensorValue with NDArray

* NodeBase to Object in Python

* rebase
  • Loading branch information
zhiics authored and jroesch committed Jan 11, 2020
1 parent dcf7fbf commit 86092de
Show file tree
Hide file tree
Showing 55 changed files with 508 additions and 781 deletions.
13 changes: 5 additions & 8 deletions docs/api/python/dev.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,14 @@ Developer API
This page contains modules that are used by developers of TVM.
Many of these APIs are PackedFunc registered in C++ backend.

tvm.node
~~~~~~~~
.. automodule:: tvm.node

.. autoclass:: tvm.node.NodeBase
:members:
tvm.object
~~~~~~~~~~
.. automodule:: tvm.object

.. autoclass:: tvm.node.Node
.. autoclass:: tvm.object.Object
:members:

.. autofunction:: tvm.register_node
.. autofunction:: tvm.register_object

tvm.expr
~~~~~~~~
Expand Down
8 changes: 4 additions & 4 deletions docs/dev/codebase_walkthrough.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,18 @@ We use a simple example that uses the low level TVM API directly. The example is
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C")

Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``python/tvm/tensor.py``. The Python ``Tensor`` is backed by C++ ``Tensor``, implemented in ``include/tvm/tensor.h`` and ``src/lang/tensor.cc``. All Python types in TVM can be thought of as a handle to the underlying C++ type with the same name. If you look at the definition of Python ``Tensor`` type below, you can see it is a subclass of ``NodeBase``.
Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``python/tvm/tensor.py``. The Python ``Tensor`` is backed by C++ ``Tensor``, implemented in ``include/tvm/tensor.h`` and ``src/lang/tensor.cc``. All Python types in TVM can be thought of as a handle to the underlying C++ type with the same name. If you look at the definition of Python ``Tensor`` type below, you can see it is a subclass of ``Object``.

::

@register_node
class Tensor(NodeBase, _expr.ExprOp):
@register_object
class Tensor(Object, _expr.ExprOp):
"""Tensor object, to construct, see function.Tensor"""

def __call__(self, *indices):
...

The Node system is the basis of exposing C++ types to frontend languages, including Python. The way TVM implements Python wrapping is not straightforward. It is briefly covered in `this document <https://docs.tvm.ai/dev/runtime.html#tvm-node-and-compiler-stack>`_, and details are in ``python/tvm/_ffi/`` if you are interested.
The object protocol is the basis of exposing C++ types to frontend languages, including Python. The way TVM implements Python wrapping is not straightforward. It is briefly covered in `this document <https://docs.tvm.ai/dev/runtime.html#tvm-node-and-compiler-stack>`_, and details are in ``python/tvm/_ffi/`` if you are interested.

``Tensor`` is created by functions in ``python/tvm/api.py``, which in turn calls into C++ functions exposed in ``src/api/api_lang.cc``. All C++ functions that are callable from Python are exposed in the ``src/api`` subdirectory. For example, the ``tvm.compute()`` function above calls into ``_ComputeOp`` API exposed in ``src/api/api_lang.cc``:

Expand Down
105 changes: 30 additions & 75 deletions include/tvm/relay/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,11 @@
#include <tvm/build_module.h>
#include <tvm/relay/module.h>
#include <tvm/relay/expr.h>
#include <tvm/runtime/object.h>

namespace tvm {
namespace relay {

/*!
* \brief A Relay value.
*/
class Value;

/*!
*\brief Create a Interpreter function that can
* evaluate an expression and produce a value.
Expand All @@ -65,39 +61,21 @@ class Value;
* \param target Compiler target flag to compile the functions on the context.
* \return A function that takes in an expression and returns a value.
*/
runtime::TypedPackedFunc<Value(Expr)>
runtime::TypedPackedFunc<ObjectRef(Expr)>
CreateInterpreter(Module mod, DLContext context, Target target);

/*! \brief The base container type of Relay values. */
class ValueNode : public RelayNode {
public:
static constexpr const char* _type_key = "relay.Value";
TVM_DECLARE_BASE_OBJECT_INFO(ValueNode, RelayNode);
};

class Value : public ObjectRef {
public:
Value() {}
explicit Value(ObjectPtr<Object> n) : ObjectRef(n) {}
const ValueNode* operator->() const {
return static_cast<const ValueNode*>(get());
}

using ContainerType = ValueNode;
};

/*! \brief A Relay closure, i.e a scope and a function. */
class Closure;

/*! \brief The container type of Closures. */
class ClosureNode : public ValueNode {
class ClosureNode : public Object {
public:
/*! \brief The set of free variables in the closure.
*
* These are the captured variables which are required for
* evaluation when we call the closure.
*/
tvm::Map<Var, Value> env;
tvm::Map<Var, ObjectRef> env;
/*! \brief The function which implements the closure.
*
* \note May reference the variables contained in the env.
Expand All @@ -111,22 +89,22 @@ class ClosureNode : public ValueNode {
v->Visit("func", &func);
}

TVM_DLL static Closure make(tvm::Map<Var, Value> env, Function func);
TVM_DLL static Closure make(tvm::Map<Var, ObjectRef> env, Function func);

static constexpr const char* _type_key = "relay.Closure";
TVM_DECLARE_FINAL_OBJECT_INFO(ClosureNode, ValueNode);
TVM_DECLARE_FINAL_OBJECT_INFO(ClosureNode, Object);
};

class Closure : public Value {
class Closure : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Closure, Value, ClosureNode);
TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureNode);
};

/*! \brief A Relay Recursive Closure. A closure that has a name. */
class RecClosure;

/*! \brief The container type of RecClosure. */
class RecClosureNode : public ValueNode {
class RecClosureNode : public Object {
public:
/*! \brief The closure. */
Closure clos;
Expand All @@ -143,89 +121,66 @@ class RecClosureNode : public ValueNode {
TVM_DLL static RecClosure make(Closure clos, Var bind);

static constexpr const char* _type_key = "relay.RecClosure";
TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureNode, ValueNode);
TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureNode, Object);
};

class RecClosure : public Value {
class RecClosure : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, Value, RecClosureNode);
TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, ObjectRef, RecClosureNode);
};

/*! \brief A tuple value. */
class TupleValue;

/*! \brief Tuple (x, ... y). */
struct TupleValueNode : ValueNode {
tvm::Array<Value> fields;
struct TupleValueNode : Object {
tvm::Array<ObjectRef> fields;

TupleValueNode() {}

void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); }

TVM_DLL static TupleValue make(tvm::Array<Value> value);
TVM_DLL static TupleValue make(tvm::Array<ObjectRef> value);

static constexpr const char* _type_key = "relay.TupleValue";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleValueNode, ValueNode);
};

class TupleValue : public Value {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TupleValue, Value, TupleValueNode);
};

/*! \brief A tensor value. */
class TensorValue;

/*! \brief The tensor value container, wrapping an NDArray. */
struct TensorValueNode : ValueNode {
runtime::NDArray data;

TensorValueNode() {}

void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("data", &data); }

/*! \brief Build a value from an NDArray. */
TVM_DLL static TensorValue make(runtime::NDArray data);

static constexpr const char* _type_key = "relay.TensorValue";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorValueNode, ValueNode);
TVM_DECLARE_FINAL_OBJECT_INFO(TupleValueNode, Object);
};

class TensorValue : public Value {
class TupleValue : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TensorValue, Value, TensorValueNode);
TVM_DEFINE_OBJECT_REF_METHODS(TupleValue, ObjectRef, TupleValueNode);
};

/*! \brief A reference value. */
class RefValue;

struct RefValueNode : ValueNode {
mutable Value value;
struct RefValueNode : Object {
mutable ObjectRef value;

RefValueNode() {}

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("value", &value);
}

TVM_DLL static RefValue make(Value val);
TVM_DLL static RefValue make(ObjectRef val);

static constexpr const char* _type_key = "relay.RefValue";
TVM_DECLARE_FINAL_OBJECT_INFO(RefValueNode, ValueNode);
TVM_DECLARE_FINAL_OBJECT_INFO(RefValueNode, Object);
};

class RefValue : public Value {
class RefValue : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RefValue, Value, RefValueNode);
TVM_DEFINE_OBJECT_REF_METHODS(RefValue, ObjectRef, RefValueNode);
};

/*! \brief An ADT constructor value. */
class ConstructorValue;

struct ConstructorValueNode : ValueNode {
struct ConstructorValueNode : Object {
int32_t tag;

tvm::Array<Value> fields;
tvm::Array<ObjectRef> fields;

/*! \brief Optional field tracking ADT constructor. */
Constructor constructor;
Expand All @@ -237,16 +192,16 @@ struct ConstructorValueNode : ValueNode {
}

TVM_DLL static ConstructorValue make(int32_t tag,
tvm::Array<Value> fields,
tvm::Array<ObjectRef> fields,
Constructor construtor = {});

static constexpr const char* _type_key = "relay.ConstructorValue";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueNode, ValueNode);
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueNode, Object);
};

class ConstructorValue : public Value {
class ConstructorValue : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, Value, ConstructorValueNode);
TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueNode);
};

} // namespace relay
Expand Down
19 changes: 0 additions & 19 deletions include/tvm/runtime/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,25 +36,6 @@ namespace tvm {
namespace runtime {
namespace vm {

/*! \brief An object containing an NDArray. */
class TensorObj : public Object {
public:
/*! \brief The NDArray. */
NDArray data;

static constexpr const uint32_t _type_index = TypeIndex::kVMTensor;
static constexpr const char* _type_key = "vm.Tensor";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorObj, Object);
};

/*! \brief reference to tensor. */
class Tensor : public ObjectRef {
public:
explicit Tensor(NDArray data);

TVM_DEFINE_OBJECT_REF_METHODS(Tensor, ObjectRef, TensorObj);
};

/*! \brief An object representing a closure. */
class ClosureObj : public Object {
public:
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from . import container
from . import schedule
from . import module
from . import node
from . import object
from . import attrs
from . import ir_builder
from . import target
Expand All @@ -55,7 +55,7 @@
from .api import *
from .intrin import *
from .tensor_intrin import decl_tensor_intrin
from .node import register_node
from .object import register_object
from .ndarray import register_extension
from .schedule import create_schedule
from .build_module import build, lower, build_config
Expand Down
13 changes: 4 additions & 9 deletions python/tvm/_ffi/_ctypes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@

from ..base import _LIB, get_last_ffi_error, py2cerror
from ..base import c_str, string_types
from ..node_generic import convert_to_node, NodeGeneric
from ..object_generic import convert_to_object, ObjectGeneric
from ..runtime_ctypes import TVMType, TVMByteArray, TVMContext
from . import ndarray as _nd
from .ndarray import NDArrayBase, _make_array
from .types import TVMValue, TypeCode
from .types import TVMPackedCFunc, TVMCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64
from .object import ObjectBase, _set_class_node
from .object import ObjectBase, _set_class_object
from . import object as _object

FunctionHandle = ctypes.c_void_p
Expand Down Expand Up @@ -144,8 +144,8 @@ def _make_tvm_args(args, temp_args):
elif isinstance(arg, string_types):
values[i].v_str = c_str(arg)
type_codes[i] = TypeCode.STR
elif isinstance(arg, (list, tuple, dict, NodeGeneric)):
arg = convert_to_node(arg)
elif isinstance(arg, (list, tuple, dict, ObjectGeneric)):
arg = convert_to_object(arg)
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.OBJECT_HANDLE
temp_args.append(arg)
Expand Down Expand Up @@ -256,7 +256,6 @@ def _handle_return_func(x):

_CLASS_MODULE = None
_CLASS_FUNCTION = None
_CLASS_OBJECT = None

def _set_class_module(module_class):
"""Initialize the module."""
Expand All @@ -266,7 +265,3 @@ def _set_class_module(module_class):
def _set_class_function(func_class):
global _CLASS_FUNCTION
_CLASS_FUNCTION = func_class

def _set_class_object(obj_class):
global _CLASS_OBJECT
_CLASS_OBJECT = obj_class
10 changes: 5 additions & 5 deletions python/tvm/_ffi/_ctypes/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
"""Maps object type to its constructor"""
OBJECT_TYPE = {}

_CLASS_NODE = None
_CLASS_OBJECT = None

def _set_class_node(node_class):
global _CLASS_NODE
_CLASS_NODE = node_class
def _set_class_object(object_class):
global _CLASS_OBJECT
_CLASS_OBJECT = object_class


def _register_object(index, cls):
Expand All @@ -51,7 +51,7 @@ def _return_object(x):
handle = ObjectHandle(handle)
tindex = ctypes.c_uint()
check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex)))
cls = OBJECT_TYPE.get(tindex.value, _CLASS_NODE)
cls = OBJECT_TYPE.get(tindex.value, _CLASS_OBJECT)
# Avoid calling __init__ of cls, instead directly call __new__
# This allows child class to implement their own __init__
obj = cls.__new__(cls)
Expand Down
Loading

0 comments on commit 86092de

Please sign in to comment.