Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR] Replace TensorObj and TensorValue with NDArray #4643

Merged
merged 3 commits into from
Jan 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions docs/api/python/dev.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,14 @@ Developer API
This page contains modules that are used by developers of TVM.
Many of these APIs are PackedFunc registered in C++ backend.

tvm.node
~~~~~~~~
.. automodule:: tvm.node

.. autoclass:: tvm.node.NodeBase
:members:
tvm.object
~~~~~~~~~~
.. automodule:: tvm.object

.. autoclass:: tvm.node.Node
.. autoclass:: tvm.object.Object
:members:

.. autofunction:: tvm.register_node
.. autofunction:: tvm.register_object

tvm.expr
~~~~~~~~
Expand Down
8 changes: 4 additions & 4 deletions docs/dev/codebase_walkthrough.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,18 @@ We use a simple example that uses the low level TVM API directly. The example is
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C")

Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``python/tvm/tensor.py``. The Python ``Tensor`` is backed by C++ ``Tensor``, implemented in ``include/tvm/tensor.h`` and ``src/lang/tensor.cc``. All Python types in TVM can be thought of as a handle to the underlying C++ type with the same name. If you look at the definition of Python ``Tensor`` type below, you can see it is a subclass of ``NodeBase``.
Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``python/tvm/tensor.py``. The Python ``Tensor`` is backed by C++ ``Tensor``, implemented in ``include/tvm/tensor.h`` and ``src/lang/tensor.cc``. All Python types in TVM can be thought of as a handle to the underlying C++ type with the same name. If you look at the definition of Python ``Tensor`` type below, you can see it is a subclass of ``Object``.

::

@register_node
class Tensor(NodeBase, _expr.ExprOp):
@register_object
class Tensor(Object, _expr.ExprOp):
"""Tensor object, to construct, see function.Tensor"""

def __call__(self, *indices):
...

The Node system is the basis of exposing C++ types to frontend languages, including Python. The way TVM implements Python wrapping is not straightforward. It is briefly covered in `this document <https://docs.tvm.ai/dev/runtime.html#tvm-node-and-compiler-stack>`_, and details are in ``python/tvm/_ffi/`` if you are interested.
The object protocol is the basis of exposing C++ types to frontend languages, including Python. The way TVM implements Python wrapping is not straightforward. It is briefly covered in `this document <https://docs.tvm.ai/dev/runtime.html#tvm-node-and-compiler-stack>`_, and details are in ``python/tvm/_ffi/`` if you are interested.

``Tensor`` is created by functions in ``python/tvm/api.py``, which in turn calls into C++ functions exposed in ``src/api/api_lang.cc``. All C++ functions that are callable from Python are exposed in the ``src/api`` subdirectory. For example, the ``tvm.compute()`` function above calls into ``_ComputeOp`` API exposed in ``src/api/api_lang.cc``:

Expand Down
105 changes: 30 additions & 75 deletions include/tvm/relay/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,11 @@
#include <tvm/build_module.h>
#include <tvm/relay/module.h>
#include <tvm/relay/expr.h>
#include <tvm/runtime/object.h>

