Skip to content

Commit

Permalink
NodeBase to Object in Python
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Jan 9, 2020
1 parent 4e19231 commit 92d6b83
Show file tree
Hide file tree
Showing 36 changed files with 298 additions and 347 deletions.
13 changes: 5 additions & 8 deletions docs/api/python/dev.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,14 @@ Developer API
This page contains modules that are used by developers of TVM.
Many of these APIs are PackedFunc registered in C++ backend.

tvm.node
~~~~~~~~
.. automodule:: tvm.node

.. autoclass:: tvm.node.NodeBase
:members:
tvm.object
~~~~~~~~~~
.. automodule:: tvm.object

.. autoclass:: tvm.node.Node
.. autoclass:: tvm.object.Object
:members:

.. autofunction:: tvm.register_node
.. autofunction:: tvm.register_object

tvm.expr
~~~~~~~~
Expand Down
8 changes: 4 additions & 4 deletions docs/dev/codebase_walkthrough.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,18 @@ We use a simple example that uses the low level TVM API directly. The example is
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C")

Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``python/tvm/tensor.py``. The Python ``Tensor`` is backed by C++ ``Tensor``, implemented in ``include/tvm/tensor.h`` and ``src/lang/tensor.cc``. All Python types in TVM can be thought of as a handle to the underlying C++ type with the same name. If you look at the definition of Python ``Tensor`` type below, you can see it is a subclass of ``NodeBase``.
Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``python/tvm/tensor.py``. The Python ``Tensor`` is backed by C++ ``Tensor``, implemented in ``include/tvm/tensor.h`` and ``src/lang/tensor.cc``. All Python types in TVM can be thought of as a handle to the underlying C++ type with the same name. If you look at the definition of Python ``Tensor`` type below, you can see it is a subclass of ``Object``.

::

@register_node
class Tensor(NodeBase, _expr.ExprOp):
@register_object
class Tensor(Object, _expr.ExprOp):
"""Tensor object, to construct, see function.Tensor"""

def __call__(self, *indices):
...

The Node system is the basis of exposing C++ types to frontend languages, including Python. The way TVM implements Python wrapping is not straightforward. It is briefly covered in `this document <https://docs.tvm.ai/dev/runtime.html#tvm-node-and-compiler-stack>`_, and details are in ``python/tvm/_ffi/`` if you are interested.
The object protocol is the basis of exposing C++ types to frontend languages, including Python. The way TVM implements Python wrapping is not straightforward. It is briefly covered in `this document <https://docs.tvm.ai/dev/runtime.html#tvm-node-and-compiler-stack>`_, and details are in ``python/tvm/_ffi/`` if you are interested.

``Tensor`` is created by functions in ``python/tvm/api.py``, which in turn calls into C++ functions exposed in ``src/api/api_lang.cc``. All C++ functions that are callable from Python are exposed in the ``src/api`` subdirectory. For example, the ``tvm.compute()`` function above calls into ``_ComputeOp`` API exposed in ``src/api/api_lang.cc``:

Expand Down
4 changes: 2 additions & 2 deletions python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from . import container
from . import schedule
from . import module
from . import node
from . import object
from . import attrs
from . import ir_builder
from . import target
Expand All @@ -55,7 +55,7 @@
from .api import *
from .intrin import *
from .tensor_intrin import decl_tensor_intrin
from .node import register_node
from .object import register_object
from .ndarray import register_extension
from .schedule import create_schedule
from .build_module import build, lower, build_config
Expand Down
13 changes: 4 additions & 9 deletions python/tvm/_ffi/_ctypes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@

from ..base import _LIB, get_last_ffi_error, py2cerror
from ..base import c_str, string_types
from ..node_generic import convert_to_node, NodeGeneric
from ..object_generic import convert_to_object, ObjectGeneric
from ..runtime_ctypes import TVMType, TVMByteArray, TVMContext
from . import ndarray as _nd
from .ndarray import NDArrayBase, _make_array
from .types import TVMValue, TypeCode
from .types import TVMPackedCFunc, TVMCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64
from .object import ObjectBase, _set_class_node
from .object import ObjectBase, _set_class_object
from . import object as _object

FunctionHandle = ctypes.c_void_p
Expand Down Expand Up @@ -144,8 +144,8 @@ def _make_tvm_args(args, temp_args):
elif isinstance(arg, string_types):
values[i].v_str = c_str(arg)
type_codes[i] = TypeCode.STR
elif isinstance(arg, (list, tuple, dict, NodeGeneric)):
arg = convert_to_node(arg)
elif isinstance(arg, (list, tuple, dict, ObjectGeneric)):
arg = convert_to_object(arg)
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.OBJECT_HANDLE
temp_args.append(arg)
Expand Down Expand Up @@ -256,7 +256,6 @@ def _handle_return_func(x):

_CLASS_MODULE = None
_CLASS_FUNCTION = None
_CLASS_OBJECT = None

def _set_class_module(module_class):
"""Initialize the module."""
Expand All @@ -266,7 +265,3 @@ def _set_class_module(module_class):
def _set_class_function(func_class):
global _CLASS_FUNCTION
_CLASS_FUNCTION = func_class

def _set_class_object(obj_class):
global _CLASS_OBJECT
_CLASS_OBJECT = obj_class
10 changes: 5 additions & 5 deletions python/tvm/_ffi/_ctypes/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
"""Maps object type to its constructor"""
OBJECT_TYPE = {}

_CLASS_NODE = None
_CLASS_OBJECT = None

def _set_class_node(node_class):
global _CLASS_NODE
_CLASS_NODE = node_class
def _set_class_object(object_class):
global _CLASS_OBJECT
_CLASS_OBJECT = object_class


