Skip to content

Commit

Permalink
[REFACTOR][PY][API-CHANGE] establish tvm.ir, migrate corresponding fi…
Browse files Browse the repository at this point in the history
…les (#4862)

* [REFACTOR][PY][API-CHANGE] establish tvm.ir, migrate corresponding relay files.

This PR establishes tvm.ir and migrates the corresponding relay
files into the new folder.

API Change:
- relay.Module -> tvm.IRModule

* Update with ADT

* Migrate transform

* address comments

* Migrate module

* Migrate json_compact

* Migrate attrs

* Move LoweredFunc to stmt temporarily

* temp migrate container

* Finish migrate container
  • Loading branch information
tqchen committed Feb 12, 2020
1 parent 15df204 commit a566161
Show file tree
Hide file tree
Showing 187 changed files with 1,818 additions and 1,976 deletions.
4 changes: 2 additions & 2 deletions apps/benchmark/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_network(name, batch_size, dtype='float32'):
Returns
-------
net: relay.Module
net: tvm.IRModule
The relay function of network definition
params: dict
The random parameters for benchmark
Expand Down Expand Up @@ -70,7 +70,7 @@ def get_network(name, batch_size, dtype='float32'):
net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
net = net["main"]
net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
net = relay.Module.from_expr(net)
net = tvm.IRModule.from_expr(net)
else:
raise ValueError("Unsupported network: " + name)

Expand Down
5 changes: 1 addition & 4 deletions docs/api/python/tvm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ The user facing API for computation declaration.

.. autosummary::

tvm.load_json
tvm.save_json
tvm.var
tvm.size_var
tvm.const
Expand All @@ -47,8 +45,7 @@ The user facing API for computation declaration.
tvm.max
tvm.tag_scope

.. autofunction:: tvm.load_json
.. autofunction:: tvm.save_json

.. autofunction:: tvm.var
.. autofunction:: tvm.size_var
.. autofunction:: tvm.const
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class RelayExpr : public BaseExpr {

class GlobalVar;
/*!
* \brief Global variable that leaves in the top-level module.
* \brief Global variable that lives in the top-level module.
*
* A GlobalVar only refers to function definitions.
* This is used to enable recursive calls between function.
Expand Down
7 changes: 4 additions & 3 deletions include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,12 @@ enum TypeKind : int {
};

/*!
* \brief Type parameter in the function.
* This can be viewed as template parameter in c++ template function.
* \brief Type parameter in functions.
*
* A type variable can be viewed as template parameter in c++ template function.
*
* For example, in the following pesudo code,
* the TypeVar of f is TypeVar(kind=kShapeVar, var=n).
* the TypeVar of f is TypeVar("n", kind=kShapeVar).
* This function can take in a Tensor with shape=(3, 3) and
* returns a Tensor with shape=(9,)
*
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/ir/type_relation.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ using TypeRelationFn =
const TypeReporter& reporter)>;

/*!
* \brief User defined type relation, is an input-output relation on types.
* \brief User defined type relation, it is an input-output relation on types.
*
* TypeRelation is more generalized than type call as it allows inference
* of both inputs and outputs.
Expand Down
8 changes: 6 additions & 2 deletions python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@
from .runtime.ndarray import vpi, rocm, opengl, ext_dev, micro_dev
from .runtime import ndarray as nd

# tvm.ir
from .ir import IRModule
from .ir import transform
from .ir import container
from . import ir

# others
from . import tensor
from . import arith
Expand All @@ -41,10 +47,8 @@
from . import make
from . import ir_pass
from . import codegen
from . import container
from . import schedule

from . import attrs
from . import ir_builder
from . import target
from . import generic
Expand Down
1 change: 1 addition & 0 deletions python/tvm/_ffi/_ctypes/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init_handle_by_constructor__(self, fconstructor, *args):
instead of creating a new Node.
"""
# assign handle first to avoid error raising
# pylint: disable=not-callable
self.handle = None
handle = __init_by_constructor__(fconstructor, args)
if not isinstance(handle, ObjectHandle):
Expand Down
68 changes: 4 additions & 64 deletions python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
from numbers import Integral as _Integral

import tvm._ffi
import tvm.runtime._ffi_node_api
import tvm.ir

from tvm.runtime import convert, const, DataType
from tvm.ir import container as _container

from ._ffi.base import string_types, TVMError
from ._ffi.registry import register_func, get_global_func, extract_ext_funcs

Expand All @@ -30,9 +32,7 @@
from . import expr as _expr
from . import tensor as _tensor
from . import schedule as _schedule
from . import container as _container
from . import tag as _tag
from . import json_compact

int8 = "int8"
int32 = "int32"
Expand Down Expand Up @@ -71,66 +71,6 @@ def max_value(dtype):
"""
return _api_internal._max_value(dtype)


def get_env_func(name):
"""Get an EnvFunc by a global name.
Parameters
----------
name: str
The name of the global function.
Returns
-------
env_func : EnvFunc
The result env function.
Note
----
EnvFunc is a Object wrapper around
global function that can be serialized via its name.
This can be used to serialize function field in the language.
"""
return _api_internal._EnvFuncGet(name)


def load_json(json_str):
"""Load tvm object from json_str.
Parameters
----------
json_str : str
The json string
Returns
-------
node : Object
The loaded tvm node.
"""

try:
return tvm.runtime._ffi_node_api.LoadJSON(json_str)
except TVMError:
json_str = json_compact.upgrade_json(json_str)
return tvm.runtime._ffi_node_api.LoadJSON(json_str)


def save_json(node):
"""Save tvm object as json string.
Parameters
----------
node : Object
A TVM object to be saved.
Returns
-------
json_str : str
Saved json string.
"""
return tvm.runtime._ffi_node_api.SaveJSON(node)


def var(name="tindex", dtype=int32):
"""Create a new variable with specified name and dtype
Expand Down Expand Up @@ -688,7 +628,7 @@ def _IterVar(dom, name, iter_type, thread_tag=''):
raise TypeError("need to be list of ranges")
dom = Range(dom[0], dom[1])

if not isinstance(dom, _container.Range):
if not isinstance(dom, tvm.ir.Range):
raise TypeError("dom need to be Range")
name = name if name else 'iter'
v = var(name)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/autotvm/graph_tuner/base_graph_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __init__(self, graph, input_shapes, records, target_ops,
self._logger.propagate = False

# Generate workload and schedule dictionaries.
if isinstance(graph, relay.Module):
if isinstance(graph, tvm.IRModule):
graph = graph["main"]

if isinstance(graph, relay.expr.Function):
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import topi

import tvm
from tvm import relay, autotvm
from tvm.relay import transform
from tvm.relay.expr import Call, Function, TupleGetItem, Var, Constant, Tuple
Expand Down Expand Up @@ -83,7 +84,7 @@ def expr2graph(expr, target_ops, node_dict, node_list):

def _infer_type(node):
"""A method to infer the type of a relay expression."""
mod = relay.Module.from_expr(node)
mod = tvm.IRModule.from_expr(node)
mod = transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(node, relay.Function) else entry.body
Expand Down Expand Up @@ -136,7 +137,7 @@ def _traverse_expr(node):
free_var = relay.Var("var_%d" % i, input_type)
params.append(free_var)
call = relay.Call(node.op, params, node.attrs)
mod = relay.Module.from_expr(relay.Function(params, call))
mod = tvm.IRModule.from_expr(relay.Function(params, call))
relay.backend.compile_engine.get().clear()
build_thread = threading.Thread(target=relay.build,
args=(mod,
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/autotvm/graph_tuner/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=eval-used,invalid-name,too-many-arguments
"""Utility functions"""
import tvm
from tvm import relay
from tvm.relay import transform

Expand Down Expand Up @@ -136,7 +137,7 @@ def bind_inputs(expr, input_shapes=None, input_dtypes="float32"):
rebind_dict[var] = updated_input_dict[var.name_hint]
updated_expr = relay.expr.bind(expr, rebind_dict)

mod = relay.Module.from_expr(updated_expr)
mod = tvm.IRModule.from_expr(updated_expr)
mod = transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(updated_expr, relay.Function) else entry.body
8 changes: 4 additions & 4 deletions python/tvm/autotvm/task/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def extract_from_program(mod, params, ops, target, target_host=None,
Parameters
----------
mod: relay.module.Module or relay.expr.Function
mod: tvm.IRModule or relay.expr.Function
The module or function to tune
params: dict of str to numpy array
The associated parameters of the program
Expand Down Expand Up @@ -95,7 +95,7 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None,
Parameters
----------
mods: List[relay.module.Module] or List[relay.expr.Function]
mods: List[tvm.IRModule] or List[relay.expr.Function]
The list of modules or functions to tune
params: List of dict of str to numpy array
The associated parameters of the programs
Expand Down Expand Up @@ -151,8 +151,8 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None,

for mod, param in zip(mods, params):
if isinstance(mod, relay.expr.Function):
mod = relay.Module.from_expr(mod)
assert isinstance(mod, relay.module.Module), \
mod = tvm.IRModule.from_expr(mod)
assert isinstance(mod, tvm.IRModule), \
"only support relay Module or Function to be tuned"
relay.backend.compile_engine.get().clear()
# wrap build call in thread to avoid multiprocessing problems
Expand Down
22 changes: 12 additions & 10 deletions python/tvm/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,19 @@
import tvm.runtime

from tvm.runtime import Object, ndarray
from tvm.ir import container
from . import api
from . import _api_internal
from . import tensor
from . import schedule
from . import expr
from . import ir_pass
from . import stmt as _stmt
from . import container
from . import codegen
from . import target as _target
from . import make
from .stmt import LoweredFunc


class DumpIR(object):
"""
Expand All @@ -58,16 +60,16 @@ def decorate(self, func):
def dump(*args, **kwargs):
"""dump function"""
retv = func(*args, **kwargs)
if not isinstance(retv, (_stmt.Stmt, container.LoweredFunc, container.Array)):
if not isinstance(retv, (_stmt.Stmt, LoweredFunc, container.Array)):
return retv
fname = func.func_name if hasattr(func, 'func_name') else func.__name__
pname = str(self._pass_id) + "_" + fname + "_ir.cc"
with open(pname, "a") as f:
out = retv.body if isinstance(retv, container.LoweredFunc) else retv
out = retv.body if isinstance(retv, LoweredFunc) else retv
f.write(str(out))
if isinstance(retv, container.Array):
for x in retv:
out = x.body if isinstance(x, container.LoweredFunc) else x
out = x.body if isinstance(x, LoweredFunc) else x
f.write("---------%s\n%s\n-----------\n"%(x.name, str(out)))
self._pass_id += 1
return retv
Expand Down Expand Up @@ -459,7 +461,7 @@ def _build_for_device(flist, target, target_host):
raise ValueError(
"Direct host side access to device memory is detected in %s. "
"Did you forget to bind?" % func.name)
if func.func_type == container.LoweredFunc.MixedFunc:
if func.func_type == LoweredFunc.MixedFunc:
if current_build_config().detect_global_barrier:
func = ir_pass.ThreadSync(func, "global")
func = ir_pass.ThreadSync(func, "shared")
Expand All @@ -471,9 +473,9 @@ def _build_for_device(flist, target, target_host):
fhost.append(fsplits[0])
for x in fsplits[1:]:
fdevice.append(x)
elif func.func_type == container.LoweredFunc.HostFunc:
elif func.func_type == LoweredFunc.HostFunc:
fhost.append(func)
elif func.func_type == container.LoweredFunc.DeviceFunc:
elif func.func_type == LoweredFunc.DeviceFunc:
fdevice.append(func)
else:
raise ValueError("unknown function type %d" % func.func_type)
Expand Down Expand Up @@ -586,9 +588,9 @@ def build(inputs,
flist = lower(inputs, args,
name=name,
binds=binds)
if isinstance(flist, container.LoweredFunc):
if isinstance(flist, LoweredFunc):
flist = [flist]
elif isinstance(inputs, container.LoweredFunc):
elif isinstance(inputs, LoweredFunc):
if args:
raise ValueError("args must be done when build from LoweredFunc.")
flist = [inputs]
Expand All @@ -612,7 +614,7 @@ def build(inputs,
"_target.Target when inputs is dict.")
fname_set = set()
for x in flist:
if not isinstance(x, container.LoweredFunc):
if not isinstance(x, LoweredFunc):
raise ValueError("inputs must be Schedule, LoweredFunc, list "
"of LoweredFunc, or dict of str to list of "
"LoweredFunc.")
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/contrib/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, arg1, ctx=None, shape=None):
The corresponding a dense numpy array,
or a tuple for constructing a sparse matrix directly.
ctx: tvm.TVMContext
ctx: tvmContext
The corresponding context.
shape : tuple of int
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/hybrid/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@
# under the License.
"""Intrinsics of TVM-Python Hybrid Script for Python compilation time
semantic support."""

from tvm.ir.container import Array
from .. import api as _api
from .. import expr as _expr
from .. import make as _make
from .. import target as _tgt
from ..container import Array
from .. import ir_pass
from ..stmt import For
from .util import _internal_assert
Expand Down
Loading

0 comments on commit a566161

Please sign in to comment.