namespace tvm {
namespace relay {

/*!
* \brief A Relay value.
*/
class Value;

/*!
*\brief Create a Interpreter function that can
* evaluate an expression and produce a value.
Expand All @@ -65,39 +61,21 @@ class Value;
* \param target Compiler target flag to compile the functions on the context.
* \return A function that takes in an expression and returns a value.
*/
runtime::TypedPackedFunc<Value(Expr)>
runtime::TypedPackedFunc<ObjectRef(Expr)>
CreateInterpreter(Module mod, DLContext context, Target target);

/*! \brief The base container type of Relay values. */
class ValueNode : public RelayNode {
public:
static constexpr const char* _type_key = "relay.Value";
TVM_DECLARE_BASE_OBJECT_INFO(ValueNode, RelayNode);
};

class Value : public ObjectRef {
public:
Value() {}
explicit Value(ObjectPtr<Object> n) : ObjectRef(n) {}
const ValueNode* operator->() const {
return static_cast<const ValueNode*>(get());
}

using ContainerType = ValueNode;
};

/*! \brief A Relay closure, i.e a scope and a function. */
class Closure;

/*! \brief The container type of Closures. */
class ClosureNode : public ValueNode {
class ClosureNode : public Object {
wweic marked this conversation as resolved.
Show resolved Hide resolved
public:
/*! \brief The set of free variables in the closure.
*
* These are the captured variables which are required for
* evaluation when we call the closure.
*/
tvm::Map<Var, Value> env;
tvm::Map<Var, ObjectRef> env;
/*! \brief The function which implements the closure.
*
* \note May reference the variables contained in the env.
Expand All @@ -111,22 +89,22 @@ class ClosureNode : public ValueNode {
v->Visit("func", &func);
}

TVM_DLL static Closure make(tvm::Map<Var, Value> env, Function func);
TVM_DLL static Closure make(tvm::Map<Var, ObjectRef> env, Function func);

static constexpr const char* _type_key = "relay.Closure";
TVM_DECLARE_FINAL_OBJECT_INFO(ClosureNode, ValueNode);
TVM_DECLARE_FINAL_OBJECT_INFO(ClosureNode, Object);
};

class Closure : public Value {
class Closure : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Closure, Value, ClosureNode);
TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureNode);
};

/*! \brief A Relay Recursive Closure. A closure that has a name. */
class RecClosure;

/*! \brief The container type of RecClosure. */
class RecClosureNode : public ValueNode {
class RecClosureNode : public Object {
wweic marked this conversation as resolved.
Show resolved Hide resolved
public:
/*! \brief The closure. */
Closure clos;
Expand All @@ -143,89 +121,66 @@ class RecClosureNode : public ValueNode {
TVM_DLL static RecClosure make(Closure clos, Var bind);

static constexpr const char* _type_key = "relay.RecClosure";
TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureNode, ValueNode);
TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureNode, Object);
};

class RecClosure : public Value {
class RecClosure : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, Value, RecClosureNode);
TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, ObjectRef, RecClosureNode);
};

/*! \brief A tuple value. */
class TupleValue;

/*! \brief Tuple (x, ... y). */
struct TupleValueNode : ValueNode {
tvm::Array<Value> fields;
struct TupleValueNode : Object {
wweic marked this conversation as resolved.
Show resolved Hide resolved
tvm::Array<ObjectRef> fields;

TupleValueNode() {}

void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); }

TVM_DLL static TupleValue make(tvm::Array<Value> value);
TVM_DLL static TupleValue make(tvm::Array<ObjectRef> value);

static constexpr const char* _type_key = "relay.TupleValue";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleValueNode, ValueNode);
};

class TupleValue : public Value {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TupleValue, Value, TupleValueNode);
};

/*! \brief A tensor value. */
class TensorValue;

/*! \brief The tensor value container, wrapping an NDArray. */
struct TensorValueNode : ValueNode {
runtime::NDArray data;

TensorValueNode() {}

void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("data", &data); }

/*! \brief Build a value from an NDArray. */
TVM_DLL static TensorValue make(runtime::NDArray data);

static constexpr const char* _type_key = "relay.TensorValue";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorValueNode, ValueNode);
TVM_DECLARE_FINAL_OBJECT_INFO(TupleValueNode, Object);
};

class TensorValue : public Value {
class TupleValue : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TensorValue, Value, TensorValueNode);
TVM_DEFINE_OBJECT_REF_METHODS(TupleValue, ObjectRef, TupleValueNode);
};

/*! \brief A reference value. */
class RefValue;

struct RefValueNode : ValueNode {
mutable Value value;
struct RefValueNode : Object {
mutable ObjectRef value;

RefValueNode() {}

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("value", &value);
}

TVM_DLL static RefValue make(Value val);
TVM_DLL static RefValue make(ObjectRef val);

static constexpr const char* _type_key = "relay.RefValue";
TVM_DECLARE_FINAL_OBJECT_INFO(RefValueNode, ValueNode);
TVM_DECLARE_FINAL_OBJECT_INFO(RefValueNode, Object);
};

class RefValue : public Value {
class RefValue : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RefValue, Value, RefValueNode);
TVM_DEFINE_OBJECT_REF_METHODS(RefValue, ObjectRef, RefValueNode);
};

/*! \brief An ADT constructor value. */
class ConstructorValue;

struct ConstructorValueNode : ValueNode {
struct ConstructorValueNode : Object {
int32_t tag;

tvm::Array<Value> fields;
tvm::Array<ObjectRef> fields;

/*! \brief Optional field tracking ADT constructor. */
Constructor constructor;
Expand All @@ -237,16 +192,16 @@ struct ConstructorValueNode : ValueNode {
}

TVM_DLL static ConstructorValue make(int32_t tag,
tvm::Array<Value> fields,
tvm::Array<ObjectRef> fields,
Constructor construtor = {});

static constexpr const char* _type_key = "relay.ConstructorValue";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueNode, ValueNode);
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueNode, Object);
};

