From 92d6b832a6b79fbf206ae5696dc8f892701fa7d0 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Wed, 8 Jan 2020 21:38:00 +0000 Subject: [PATCH] NodeBase to Object in Python --- docs/api/python/dev.rst | 13 ++- docs/dev/codebase_walkthrough.rst | 8 +- python/tvm/__init__.py | 4 +- python/tvm/_ffi/_ctypes/function.py | 13 +-- python/tvm/_ffi/_ctypes/object.py | 10 +-- python/tvm/_ffi/_cython/function.pxi | 11 +-- python/tvm/_ffi/_cython/object.pxi | 8 +- python/tvm/_ffi/function.py | 2 +- python/tvm/_ffi/node.py | 89 ------------------- python/tvm/_ffi/object.py | 66 ++++++++++++-- .../{node_generic.py => object_generic.py} | 30 +++---- python/tvm/api.py | 19 ++-- python/tvm/arith.py | 16 ++-- python/tvm/attrs.py | 6 +- python/tvm/build_module.py | 18 ++-- python/tvm/container.py | 28 +++--- python/tvm/expr.py | 84 ++++++++--------- python/tvm/ir_builder.py | 6 +- python/tvm/{node.py => object.py} | 4 +- python/tvm/relay/_module.pyi | 4 +- python/tvm/relay/adt.py | 4 +- python/tvm/relay/backend/compile_engine.py | 10 +-- python/tvm/relay/backend/interpreter.py | 14 +-- python/tvm/relay/base.py | 9 +- python/tvm/relay/expr.pyi | 4 +- python/tvm/relay/quantize/quantize.py | 4 +- python/tvm/relay/transform.pyi | 8 +- python/tvm/relay/ty.pyi | 4 +- python/tvm/schedule.py | 40 ++++----- python/tvm/stmt.py | 32 +++---- python/tvm/target.py | 12 +-- python/tvm/tensor.py | 43 ++++----- python/tvm/tensor_intrin.py | 6 +- .../test_pass_inject_double_buffer.py | 4 +- .../unittest/test_pass_inject_vthread.py | 8 +- .../unittest/test_pass_storage_flatten.py | 4 +- 36 files changed, 298 insertions(+), 347 deletions(-) delete mode 100644 python/tvm/_ffi/node.py rename python/tvm/_ffi/{node_generic.py => object_generic.py} (82%) rename python/tvm/{node.py => object.py} (93%) diff --git a/docs/api/python/dev.rst b/docs/api/python/dev.rst index 7bb938ca7517..8a0a70588bc3 100644 --- a/docs/api/python/dev.rst +++ b/docs/api/python/dev.rst @@ -20,17 +20,14 @@ Developer API This page contains modules that are used by developers of TVM. Many of these APIs are PackedFunc registered in C++ backend. -tvm.node -~~~~~~~~ -.. automodule:: tvm.node - -.. autoclass:: tvm.node.NodeBase - :members: +tvm.object +~~~~~~~~~~ +.. automodule:: tvm.object -.. autoclass:: tvm.node.Node +.. autoclass:: tvm.object.Object :members: -.. autofunction:: tvm.register_node +.. autofunction:: tvm.register_object tvm.expr ~~~~~~~~ diff --git a/docs/dev/codebase_walkthrough.rst b/docs/dev/codebase_walkthrough.rst index 19f185edca98..0732c26f0c58 100644 --- a/docs/dev/codebase_walkthrough.rst +++ b/docs/dev/codebase_walkthrough.rst @@ -55,18 +55,18 @@ We use a simple example that uses the low level TVM API directly. The example is B = tvm.placeholder((n,), name='B') C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C") -Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``python/tvm/tensor.py``. The Python ``Tensor`` is backed by C++ ``Tensor``, implemented in ``include/tvm/tensor.h`` and ``src/lang/tensor.cc``. All Python types in TVM can be thought of as a handle to the underlying C++ type with the same name. If you look at the definition of Python ``Tensor`` type below, you can see it is a subclass of ``NodeBase``. +Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``python/tvm/tensor.py``. The Python ``Tensor`` is backed by C++ ``Tensor``, implemented in ``include/tvm/tensor.h`` and ``src/lang/tensor.cc``. All Python types in TVM can be thought of as a handle to the underlying C++ type with the same name. If you look at the definition of Python ``Tensor`` type below, you can see it is a subclass of ``Object``. :: - @register_node - class Tensor(NodeBase, _expr.ExprOp): + @register_object + class Tensor(Object, _expr.ExprOp): """Tensor object, to construct, see function.Tensor""" def __call__(self, *indices): ... -The Node system is the basis of exposing C++ types to frontend languages, including Python. The way TVM implements Python wrapping is not straightforward. It is briefly covered in `this document `_, and details are in ``python/tvm/_ffi/`` if you are interested. +The object protocol is the basis of exposing C++ types to frontend languages, including Python. The way TVM implements Python wrapping is not straightforward. It is briefly covered in `this document `_, and details are in ``python/tvm/_ffi/`` if you are interested. ``Tensor`` is created by functions in ``python/tvm/api.py``, which in turn calls into C++ functions exposed in ``src/api/api_lang.cc``. All C++ functions that are callable from Python are exposed in the ``src/api`` subdirectory. For example, the ``tvm.compute()`` function above calls into ``_ComputeOp`` API exposed in ``src/api/api_lang.cc``: diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 9e3eb0faefb8..b2a4ca3ccf13 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -34,7 +34,7 @@ from . import container from . import schedule from . import module -from . import node +from . import object from . import attrs from . import ir_builder from . import target @@ -55,7 +55,7 @@ from .api import * from .intrin import * from .tensor_intrin import decl_tensor_intrin -from .node import register_node +from .object import register_object from .ndarray import register_extension from .schedule import create_schedule from .build_module import build, lower, build_config diff --git a/python/tvm/_ffi/_ctypes/function.py b/python/tvm/_ffi/_ctypes/function.py index 2f0b5babda4d..45048c5768a9 100644 --- a/python/tvm/_ffi/_ctypes/function.py +++ b/python/tvm/_ffi/_ctypes/function.py @@ -25,14 +25,14 @@ from ..base import _LIB, get_last_ffi_error, py2cerror from ..base import c_str, string_types -from ..node_generic import convert_to_node, NodeGeneric +from ..object_generic import convert_to_object, ObjectGeneric from ..runtime_ctypes import TVMType, TVMByteArray, TVMContext from . import ndarray as _nd from .ndarray import NDArrayBase, _make_array from .types import TVMValue, TypeCode from .types import TVMPackedCFunc, TVMCFuncFinalizer from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64 -from .object import ObjectBase, _set_class_node +from .object import ObjectBase, _set_class_object from . import object as _object FunctionHandle = ctypes.c_void_p @@ -144,8 +144,8 @@ def _make_tvm_args(args, temp_args): elif isinstance(arg, string_types): values[i].v_str = c_str(arg) type_codes[i] = TypeCode.STR - elif isinstance(arg, (list, tuple, dict, NodeGeneric)): - arg = convert_to_node(arg) + elif isinstance(arg, (list, tuple, dict, ObjectGeneric)): + arg = convert_to_object(arg) values[i].v_handle = arg.handle type_codes[i] = TypeCode.OBJECT_HANDLE temp_args.append(arg) @@ -256,7 +256,6 @@ def _handle_return_func(x): _CLASS_MODULE = None _CLASS_FUNCTION = None -_CLASS_OBJECT = None def _set_class_module(module_class): """Initialize the module.""" @@ -266,7 +265,3 @@ def _set_class_module(module_class): def _set_class_function(func_class): global _CLASS_FUNCTION _CLASS_FUNCTION = func_class - -def _set_class_object(obj_class): - global _CLASS_OBJECT - _CLASS_OBJECT = obj_class diff --git a/python/tvm/_ffi/_ctypes/object.py b/python/tvm/_ffi/_ctypes/object.py index b8b8aefea131..8a2fb1b5363e 100644 --- a/python/tvm/_ffi/_ctypes/object.py +++ b/python/tvm/_ffi/_ctypes/object.py @@ -30,11 +30,11 @@ """Maps object type to its constructor""" OBJECT_TYPE = {} -_CLASS_NODE = None +_CLASS_OBJECT = None -def _set_class_node(node_class): - global _CLASS_NODE - _CLASS_NODE = node_class +def _set_class_object(object_class): + global _CLASS_OBJECT + _CLASS_OBJECT = object_class def _register_object(index, cls): @@ -51,7 +51,7 @@ def _return_object(x): handle = ObjectHandle(handle) tindex = ctypes.c_uint() check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex))) - cls = OBJECT_TYPE.get(tindex.value, _CLASS_NODE) + cls = OBJECT_TYPE.get(tindex.value, _CLASS_OBJECT) # Avoid calling __init__ of cls, instead directly call __new__ # This allows child class to implement their own __init__ obj = cls.__new__(cls) diff --git a/python/tvm/_ffi/_cython/function.pxi b/python/tvm/_ffi/_cython/function.pxi index a2360427b6c7..7789769a3901 100644 --- a/python/tvm/_ffi/_cython/function.pxi +++ b/python/tvm/_ffi/_cython/function.pxi @@ -20,7 +20,7 @@ import traceback from cpython cimport Py_INCREF, Py_DECREF from numbers import Number, Integral from ..base import string_types, py2cerror -from ..node_generic import convert_to_node, NodeGeneric +from ..object_generic import convert_to_object, ObjectGeneric from ..runtime_ctypes import TVMType, TVMContext, TVMByteArray @@ -149,8 +149,8 @@ cdef inline int make_arg(object arg, value[0].v_str = tstr tcode[0] = kStr temp_args.append(tstr) - elif isinstance(arg, (list, tuple, dict, NodeGeneric)): - arg = convert_to_node(arg) + elif isinstance(arg, (list, tuple, dict, ObjectGeneric)): + arg = convert_to_object(arg) value[0].v_handle = (arg).chandle tcode[0] = kObjectHandle temp_args.append(arg) @@ -308,7 +308,6 @@ cdef class FunctionBase: _CLASS_FUNCTION = None _CLASS_MODULE = None _CLASS_OBJECT = None -_CLASS_NODE = None def _set_class_module(module_class): """Initialize the module.""" @@ -322,7 +321,3 @@ def _set_class_function(func_class): def _set_class_object(obj_class): global _CLASS_OBJECT _CLASS_OBJECT = obj_class - -def _set_class_node(node_class): - global _CLASS_NODE - _CLASS_NODE = node_class diff --git a/python/tvm/_ffi/_cython/object.pxi b/python/tvm/_ffi/_cython/object.pxi index 6d20723fd188..1392f9944835 100644 --- a/python/tvm/_ffi/_cython/object.pxi +++ b/python/tvm/_ffi/_cython/object.pxi @@ -32,7 +32,7 @@ def _register_object(int index, object cls): cdef inline object make_ret_object(void* chandle): global OBJECT_TYPE - global _CLASS_NODE + global _CLASS_OBJECT cdef unsigned tindex cdef object cls cdef object handle @@ -44,11 +44,9 @@ cdef inline object make_ret_object(void* chandle): if cls is not None: obj = cls.__new__(cls) else: - # default use node base class - # TODO(tqchen) change to object after Node unifies with Object - obj = _CLASS_NODE.__new__(_CLASS_NODE) + obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) else: - obj = _CLASS_NODE.__new__(_CLASS_NODE) + obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) (obj).chandle = chandle return obj diff --git a/python/tvm/_ffi/function.py b/python/tvm/_ffi/function.py index 23d95ebbf66b..22e03563976b 100644 --- a/python/tvm/_ffi/function.py +++ b/python/tvm/_ffi/function.py @@ -22,7 +22,7 @@ import sys import ctypes from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE -from .node_generic import _set_class_objects +from .object_generic import _set_class_objects IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError diff --git a/python/tvm/_ffi/node.py b/python/tvm/_ffi/node.py deleted file mode 100644 index c6c151af9053..000000000000 --- a/python/tvm/_ffi/node.py +++ /dev/null @@ -1,89 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Node namespace""" -# pylint: disable=unused-import -from __future__ import absolute_import - -import ctypes -import sys -from .. import _api_internal -from .object import Object, register_object, _set_class_node -from .node_generic import NodeGeneric, convert_to_node, const - - -def _new_object(cls): - """Helper function for pickle""" - return cls.__new__(cls) - - -class NodeBase(Object): - """NodeBase is the base class of all TVM language AST object.""" - def __repr__(self): - return _api_internal._format_str(self) - - def __dir__(self): - fnames = _api_internal._NodeListAttrNames(self) - size = fnames(-1) - return [fnames(i) for i in range(size)] - - def __getattr__(self, name): - try: - return _api_internal._NodeGetAttr(self, name) - except AttributeError: - raise AttributeError( - "%s has no attribute %s" % (str(type(self)), name)) - - def __hash__(self): - return _api_internal._raw_ptr(self) - - def __eq__(self, other): - return self.same_as(other) - - def __ne__(self, other): - return not self.__eq__(other) - - def __reduce__(self): - cls = type(self) - return (_new_object, (cls, ), self.__getstate__()) - - def __getstate__(self): - handle = self.handle - if handle is not None: - return {'handle': _api_internal._save_json(self)} - return {'handle': None} - - def __setstate__(self, state): - # pylint: disable=assigning-non-slot - handle = state['handle'] - if handle is not None: - json_str = handle - other = _api_internal._load_json(json_str) - self.handle = other.handle - other.handle = None - else: - self.handle = None - - def same_as(self, other): - """check object identity equality""" - if not isinstance(other, NodeBase): - return False - return self.__hash__() == other.__hash__() - - -# pylint: disable=invalid-name -register_node = register_object -_set_class_node(NodeBase) diff --git a/python/tvm/_ffi/object.py b/python/tvm/_ffi/object.py index 002fd27af0fd..83d4129a7140 100644 --- a/python/tvm/_ffi/object.py +++ b/python/tvm/_ffi/object.py @@ -14,13 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name +# pylint: disable=invalid-name, unused-import """Runtime Object API""" from __future__ import absolute_import import sys import ctypes +from .. import _api_internal from .base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str +from .object_generic import ObjectGeneric, convert_to_object, const IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError @@ -29,23 +31,77 @@ if _FFI_MODE == "ctypes": raise ImportError() if sys.version_info >= (3, 0): - from ._cy3.core import _set_class_object, _set_class_node + from ._cy3.core import _set_class_object from ._cy3.core import ObjectBase as _ObjectBase from ._cy3.core import _register_object else: - from ._cy2.core import _set_class_object, _set_class_node + from ._cy2.core import _set_class_object from ._cy2.core import ObjectBase as _ObjectBase from ._cy2.core import _register_object except IMPORT_EXCEPT: # pylint: disable=wrong-import-position,unused-import - from ._ctypes.function import _set_class_object, _set_class_node + from ._ctypes.function import _set_class_object from ._ctypes.object import ObjectBase as _ObjectBase from ._ctypes.object import _register_object +def _new_object(cls): + """Helper function for pickle""" + return cls.__new__(cls) + + class Object(_ObjectBase): """Base class for all tvm's runtime objects.""" - pass + def __repr__(self): + return _api_internal._format_str(self) + + def __dir__(self): + fnames = _api_internal._NodeListAttrNames(self) + size = fnames(-1) + return [fnames(i) for i in range(size)] + + def __getattr__(self, name): + try: + return _api_internal._NodeGetAttr(self, name) + except AttributeError: + raise AttributeError( + "%s has no attribute %s" % (str(type(self)), name)) + + def __hash__(self): + return _api_internal._raw_ptr(self) + + def __eq__(self, other): + return self.same_as(other) + + def __ne__(self, other): + return not self.__eq__(other) + + def __reduce__(self): + cls = type(self) + return (_new_object, (cls, ), self.__getstate__()) + + def __getstate__(self): + handle = self.handle + if handle is not None: + return {'handle': _api_internal._save_json(self)} + return {'handle': None} + + def __setstate__(self, state): + # pylint: disable=assigning-non-slot + handle = state['handle'] + if handle is not None: + json_str = handle + other = _api_internal._load_json(json_str) + self.handle = other.handle + other.handle = None + else: + self.handle = None + + def same_as(self, other): + """check object identity equality""" + if not isinstance(other, Object): + return False + return self.__hash__() == other.__hash__() def register_object(type_key=None): diff --git a/python/tvm/_ffi/node_generic.py b/python/tvm/_ffi/object_generic.py similarity index 82% rename from python/tvm/_ffi/node_generic.py rename to python/tvm/_ffi/object_generic.py index 8ee7fc5f2b5b..92e73ad79e88 100644 --- a/python/tvm/_ffi/node_generic.py +++ b/python/tvm/_ffi/object_generic.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Common implementation of Node generic related logic""" +"""Common implementation of object generic related logic""" # pylint: disable=unused-import from __future__ import absolute_import @@ -22,7 +22,7 @@ from .. import _api_internal from .base import string_types -# Node base class +# Object base class _CLASS_OBJECTS = None def _set_class_objects(cls): @@ -47,15 +47,15 @@ def _scalar_type_inference(value): return dtype -class NodeGeneric(object): - """Base class for all classes that can be converted to node.""" - def asnode(self): - """Convert value to node""" +class ObjectGeneric(object): + """Base class for all classes that can be converted to object.""" + def asobject(self): + """Convert value to object""" raise NotImplementedError() -def convert_to_node(value): - """Convert a python value to corresponding node type. +def convert_to_object(value): + """Convert a python value to corresponding object type. Parameters ---------- @@ -64,8 +64,8 @@ def convert_to_node(value): Returns ------- - node : Node - The corresponding node value. + obj : Object + The corresponding object value. """ if isinstance(value, _CLASS_OBJECTS): return value @@ -76,7 +76,7 @@ def convert_to_node(value): if isinstance(value, string_types): return _api_internal._str(value) if isinstance(value, (list, tuple)): - value = [convert_to_node(x) for x in value] + value = [convert_to_object(x) for x in value] return _api_internal._Array(*value) if isinstance(value, dict): vlist = [] @@ -85,14 +85,14 @@ def convert_to_node(value): not isinstance(item[0], string_types)): raise ValueError("key of map must already been a container type") vlist.append(item[0]) - vlist.append(convert_to_node(item[1])) + vlist.append(convert_to_object(item[1])) return _api_internal._Map(*vlist) - if isinstance(value, NodeGeneric): - return value.asnode() + if isinstance(value, ObjectGeneric): + return value.asobject() if value is None: return None - raise ValueError("don't know how to convert type %s to node" % type(value)) + raise ValueError("don't know how to convert type %s to object" % type(value)) def const(value, dtype=None): diff --git a/python/tvm/api.py b/python/tvm/api.py index 4d0e3472683c..7395d3524709 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -22,9 +22,8 @@ from ._ffi.base import string_types from ._ffi.object import register_object, Object -from ._ffi.node import register_node, NodeBase -from ._ffi.node import convert_to_node as _convert_to_node -from ._ffi.node_generic import _scalar_type_inference +from ._ffi.object import convert_to_object as _convert_to_object +from ._ffi.object_generic import _scalar_type_inference from ._ffi.function import Function from ._ffi.function import _init_api, register_func, get_global_func, extract_ext_funcs from ._ffi.function import convert_to_tvm_func as _convert_tvm_func @@ -111,7 +110,7 @@ def get_env_func(name): Note ---- - EnvFunc is a Node wrapper around + EnvFunc is a Object wrapper around global function that can be serialized via its name. This can be used to serialize function field in the language. """ @@ -127,16 +126,16 @@ def convert(value): Returns ------- - tvm_val : Node or Function + tvm_val : Object or Function Converted value in TVM """ - if isinstance(value, (Function, NodeBase)): + if isinstance(value, (Function, Object)): return value if callable(value): return _convert_tvm_func(value) - return _convert_to_node(value) + return _convert_to_object(value) def load_json(json_str): @@ -149,7 +148,7 @@ def load_json(json_str): Returns ------- - node : Node + node : Object The loaded tvm node. """ return _api_internal._load_json(json_str) @@ -160,8 +159,8 @@ def save_json(node): Parameters ---------- - node : Node - A TVM Node object to be saved. + node : Object + A TVM object to be saved. Returns ------- diff --git a/python/tvm/arith.py b/python/tvm/arith.py index 4c3c05f75796..81f478c66b92 100644 --- a/python/tvm/arith.py +++ b/python/tvm/arith.py @@ -17,11 +17,11 @@ """Arithmetic data structure and utility""" from __future__ import absolute_import as _abs -from ._ffi.node import NodeBase, register_node +from ._ffi.object import Object, register_object from ._ffi.function import _init_api from . import _api_internal -class IntSet(NodeBase): +class IntSet(Object): """Represent a set of integer in one dimension.""" def is_nothing(self): """Whether the set represent nothing""" @@ -32,7 +32,7 @@ def is_everything(self): return _api_internal._IntSetIsEverything(self) -@register_node("arith.IntervalSet") +@register_object("arith.IntervalSet") class IntervalSet(IntSet): """Represent set of continuous interval [min_value, max_value] @@ -49,16 +49,16 @@ def __init__(self, min_value, max_value): _make_IntervalSet, min_value, max_value) -@register_node("arith.ModularSet") -class ModularSet(NodeBase): +@register_object("arith.ModularSet") +class ModularSet(Object): """Represent range of (coeff * x + base) for x in Z """ def __init__(self, coeff, base): self.__init_handle_by_constructor__( _make_ModularSet, coeff, base) -@register_node("arith.ConstIntBound") -class ConstIntBound(NodeBase): +@register_object("arith.ConstIntBound") +class ConstIntBound(Object): """Represent constant integer bound Parameters @@ -245,7 +245,7 @@ def update(self, var, info, override=False): var : tvm.Var The variable. - info : tvm.NodeBase + info : tvm.Object Related information. override : bool diff --git a/python/tvm/attrs.py b/python/tvm/attrs.py index e2a27328fdcc..2963a0e21734 100644 --- a/python/tvm/attrs.py +++ b/python/tvm/attrs.py @@ -15,13 +15,13 @@ # specific language governing permissions and limitations # under the License. """ TVM Attribute module, which is mainly used for defining attributes of operators""" -from ._ffi.node import NodeBase, register_node as _register_tvm_node +from ._ffi.object import Object, register_object from ._ffi.function import _init_api from . import _api_internal -@_register_tvm_node -class Attrs(NodeBase): +@register_object +class Attrs(Object): """Attribute node, which is mainly use for defining attributes of relay operators. Used by function registered in python side, such as compute, schedule and alter_layout. diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index f96e28323595..85d2b8514779 100644 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -23,7 +23,7 @@ import warnings from ._ffi.function import Function -from ._ffi.node import NodeBase, register_node +from ._ffi.object import Object, register_object from . import api from . import _api_internal from . import tensor @@ -115,22 +115,22 @@ def exit(self): DumpIR.scope_level -= 1 -@register_node -class BuildConfig(NodeBase): +@register_object +class BuildConfig(Object): """Configuration scope to set a build config option. Note ---- - This object is backed by node system in C++, with arguments that can be + This object is backed by object protocol in C++, with arguments that can be exchanged between python and C++. Do not construct directly, use build_config instead. - The fields that are backed by the C++ node are immutable once an instance - is constructed. See _node_defaults for the fields. + The fields that are backed by the C++ object are immutable once an instance + is constructed. See _object_defaults for the fields. """ - _node_defaults = { + _object_defaults = { "auto_unroll_max_step": 0, "auto_unroll_max_depth": 8, "auto_unroll_max_extent": 0, @@ -191,7 +191,7 @@ def __exit__(self, ptype, value, trace): _api_internal._ExitBuildConfigScope(self) def __setattr__(self, name, value): - if name in BuildConfig._node_defaults: + if name in BuildConfig._object_defaults: raise AttributeError( "'%s' object cannot set attribute '%s'" % (str(type(self)), name)) return super(BuildConfig, self).__setattr__(name, value) @@ -257,7 +257,7 @@ def build_config(**kwargs): The build configuration """ node_args = {k: v if k not in kwargs else kwargs[k] - for k, v in BuildConfig._node_defaults.items()} + for k, v in BuildConfig._object_defaults.items()} config = make.node("BuildConfig", **node_args) if "add_lower_pass" in kwargs: diff --git a/python/tvm/container.py b/python/tvm/container.py index aedbe95b01b2..274fc1f4027c 100644 --- a/python/tvm/container.py +++ b/python/tvm/container.py @@ -16,11 +16,11 @@ # under the License. """Container data structures used in TVM DSL.""" from __future__ import absolute_import as _abs -from ._ffi.node import NodeBase, register_node +from ._ffi.object import Object, register_object from . import _api_internal -@register_node -class Array(NodeBase): +@register_object +class Array(Object): """Array container of TVM. You do not need to create Array explicitly. @@ -50,8 +50,8 @@ def __len__(self): return _api_internal._ArraySize(self) -@register_node -class EnvFunc(NodeBase): +@register_object +class EnvFunc(Object): """Environment function. This is a global function object that can be serialized by its name. @@ -64,13 +64,13 @@ def func(self): return _api_internal._EnvFuncGetPackedFunc(self) -@register_node -class Map(NodeBase): +@register_object +class Map(Object): """Map container of TVM. You do not need to create Map explicitly. Normally python dict will be converted automaticall to Map during tvm function call. - You can use convert to create a dict[NodeBase-> NodeBase] into a Map + You can use convert to create a dict[Object-> Object] into a Map """ def __getitem__(self, k): return _api_internal._MapGetItem(self, k) @@ -87,11 +87,11 @@ def __len__(self): return _api_internal._MapSize(self) -@register_node +@register_object class StrMap(Map): """A special map container that has str as key. - You can use convert to create a dict[str->NodeBase] into a Map. + You can use convert to create a dict[str->Object] into a Map. """ def items(self): """Get the items from the map""" @@ -99,8 +99,8 @@ def items(self): return [(akvs[i].value, akvs[i+1]) for i in range(0, len(akvs), 2)] -@register_node -class Range(NodeBase): +@register_object +class Range(Object): """Represent a range in TVM. You do not need to create a Range explicitly. @@ -108,8 +108,8 @@ class Range(NodeBase): """ -@register_node -class LoweredFunc(NodeBase): +@register_object +class LoweredFunc(Object): """Represent a LoweredFunc in TVM.""" MixedFunc = 0 HostFunc = 1 diff --git a/python/tvm/expr.py b/python/tvm/expr.py index c6b3d9b866e2..d147dd622fd3 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -32,7 +32,7 @@ """ # pylint: disable=missing-docstring from __future__ import absolute_import as _abs -from ._ffi.node import NodeBase, NodeGeneric, register_node +from ._ffi.object import Object, register_object, ObjectGeneric from ._ffi.runtime_ctypes import TVMType, TypeCode from . import make as _make from . import generic as _generic @@ -178,11 +178,11 @@ def astype(self, dtype): return _generic.cast(self, dtype) -class EqualOp(NodeGeneric, ExprOp): +class EqualOp(ObjectGeneric, ExprOp): """Deferred equal operator. This is used to support sugar that a == b can either - mean NodeBase.same_as or NodeBase.equal. + mean Object.same_as or Object.equal. Parameters ---------- @@ -205,16 +205,16 @@ def __nonzero__(self): def __bool__(self): return self.__nonzero__() - def asnode(self): - """Convert node.""" + def asobject(self): + """Convert object.""" return _make._OpEQ(self.a, self.b) -class NotEqualOp(NodeGeneric, ExprOp): +class NotEqualOp(ObjectGeneric, ExprOp): """Deferred NE operator. This is used to support sugar that a != b can either - mean not NodeBase.same_as or make.NE. + mean not Object.same_as or make.NE. Parameters ---------- @@ -237,8 +237,8 @@ def __nonzero__(self): def __bool__(self): return self.__nonzero__() - def asnode(self): - """Convert node.""" + def asobject(self): + """Convert object.""" return _make._OpNE(self.a, self.b) @@ -246,7 +246,7 @@ class PrimExpr(ExprOp, NodeBase): """Base class of all tvm Expressions""" # In Python3, We have to explicitly tell interpreter to retain __hash__ if we overide __eq__ # https://docs.python.org/3.1/reference/datamodel.html#object.__hash__ - __hash__ = NodeBase.__hash__ + __hash__ = Object.__hash__ class ConstExpr(PrimExpr): @@ -261,7 +261,7 @@ class CmpExpr(PrimExpr): class LogicalExpr(PrimExpr): pass -@register_node("Variable") +@register_object("Variable") class Var(PrimExpr): """Symbolic variable. @@ -278,7 +278,7 @@ def __init__(self, name, dtype): _api_internal._Var, name, dtype) -@register_node +@register_object class Reduce(PrimExpr): """Reduce node. @@ -305,7 +305,7 @@ def __init__(self, combiner, src, rdom, condition, value_index): condition, value_index) -@register_node +@register_object class FloatImm(ConstExpr): """Float constant. @@ -321,7 +321,7 @@ def __init__(self, dtype, value): self.__init_handle_by_constructor__( _make.FloatImm, dtype, value) -@register_node +@register_object class IntImm(ConstExpr): """Int constant. @@ -341,7 +341,7 @@ def __int__(self): return self.value -@register_node +@register_object class UIntImm(ConstExpr): """UInt constant. @@ -358,7 +358,7 @@ def __init__(self, dtype, value): _make.UIntImm, dtype, value) -@register_node +@register_object class StringImm(ConstExpr): """String constant. @@ -382,7 +382,7 @@ def __ne__(self, other): return self.value != other -@register_node +@register_object class Cast(PrimExpr): """Cast expression. @@ -399,7 +399,7 @@ def __init__(self, dtype, value): _make.Cast, dtype, value) -@register_node +@register_object class Add(BinaryOpExpr): """Add node. @@ -416,7 +416,7 @@ def __init__(self, a, b): _make.Add, a, b) -@register_node +@register_object class Sub(BinaryOpExpr): """Sub node. @@ -433,7 +433,7 @@ def __init__(self, a, b): _make.Sub, a, b) -@register_node +@register_object class Mul(BinaryOpExpr): """Mul node. @@ -450,7 +450,7 @@ def __init__(self, a, b): _make.Mul, a, b) -@register_node +@register_object class Div(BinaryOpExpr): """Div node. @@ -467,7 +467,7 @@ def __init__(self, a, b): _make.Div, a, b) -@register_node +@register_object class Mod(BinaryOpExpr): """Mod node. @@ -484,7 +484,7 @@ def __init__(self, a, b): _make.Mod, a, b) -@register_node +@register_object class FloorDiv(BinaryOpExpr): """FloorDiv node. @@ -501,7 +501,7 @@ def __init__(self, a, b): _make.FloorDiv, a, b) -@register_node +@register_object class FloorMod(BinaryOpExpr): """FloorMod node. @@ -518,7 +518,7 @@ def __init__(self, a, b): _make.FloorMod, a, b) -@register_node +@register_object class Min(BinaryOpExpr): """Min node. @@ -535,7 +535,7 @@ def __init__(self, a, b): _make.Min, a, b) -@register_node +@register_object class Max(BinaryOpExpr): """Max node. @@ -552,7 +552,7 @@ def __init__(self, a, b): _make.Max, a, b) -@register_node +@register_object class EQ(CmpExpr): """EQ node. @@ -569,7 +569,7 @@ def __init__(self, a, b): _make.EQ, a, b) -@register_node +@register_object class NE(CmpExpr): """NE node. @@ -586,7 +586,7 @@ def __init__(self, a, b): _make.NE, a, b) -@register_node +@register_object class LT(CmpExpr): """LT node. @@ -603,7 +603,7 @@ def __init__(self, a, b): _make.LT, a, b) -@register_node +@register_object class LE(CmpExpr): """LE node. @@ -620,7 +620,7 @@ def __init__(self, a, b): _make.LE, a, b) -@register_node +@register_object class GT(CmpExpr): """GT node. @@ -637,7 +637,7 @@ def __init__(self, a, b): _make.GT, a, b) -@register_node +@register_object class GE(CmpExpr): """GE node. @@ -654,7 +654,7 @@ def __init__(self, a, b): _make.GE, a, b) -@register_node +@register_object class And(LogicalExpr): """And node. @@ -671,7 +671,7 @@ def __init__(self, a, b): _make.And, a, b) -@register_node +@register_object class Or(LogicalExpr): """Or node. @@ -688,7 +688,7 @@ def __init__(self, a, b): _make.Or, a, b) -@register_node +@register_object class Not(LogicalExpr): """Not node. @@ -702,7 +702,7 @@ def __init__(self, a): _make.Not, a) -@register_node +@register_object class Select(PrimExpr): """Select node. @@ -730,7 +730,7 @@ def __init__(self, condition, true_value, false_value): _make.Select, condition, true_value, false_value) -@register_node +@register_object class Load(PrimExpr): """Load node. @@ -753,7 +753,7 @@ def __init__(self, dtype, buffer_var, index, predicate): _make.Load, dtype, buffer_var, index, predicate) -@register_node +@register_object class Ramp(PrimExpr): """Ramp node. @@ -773,7 +773,7 @@ def __init__(self, base, stride, lanes): _make.Ramp, base, stride, lanes) -@register_node +@register_object class Broadcast(PrimExpr): """Broadcast node. @@ -790,7 +790,7 @@ def __init__(self, value, lanes): _make.Broadcast, value, lanes) -@register_node +@register_object class Shuffle(PrimExpr): """Shuffle node. @@ -807,7 +807,7 @@ def __init__(self, vectors, indices): _make.Shuffle, vectors, indices) -@register_node +@register_object class Call(PrimExpr): """Call node. @@ -842,7 +842,7 @@ def __init__(self, dtype, name, args, call_type, func, value_index): _make.Call, dtype, name, args, call_type, func, value_index) -@register_node +@register_object class Let(PrimExpr): """Let node. diff --git a/python/tvm/ir_builder.py b/python/tvm/ir_builder.py index bf41c98a7bdd..ede17a154285 100644 --- a/python/tvm/ir_builder.py +++ b/python/tvm/ir_builder.py @@ -24,7 +24,7 @@ from . import ir_pass as _pass from . import container as _container from ._ffi.base import string_types -from ._ffi.node import NodeGeneric +from ._ffi.object import ObjectGeneric from ._ffi.runtime_ctypes import TVMType from .expr import Call as _Call @@ -41,7 +41,7 @@ def __exit__(self, ptype, value, trace): self._exit_cb() -class BufferVar(NodeGeneric): +class BufferVar(ObjectGeneric): """Buffer variable with content type, makes load store easily. Do not create it directly, create use IRBuilder. @@ -70,7 +70,7 @@ def __init__(self, builder, buffer_var, content_type): self._buffer_var = buffer_var self._content_type = content_type - def asnode(self): + def asobject(self): return self._buffer_var @property diff --git a/python/tvm/node.py b/python/tvm/object.py similarity index 93% rename from python/tvm/node.py rename to python/tvm/object.py index 1d5b506fabe7..9659d3c89067 100644 --- a/python/tvm/node.py +++ b/python/tvm/object.py @@ -20,6 +20,4 @@ """ # pylint: disable=unused-import from __future__ import absolute_import as _abs -from ._ffi.node import NodeBase, register_node - -Node = NodeBase +from ._ffi.object import Object, register_object diff --git a/python/tvm/relay/_module.pyi b/python/tvm/relay/_module.pyi index ae2d199de257..66c994e4400e 100644 --- a/python/tvm/relay/_module.pyi +++ b/python/tvm/relay/_module.pyi @@ -16,7 +16,7 @@ # under the License. from typing import Union, Tuple, Dict, List -from relay.ir import GlobalId, OperatorId, Item, NodeBase, Span, FileId +from relay.ir import GlobalId, OperatorId, Item, Object, Span, FileId from relay.ir import ShapeExtension, Operator, Defn -class Module(NodeBase): ... +class Module(Object): ... diff --git a/python/tvm/relay/adt.py b/python/tvm/relay/adt.py index 30db22cf8314..7f7496b1a407 100644 --- a/python/tvm/relay/adt.py +++ b/python/tvm/relay/adt.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """Algebraic data types in Relay.""" -from .base import RelayNode, register_relay_node, NodeBase +from .base import RelayNode, register_relay_node, Object from . import _make from .ty import Type from .expr import Expr, Call @@ -184,7 +184,7 @@ def __init__(self, header, type_vars, constructors): @register_relay_node -class Clause(NodeBase): +class Clause(Object): """Clause for pattern matching in Relay.""" def __init__(self, lhs, rhs): diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index 6c690a9b71de..956ad55404bf 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -17,19 +17,19 @@ """Backend code generation engine.""" from __future__ import absolute_import -from ..base import register_relay_node, NodeBase +from ..base import register_relay_node, Object from ... import target as _target from .. import expr as _expr from . import _backend @register_relay_node -class CachedFunc(NodeBase): +class CachedFunc(Object): """Low-level tensor function to back a relay primitive function. """ @register_relay_node -class CCacheKey(NodeBase): +class CCacheKey(Object): """Key in the CompileEngine. Parameters @@ -46,7 +46,7 @@ def __init__(self, source_func, target): @register_relay_node -class CCacheValue(NodeBase): +class CCacheValue(Object): """Value in the CompileEngine, including usage statistics. """ @@ -64,7 +64,7 @@ def _get_cache_key(source_func, target): @register_relay_node -class CompileEngine(NodeBase): +class CompileEngine(Object): """CompileEngine to get lowered code. """ def __init__(self): diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index 128edfca0fe1..59d9a8fae43c 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -24,12 +24,12 @@ from .. import _make, analysis, transform from .. import module from ... import nd -from ..base import NodeBase, register_relay_node +from ..base import Object, register_relay_node from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const from ..scope_builder import ScopeBuilder @register_relay_node -class TupleValue(NodeBase): +class TupleValue(Object): """A tuple value produced by the interpreter.""" def __init__(self, *fields): self.__init_handle_by_constructor__( @@ -54,24 +54,24 @@ def __iter__(self): @register_relay_node -class Closure(NodeBase): +class Closure(Object): """A closure produced by the interpreter.""" @register_relay_node -class RecClosure(NodeBase): +class RecClosure(Object): """A recursive closure produced by the interpreter.""" @register_relay_node -class ConstructorValue(NodeBase): +class ConstructorValue(Object): def __init__(self, tag, fields, constructor): self.__init_handle_by_constructor__( _make.ConstructorValue, tag, fields, constructor) @register_relay_node -class RefValue(NodeBase): +class RefValue(Object): def __init__(self, value): self.__init_handle_by_constructor__( _make.RefValue, value) @@ -189,7 +189,7 @@ def evaluate(self, expr=None, binds=None): Returns ------- - val : Union[function, NodeBase] + val : Union[function, Object] The evaluation result. """ if binds: diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py index eb604a405410..d389803bfeea 100644 --- a/python/tvm/relay/base.py +++ b/python/tvm/relay/base.py @@ -17,12 +17,13 @@ # pylint: disable=no-else-return, unidiomatic-typecheck """The base node types for the Relay language.""" from __future__ import absolute_import as _abs -from .._ffi.node import NodeBase, register_node as _register_tvm_node +from .._ffi.object import register_object as _register_tvm_node +from .._ffi.object import Object from . import _make from . import _expr from . import _base -NodeBase = NodeBase +Object = Object def register_relay_node(type_key=None): """Register a Relay node type. @@ -52,7 +53,7 @@ def register_relay_attr_node(type_key=None): return _register_tvm_node(type_key) -class RelayNode(NodeBase): +class RelayNode(Object): """Base class of all Relay nodes.""" def astext(self, show_meta_data=True, annotate=None): """Get the text format of the expression. @@ -102,7 +103,7 @@ def __init__(self, name): self.__init_handle_by_constructor__(_make.SourceName, name) @register_relay_node -class Id(NodeBase): +class Id(Object): """Unique identifier(name) used in Var. Guaranteed to be stable across all passes. """ diff --git a/python/tvm/relay/expr.pyi b/python/tvm/relay/expr.pyi index d264e99e0577..d2d01720f5ff 100644 --- a/python/tvm/relay/expr.pyi +++ b/python/tvm/relay/expr.pyi @@ -17,12 +17,12 @@ from typing import List import tvm -from .base import Span, NodeBase +from .base import Span, Object from .ty import Type, TypeParam from ._analysis import _get_checked_type -class Expr(NodeBase): +class Expr(Object): def checked_type(self): ... diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index ac5387cf2512..a9d877cecd51 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -22,7 +22,7 @@ from .. import expr as _expr from .. import transform as _transform from ... import make as _make -from ..base import NodeBase, register_relay_node +from ..base import Object, register_relay_node class QAnnotateKind(object): @@ -53,7 +53,7 @@ def _forward_op(ref_call, args): @register_relay_node("relay.quantize.QConfig") -class QConfig(NodeBase): +class QConfig(Object): """Configure the quantization behavior by setting config variables. Note diff --git a/python/tvm/relay/transform.pyi b/python/tvm/relay/transform.pyi index 343e89976b09..2c466b0576a7 100644 --- a/python/tvm/relay/transform.pyi +++ b/python/tvm/relay/transform.pyi @@ -16,14 +16,14 @@ # under the License. import tvm -from .base import NodeBase +from .base import Object -class PassContext(NodeBase): +class PassContext(Object): def __init__(self): ... -class PassInfo(NodeBase): +class PassInfo(Object): name = ... # type: str opt_level = ... # type: int required = ... # type: list @@ -32,7 +32,7 @@ class PassInfo(NodeBase): # type: (str, int, list) -> None -class Pass(NodeBase): +class Pass(Object): def __init__(self): ... diff --git a/python/tvm/relay/ty.pyi b/python/tvm/relay/ty.pyi index 5a7ecffb372e..cde851160167 100644 --- a/python/tvm/relay/ty.pyi +++ b/python/tvm/relay/ty.pyi @@ -18,11 +18,11 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """The type nodes of the Relay language.""" from enum import IntEnum -from .base import NodeBase, register_relay_node +from .base import Object, register_relay_node from . import _make -class Type(NodeBase): +class Type(Object): """The base type for all Relay types.""" def __eq__(self, other): diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index 6b577c456fac..c8fcd7cbd52d 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -17,8 +17,8 @@ """The computation schedule api of TVM.""" from __future__ import absolute_import as _abs from ._ffi.base import string_types -from ._ffi.node import NodeBase, register_node -from ._ffi.node import convert_to_node as _convert_to_node +from ._ffi.object import Object, register_object +from ._ffi.object import convert_to_object as _convert_to_object from ._ffi.function import _init_api, Function from ._ffi.function import convert_to_tvm_func as _convert_tvm_func from . import _api_internal @@ -27,7 +27,7 @@ from . import container as _container def convert(value): - """Convert value to TVM node or function. + """Convert value to TVM object or function. Parameters ---------- @@ -35,19 +35,19 @@ def convert(value): Returns ------- - tvm_val : Node or Function + tvm_val : Object or Function Converted value in TVM """ - if isinstance(value, (Function, NodeBase)): + if isinstance(value, (Function, Object)): return value if callable(value): return _convert_tvm_func(value) - return _convert_to_node(value) + return _convert_to_object(value) -@register_node -class Buffer(NodeBase): +@register_object +class Buffer(Object): """Symbolic data buffer in TVM. Buffer provide a way to represent data layout @@ -156,23 +156,23 @@ def vstore(self, begin, value): return _api_internal._BufferVStore(self, begin, value) -@register_node -class Split(NodeBase): +@register_object +class Split(Object): """Split operation on axis.""" -@register_node -class Fuse(NodeBase): +@register_object +class Fuse(Object): """Fuse operation on axis.""" -@register_node -class Singleton(NodeBase): +@register_object +class Singleton(Object): """Singleton axis.""" -@register_node -class IterVar(NodeBase, _expr.ExprOp): +@register_object +class IterVar(Object, _expr.ExprOp): """Represent iteration variable. IterVar is normally created by Operation, to represent @@ -214,8 +214,8 @@ def create_schedule(ops): return _api_internal._CreateSchedule(ops) -@register_node -class Schedule(NodeBase): +@register_object +class Schedule(Object): """Schedule for all the stages.""" def __getitem__(self, k): if isinstance(k, _tensor.Tensor): @@ -348,8 +348,8 @@ def rfactor(self, tensor, axis, factor_axis=0): return factored[0] if len(factored) == 1 else factored -@register_node -class Stage(NodeBase): +@register_object +class Stage(Object): """A Stage represents schedule for one operation.""" def split(self, parent, factor=None, nparts=None): """Split the stage either by factor providing outer scope, or both diff --git a/python/tvm/stmt.py b/python/tvm/stmt.py index 64628d1d4198..6b87fcb1b885 100644 --- a/python/tvm/stmt.py +++ b/python/tvm/stmt.py @@ -30,14 +30,14 @@ assert(st.buffer_var == a) """ from __future__ import absolute_import as _abs -from ._ffi.node import NodeBase, register_node +from ._ffi.object import Object, register_object from . import make as _make -class Stmt(NodeBase): +class Stmt(Object): pass -@register_node +@register_object class LetStmt(Stmt): """LetStmt node. @@ -57,7 +57,7 @@ def __init__(self, var, value, body): _make.LetStmt, var, value, body) -@register_node +@register_object class AssertStmt(Stmt): """AssertStmt node. @@ -77,7 +77,7 @@ def __init__(self, condition, message, body): _make.AssertStmt, condition, message, body) -@register_node +@register_object class ProducerConsumer(Stmt): """ProducerConsumer node. @@ -97,7 +97,7 @@ def __init__(self, func, is_producer, body): _make.ProducerConsumer, func, is_producer, body) -@register_node +@register_object class For(Stmt): """For node. @@ -137,7 +137,7 @@ def __init__(self, for_type, device_api, body) -@register_node +@register_object class Store(Stmt): """Store node. @@ -160,7 +160,7 @@ def __init__(self, buffer_var, value, index, predicate): _make.Store, buffer_var, value, index, predicate) -@register_node +@register_object class Provide(Stmt): """Provide node. @@ -183,7 +183,7 @@ def __init__(self, func, value_index, value, args): _make.Provide, func, value_index, value, args) -@register_node +@register_object class Allocate(Stmt): """Allocate node. @@ -215,7 +215,7 @@ def __init__(self, extents, condition, body) -@register_node +@register_object class AttrStmt(Stmt): """AttrStmt node. @@ -238,7 +238,7 @@ def __init__(self, node, attr_key, value, body): _make.AttrStmt, node, attr_key, value, body) -@register_node +@register_object class Free(Stmt): """Free node. @@ -252,7 +252,7 @@ def __init__(self, buffer_var): _make.Free, buffer_var) -@register_node +@register_object class Realize(Stmt): """Realize node. @@ -288,7 +288,7 @@ def __init__(self, bounds, condition, body) -@register_node +@register_object class SeqStmt(Stmt): """Sequence of statements. @@ -308,7 +308,7 @@ def __len__(self): return len(self.seq) -@register_node +@register_object class IfThenElse(Stmt): """IfThenElse node. @@ -328,7 +328,7 @@ def __init__(self, condition, then_case, else_case): _make.IfThenElse, condition, then_case, else_case) -@register_node +@register_object class Evaluate(Stmt): """Evaluate node. @@ -342,7 +342,7 @@ def __init__(self, value): _make.Evaluate, value) -@register_node +@register_object class Prefetch(Stmt): """Prefetch node. diff --git a/python/tvm/target.py b/python/tvm/target.py index afddd5f1fd59..c2d37529040b 100644 --- a/python/tvm/target.py +++ b/python/tvm/target.py @@ -59,7 +59,7 @@ import warnings from ._ffi.base import _LIB_NAME -from ._ffi.node import NodeBase, register_node +from ._ffi.object import Object, register_object from . import _api_internal try: @@ -80,8 +80,8 @@ def _merge_opts(opts, new_opts): return opts -@register_node -class Target(NodeBase): +@register_object +class Target(Object): """Target device information, use through TVM API. Note @@ -97,7 +97,7 @@ class Target(NodeBase): """ def __new__(cls): # Always override new to enable class - obj = NodeBase.__new__(cls) + obj = Object.__new__(cls) obj._keys = None obj._options = None obj._libs = None @@ -146,8 +146,8 @@ def __exit__(self, ptype, value, trace): _api_internal._ExitTargetScope(self) -@register_node -class GenericFunc(NodeBase): +@register_object +class GenericFunc(Object): """GenericFunc node reference. This represents a generic function that may be specialized for different targets. When this object is called, a specialization is chosen based on the current target. diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py index e4a2f4f76e7b..e4c36c11120b 100644 --- a/python/tvm/tensor.py +++ b/python/tvm/tensor.py @@ -17,13 +17,14 @@ """Tensor and Operation class for computation declaration.""" # pylint: disable=invalid-name from __future__ import absolute_import as _abs -from ._ffi.node import NodeBase, NodeGeneric, register_node, convert_to_node +from ._ffi.object import Object, register_object, ObjectGeneric, \ + convert_to_object from . import _api_internal from . import make as _make from . import expr as _expr -class TensorSlice(NodeGeneric, _expr.ExprOp): +class TensorSlice(ObjectGeneric, _expr.ExprOp): """Auxiliary data structure for enable slicing syntax from tensor.""" def __init__(self, tensor, indices): @@ -37,8 +38,8 @@ def __getitem__(self, indices): indices = (indices,) return TensorSlice(self.tensor, self.indices + indices) - def asnode(self): - """Convert slice to node.""" + def asobject(self): + """Convert slice to object.""" return self.tensor(*self.indices) @property @@ -46,23 +47,23 @@ def dtype(self): """Data content of the tensor.""" return self.tensor.dtype -@register_node -class TensorIntrinCall(NodeBase): +@register_object +class TensorIntrinCall(Object): """Intermediate structure for calling a tensor intrinsic.""" itervar_cls = None -@register_node -class Tensor(NodeBase, _expr.ExprOp): +@register_object +class Tensor(Object, _expr.ExprOp): """Tensor object, to construct, see function.Tensor""" def __call__(self, *indices): ndim = self.ndim if len(indices) != ndim: raise ValueError("Need to provide %d index in tensor slice" % ndim) - indices = convert_to_node(indices) + indices = convert_to_object(indices) args = [] for x in indices: if isinstance(x, _expr.PrimExpr): @@ -127,7 +128,7 @@ def name(self): -class Operation(NodeBase): +class Operation(Object): """Represent an operation that generates a tensor""" def output(self, index): @@ -156,12 +157,12 @@ def input_tensors(self): return _api_internal._OpInputTensors(self) -@register_node +@register_object class PlaceholderOp(Operation): """Placeholder operation.""" -@register_node +@register_object class BaseComputeOp(Operation): """Compute operation.""" @property @@ -175,18 +176,18 @@ def reduce_axis(self): return self.__getattr__("reduce_axis") -@register_node +@register_object class ComputeOp(BaseComputeOp): """Scalar operation.""" pass -@register_node +@register_object class TensorComputeOp(BaseComputeOp): """Tensor operation.""" -@register_node +@register_object class ScanOp(Operation): """Scan operation.""" @property @@ -195,12 +196,12 @@ def scan_axis(self): return self.__getattr__("scan_axis") -@register_node +@register_object class ExternOp(Operation): """External operation.""" -@register_node +@register_object class HybridOp(Operation): """Hybrid operation.""" @property @@ -209,8 +210,8 @@ def axis(self): return self.__getattr__("axis") -@register_node -class Layout(NodeBase): +@register_object +class Layout(Object): """Layout is composed of upper cases, lower cases and numbers, where upper case indicates a primal axis and the corresponding lower case with factor size indicates the subordinate axis. @@ -269,8 +270,8 @@ def factor_of(self, axis): return _api_internal._LayoutFactorOf(self, axis) -@register_node -class BijectiveLayout(NodeBase): +@register_object +class BijectiveLayout(Object): """Bijective mapping for two layouts (src-layout and dst-layout). It provides shape and index conversion between each other. diff --git a/python/tvm/tensor_intrin.py b/python/tvm/tensor_intrin.py index 378cfe51a7b7..4665ccfd6204 100644 --- a/python/tvm/tensor_intrin.py +++ b/python/tvm/tensor_intrin.py @@ -24,7 +24,7 @@ from . import tensor as _tensor from . import schedule as _schedule from .build_module import current_build_config -from ._ffi.node import NodeBase, register_node +from ._ffi.object import Object, register_object def _get_region(tslice): @@ -41,8 +41,8 @@ def _get_region(tslice): region.append(_make.range_by_min_extent(begin, 1)) return region -@register_node -class TensorIntrin(NodeBase): +@register_object +class TensorIntrin(Object): """Tensor intrinsic functions for certain computation. See Also diff --git a/tests/python/unittest/test_pass_inject_double_buffer.py b/tests/python/unittest/test_pass_inject_double_buffer.py index dc517e2ee28b..aa569cea8665 100644 --- a/tests/python/unittest/test_pass_inject_double_buffer.py +++ b/tests/python/unittest/test_pass_inject_double_buffer.py @@ -28,7 +28,7 @@ def test_double_buffer(): with ib.for_range(0, n) as i: B = ib.allocate("float32", m, name="B", scope="shared") with ib.new_scope(): - ib.scope_attr(B.asnode(), "double_buffer_scope", 1) + ib.scope_attr(B.asobject(), "double_buffer_scope", 1) with ib.for_range(0, m) as j: B[j] = A[i * 4 + j] with ib.for_range(0, m) as j: @@ -39,7 +39,7 @@ def test_double_buffer(): stmt = tvm.ir_pass.Simplify(stmt) assert isinstance(stmt.body.body, tvm.stmt.Allocate) assert stmt.body.body.extents[0].value == 2 - f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asnode(), C.asnode()], 2, True) + f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True) f = tvm.ir_pass.ThreadSync(f, "shared") count = [0] def count_sync(op): diff --git a/tests/python/unittest/test_pass_inject_vthread.py b/tests/python/unittest/test_pass_inject_vthread.py index a3d059787ab8..08e261b68f6d 100644 --- a/tests/python/unittest/test_pass_inject_vthread.py +++ b/tests/python/unittest/test_pass_inject_vthread.py @@ -32,7 +32,7 @@ def get_vthread(name): ib.scope_attr(ty, "virtual_thread", nthread) B = ib.allocate("float32", m, name="B", scope="shared") B[i] = A[i * nthread + tx] - bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asnode()) + bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asobject()) ib.emit(tvm.call_extern("int32", "Run", bbuffer.access_ptr("r"), tvm.call_pure_intrin("int32", "tvm_context_id"))) @@ -60,9 +60,9 @@ def get_vthread(name): A = ib.allocate("float32", m, name="A", scope="shared") B = ib.allocate("float32", m, name="B", scope="shared") C = ib.allocate("float32", m, name="C", scope="shared") - cbuffer = tvm.decl_buffer((m,), dtype=C.dtype, data=C.asnode()) - abuffer = tvm.decl_buffer((m,), dtype=A.dtype, data=A.asnode()) - bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asnode()) + cbuffer = tvm.decl_buffer((m,), dtype=C.dtype, data=C.asobject()) + abuffer = tvm.decl_buffer((m,), dtype=A.dtype, data=A.asobject()) + bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asobject()) A[tx] = tx + 1.0 B[ty] = ty + 1.0 ib.emit(tvm.call_extern("int32", "Run", diff --git a/tests/python/unittest/test_pass_storage_flatten.py b/tests/python/unittest/test_pass_storage_flatten.py index 02edfe7d3261..da32f60f69fb 100644 --- a/tests/python/unittest/test_pass_storage_flatten.py +++ b/tests/python/unittest/test_pass_storage_flatten.py @@ -79,7 +79,7 @@ def test_flatten_double_buffer(): with ib.for_range(0, n) as i: B = ib.allocate("float32", m, name="B", scope="shared") with ib.new_scope(): - ib.scope_attr(B.asnode(), "double_buffer_scope", 1) + ib.scope_attr(B.asobject(), "double_buffer_scope", 1) with ib.for_range(0, m) as j: B[j] = A[i * 4 + j] with ib.for_range(0, m) as j: @@ -91,7 +91,7 @@ def test_flatten_double_buffer(): stmt = tvm.ir_pass.Simplify(stmt) assert isinstance(stmt.body.body, tvm.stmt.Allocate) assert stmt.body.body.extents[0].value == 2 - f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asnode(), C.asnode()], 2, True) + f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True) f = tvm.ir_pass.ThreadSync(f, "shared") count = [0] def count_sync(op):