def _register_object(index, cls):
Expand All @@ -51,7 +51,7 @@ def _return_object(x):
handle = ObjectHandle(handle)
tindex = ctypes.c_uint()
check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex)))
cls = OBJECT_TYPE.get(tindex.value, _CLASS_NODE)
cls = OBJECT_TYPE.get(tindex.value, _CLASS_OBJECT)
# Avoid calling __init__ of cls, instead directly call __new__
# This allows child class to implement their own __init__
obj = cls.__new__(cls)
Expand Down
11 changes: 3 additions & 8 deletions python/tvm/_ffi/_cython/function.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import traceback
from cpython cimport Py_INCREF, Py_DECREF
from numbers import Number, Integral
from ..base import string_types, py2cerror
from ..node_generic import convert_to_node, NodeGeneric
from ..object_generic import convert_to_object, ObjectGeneric
from ..runtime_ctypes import TVMType, TVMContext, TVMByteArray


Expand Down Expand Up @@ -149,8 +149,8 @@ cdef inline int make_arg(object arg,
value[0].v_str = tstr
tcode[0] = kStr
temp_args.append(tstr)
elif isinstance(arg, (list, tuple, dict, NodeGeneric)):
arg = convert_to_node(arg)
elif isinstance(arg, (list, tuple, dict, ObjectGeneric)):
arg = convert_to_object(arg)
value[0].v_handle = (<ObjectBase>arg).chandle
tcode[0] = kObjectHandle
temp_args.append(arg)
Expand Down Expand Up @@ -308,7 +308,6 @@ cdef class FunctionBase:
_CLASS_FUNCTION = None
_CLASS_MODULE = None
_CLASS_OBJECT = None
_CLASS_NODE = None

def _set_class_module(module_class):
"""Initialize the module."""
Expand All @@ -322,7 +321,3 @@ def _set_class_function(func_class):
def _set_class_object(obj_class):
global _CLASS_OBJECT
_CLASS_OBJECT = obj_class

def _set_class_node(node_class):
global _CLASS_NODE
_CLASS_NODE = node_class
8 changes: 3 additions & 5 deletions python/tvm/_ffi/_cython/object.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _register_object(int index, object cls):

cdef inline object make_ret_object(void* chandle):
global OBJECT_TYPE
global _CLASS_NODE
global _CLASS_OBJECT
cdef unsigned tindex
cdef object cls
cdef object handle
Expand All @@ -44,11 +44,9 @@ cdef inline object make_ret_object(void* chandle):
if cls is not None:
obj = cls.__new__(cls)
else:
# default use node base class
# TODO(tqchen) change to object after Node unifies with Object
obj = _CLASS_NODE.__new__(_CLASS_NODE)
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
else:
obj = _CLASS_NODE.__new__(_CLASS_NODE)
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
(<ObjectBase>obj).chandle = chandle
return obj

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/_ffi/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import sys
import ctypes
from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE
from .node_generic import _set_class_objects
from .object_generic import _set_class_objects

IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError

Expand Down
89 changes: 0 additions & 89 deletions python/tvm/_ffi/node.py

This file was deleted.

66 changes: 61 additions & 5 deletions python/tvm/_ffi/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
# pylint: disable=invalid-name, unused-import
"""Runtime Object API"""
from __future__ import absolute_import

import sys
import ctypes
from .. import _api_internal
from .base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str
from .object_generic import ObjectGeneric, convert_to_object, const

IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError

Expand All @@ -29,23 +31,77 @@
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_object, _set_class_node
from ._cy3.core import _set_class_object
from ._cy3.core import ObjectBase as _ObjectBase
from ._cy3.core import _register_object
else:
from ._cy2.core import _set_class_object, _set_class_node
from ._cy2.core import _set_class_object
from ._cy2.core import ObjectBase as _ObjectBase
from ._cy2.core import _register_object
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position,unused-import
from ._ctypes.function import _set_class_object, _set_class_node
from ._ctypes.function import _set_class_object
from ._ctypes.object import ObjectBase as _ObjectBase
from ._ctypes.object import _register_object


def _new_object(cls):
"""Helper function for pickle"""
return cls.__new__(cls)


class Object(_ObjectBase):
"""Base class for all tvm's runtime objects."""
pass
def __repr__(self):
return _api_internal._format_str(self)

def __dir__(self):
fnames = _api_internal._NodeListAttrNames(self)
size = fnames(-1)
return [fnames(i) for i in range(size)]

def __getattr__(self, name):
try:
return _api_internal._NodeGetAttr(self, name)
except AttributeError:
raise AttributeError(
"%s has no attribute %s" % (str(type(self)), name))

def __hash__(self):
return _api_internal._raw_ptr(self)

def __eq__(self, other):
return self.same_as(other)

def __ne__(self, other):
return not self.__eq__(other)

def __reduce__(self):
cls = type(self)
return (_new_object, (cls, ), self.__getstate__())

def __getstate__(self):
handle = self.handle
if handle is not None:
return {'handle': _api_internal._save_json(self)}
return {'handle': None}

def __setstate__(self, state):
# pylint: disable=assigning-non-slot
handle = state['handle']
if handle is not None:
json_str = handle
other = _api_internal._load_json(json_str)
self.handle = other.handle
other.handle = None
else:
self.handle = None

def same_as(self, other):
"""check object identity equality"""
if not isinstance(other, Object):
return False
return self.__hash__() == other.__hash__()


def register_object(type_key=None):
Expand Down
Loading

0 comments on commit 92d6b83

Please sign in to comment.