class ConstructorValue : public Value {
class ConstructorValue : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, Value, ConstructorValueNode);
TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueNode);
};

} // namespace relay
Expand Down
19 changes: 0 additions & 19 deletions include/tvm/runtime/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,25 +36,6 @@ namespace tvm {
namespace runtime {
namespace vm {

/*! \brief An object containing an NDArray. */
class TensorObj : public Object {
public:
/*! \brief The NDArray. */
NDArray data;

static constexpr const uint32_t _type_index = TypeIndex::kVMTensor;
static constexpr const char* _type_key = "vm.Tensor";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorObj, Object);
};

/*! \brief reference to tensor. */
class Tensor : public ObjectRef {
public:
explicit Tensor(NDArray data);

TVM_DEFINE_OBJECT_REF_METHODS(Tensor, ObjectRef, TensorObj);
};

/*! \brief An object representing a closure. */
class ClosureObj : public Object {
public:
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from . import container
from . import schedule
from . import module
from . import node
from . import object
from . import attrs
from . import ir_builder
from . import target
Expand All @@ -55,7 +55,7 @@
from .api import *
from .intrin import *
from .tensor_intrin import decl_tensor_intrin
from .node import register_node
from .object import register_object
from .ndarray import register_extension
from .schedule import create_schedule
from .build_module import build, lower, build_config
Expand Down
13 changes: 4 additions & 9 deletions python/tvm/_ffi/_ctypes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@

from ..base import _LIB, get_last_ffi_error, py2cerror
from ..base import c_str, string_types
from ..node_generic import convert_to_node, NodeGeneric
from ..object_generic import convert_to_object, ObjectGeneric
from ..runtime_ctypes import TVMType, TVMByteArray, TVMContext
from . import ndarray as _nd
from .ndarray import NDArrayBase, _make_array
from .types import TVMValue, TypeCode
from .types import TVMPackedCFunc, TVMCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64
from .object import ObjectBase, _set_class_node
from .object import ObjectBase, _set_class_object
from . import object as _object

FunctionHandle = ctypes.c_void_p
Expand Down Expand Up @@ -144,8 +144,8 @@ def _make_tvm_args(args, temp_args):
elif isinstance(arg, string_types):
values[i].v_str = c_str(arg)
type_codes[i] = TypeCode.STR
elif isinstance(arg, (list, tuple, dict, NodeGeneric)):
arg = convert_to_node(arg)
elif isinstance(arg, (list, tuple, dict, ObjectGeneric)):
arg = convert_to_object(arg)
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.OBJECT_HANDLE
temp_args.append(arg)
Expand Down Expand Up @@ -256,7 +256,6 @@ def _handle_return_func(x):

_CLASS_MODULE = None
_CLASS_FUNCTION = None
_CLASS_OBJECT = None

def _set_class_module(module_class):
"""Initialize the module."""
Expand All @@ -266,7 +265,3 @@ def _set_class_module(module_class):
def _set_class_function(func_class):
global _CLASS_FUNCTION
_CLASS_FUNCTION = func_class

def _set_class_object(obj_class):
global _CLASS_OBJECT
_CLASS_OBJECT = obj_class
10 changes: 5 additions & 5 deletions python/tvm/_ffi/_ctypes/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
"""Maps object type to its constructor"""
OBJECT_TYPE = {}

_CLASS_NODE = None
_CLASS_OBJECT = None

def _set_class_node(node_class):
global _CLASS_NODE
_CLASS_NODE = node_class
def _set_class_object(object_class):
global _CLASS_OBJECT
_CLASS_OBJECT = object_class


def _register_object(index, cls):
Expand All @@ -51,7 +51,7 @@ def _return_object(x):
handle = ObjectHandle(handle)
tindex = ctypes.c_uint()
check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex)))
cls = OBJECT_TYPE.get(tindex.value, _CLASS_NODE)
cls = OBJECT_TYPE.get(tindex.value, _CLASS_OBJECT)
# Avoid calling __init__ of cls, instead directly call __new__
# This allows child class to implement their own __init__
obj = cls.__new__(cls)
Expand Down
Loading