Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR][PY][API-CHANGE] establish tvm.ir, migrate corresponding files #4862

Merged
merged 10 commits into from
Feb 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions apps/benchmark/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_network(name, batch_size, dtype='float32'):

Returns
-------
net: relay.Module
net: tvm.IRModule
The relay function of network definition
params: dict
The random parameters for benchmark
Expand Down Expand Up @@ -70,7 +70,7 @@ def get_network(name, batch_size, dtype='float32'):
net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
net = net["main"]
net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
net = relay.Module.from_expr(net)
net = tvm.IRModule.from_expr(net)
else:
raise ValueError("Unsupported network: " + name)

Expand Down
5 changes: 1 addition & 4 deletions docs/api/python/tvm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ The user facing API for computation declaration.

.. autosummary::

tvm.load_json
tvm.save_json
tvm.var
tvm.size_var
tvm.const
Expand All @@ -47,8 +45,7 @@ The user facing API for computation declaration.
tvm.max
tvm.tag_scope

.. autofunction:: tvm.load_json
.. autofunction:: tvm.save_json

.. autofunction:: tvm.var
.. autofunction:: tvm.size_var
.. autofunction:: tvm.const
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class RelayExpr : public BaseExpr {

class GlobalVar;
/*!
* \brief Global variable that leaves in the top-level module.
* \brief Global variable that lives in the top-level module.
*
* A GlobalVar only refers to function definitions.
* This is used to enable recursive calls between function.
Expand Down
7 changes: 4 additions & 3 deletions include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,12 @@ enum TypeKind : int {
};

/*!
* \brief Type parameter in the function.
* This can be viewed as template parameter in c++ template function.
* \brief Type parameter in functions.
*
* A type variable can be viewed as template parameter in c++ template function.
*
* For example, in the following pesudo code,
* the TypeVar of f is TypeVar(kind=kShapeVar, var=n).
* the TypeVar of f is TypeVar("n", kind=kShapeVar).
* This function can take in a Tensor with shape=(3, 3) and
* returns a Tensor with shape=(9,)
*
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/ir/type_relation.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ using TypeRelationFn =
const TypeReporter& reporter)>;

/*!
* \brief User defined type relation, is an input-output relation on types.
* \brief User defined type relation, it is an input-output relation on types.
*
* TypeRelation is more generalized than type call as it allows inference
* of both inputs and outputs.
Expand Down
8 changes: 6 additions & 2 deletions python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@
from .runtime.ndarray import vpi, rocm, opengl, ext_dev, micro_dev
from .runtime import ndarray as nd

# tvm.ir
from .ir import IRModule
from .ir import transform
from .ir import container
from . import ir

# others
from . import tensor
from . import arith
Expand All @@ -41,10 +47,8 @@
from . import make
from . import ir_pass
from . import codegen
from . import container
from . import schedule

from . import attrs
from . import ir_builder
from . import target
from . import generic
Expand Down
1 change: 1 addition & 0 deletions python/tvm/_ffi/_ctypes/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init_handle_by_constructor__(self, fconstructor, *args):
instead of creating a new Node.
"""
# assign handle first to avoid error raising
# pylint: disable=not-callable
self.handle = None
handle = __init_by_constructor__(fconstructor, args)
if not isinstance(handle, ObjectHandle):
Expand Down
68 changes: 4 additions & 64 deletions python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
from numbers import Integral as _Integral

import tvm._ffi
import tvm.runtime._ffi_node_api
import tvm.ir

from tvm.runtime import convert, const, DataType
from tvm.ir import container as _container

from ._ffi.base import string_types, TVMError
from ._ffi.registry import register_func, get_global_func, extract_ext_funcs

Expand All @@ -30,9 +32,7 @@
from . import expr as _expr
from . import tensor as _tensor
from . import schedule as _schedule
from . import container as _container
from . import tag as _tag
from . import json_compact

int8 = "int8"
int32 = "int32"
Expand Down Expand Up @@ -71,66 +71,6 @@ def max_value(dtype):
"""
return _api_internal._max_value(dtype)


def get_env_func(name):
"""Get an EnvFunc by a global name.

Parameters
----------
name: str
The name of the global function.

Returns
-------
env_func : EnvFunc
The result env function.

Note
----
EnvFunc is a Object wrapper around
global function that can be serialized via its name.
This can be used to serialize function field in the language.
"""
return _api_internal._EnvFuncGet(name)


def load_json(json_str):
"""Load tvm object from json_str.

Parameters
----------
json_str : str
The json string

Returns
-------
node : Object
The loaded tvm node.
"""

try:
return tvm.runtime._ffi_node_api.LoadJSON(json_str)
except TVMError:
json_str = json_compact.upgrade_json(json_str)
return tvm.runtime._ffi_node_api.LoadJSON(json_str)


def save_json(node):
"""Save tvm object as json string.

Parameters
----------
node : Object
A TVM object to be saved.

Returns
-------
json_str : str
Saved json string.
"""
return tvm.runtime._ffi_node_api.SaveJSON(node)


def var(name="tindex", dtype=int32):
"""Create a new variable with specified name and dtype

Expand Down Expand Up @@ -688,7 +628,7 @@ def _IterVar(dom, name, iter_type, thread_tag=''):
raise TypeError("need to be list of ranges")
dom = Range(dom[0], dom[1])

if not isinstance(dom, _container.Range):
if not isinstance(dom, tvm.ir.Range):
raise TypeError("dom need to be Range")
name = name if name else 'iter'
v = var(name)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/autotvm/graph_tuner/base_graph_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __init__(self, graph, input_shapes, records, target_ops,
self._logger.propagate = False

# Generate workload and schedule dictionaries.
if isinstance(graph, relay.Module):
if isinstance(graph, tvm.IRModule):
graph = graph["main"]

if isinstance(graph, relay.expr.Function):
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import topi

import tvm
from tvm import relay, autotvm
from tvm.relay import transform
from tvm.relay.expr import Call, Function, TupleGetItem, Var, Constant, Tuple
Expand Down Expand Up @@ -83,7 +84,7 @@ def expr2graph(expr, target_ops, node_dict, node_list):

def _infer_type(node):
"""A method to infer the type of a relay expression."""
mod = relay.Module.from_expr(node)
mod = tvm.IRModule.from_expr(node)
mod = transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(node, relay.Function) else entry.body
Expand Down Expand Up @@ -136,7 +137,7 @@ def _traverse_expr(node):
free_var = relay.Var("var_%d" % i, input_type)
params.append(free_var)
call = relay.Call(node.op, params, node.attrs)
mod = relay.Module.from_expr(relay.Function(params, call))
mod = tvm.IRModule.from_expr(relay.Function(params, call))
relay.backend.compile_engine.get().clear()
build_thread = threading.Thread(target=relay.build,
args=(mod,
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/autotvm/graph_tuner/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=eval-used,invalid-name,too-many-arguments
"""Utility functions"""
import tvm
from tvm import relay
from tvm.relay import transform

Expand Down Expand Up @@ -136,7 +137,7 @@ def bind_inputs(expr, input_shapes=None, input_dtypes="float32"):
rebind_dict[var] = updated_input_dict[var.name_hint]
updated_expr = relay.expr.bind(expr, rebind_dict)

mod = relay.Module.from_expr(updated_expr)
mod = tvm.IRModule.from_expr(updated_expr)
mod = transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(updated_expr, relay.Function) else entry.body
8 changes: 4 additions & 4 deletions python/tvm/autotvm/task/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def extract_from_program(mod, params, ops, target, target_host=None,

Parameters
----------
mod: relay.module.Module or relay.expr.Function
mod: tvm.IRModule or relay.expr.Function
The module or function to tune
params: dict of str to numpy array
The associated parameters of the program
Expand Down Expand Up @@ -95,7 +95,7 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None,

Parameters
----------
mods: List[relay.module.Module] or List[relay.expr.Function]
mods: List[tvm.IRModule] or List[relay.expr.Function]
The list of modules or functions to tune
params: List of dict of str to numpy array
The associated parameters of the programs
Expand Down Expand Up @@ -151,8 +151,8 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None,

for mod, param in zip(mods, params):
if isinstance(mod, relay.expr.Function):
mod = relay.Module.from_expr(mod)
assert isinstance(mod, relay.module.Module), \
mod = tvm.IRModule.from_expr(mod)
assert isinstance(mod, tvm.IRModule), \
"only support relay Module or Function to be tuned"
relay.backend.compile_engine.get().clear()
# wrap build call in thread to avoid multiprocessing problems
Expand Down
22 changes: 12 additions & 10 deletions python/tvm/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,19 @@
import tvm.runtime

from tvm.runtime import Object, ndarray
from tvm.ir import container
from . import api
from . import _api_internal
from . import tensor
from . import schedule
from . import expr
from . import ir_pass
from . import stmt as _stmt
from . import container
from . import codegen
from . import target as _target
from . import make
from .stmt import LoweredFunc


class DumpIR(object):
"""
Expand All @@ -58,16 +60,16 @@ def decorate(self, func):
def dump(*args, **kwargs):
"""dump function"""
retv = func(*args, **kwargs)
if not isinstance(retv, (_stmt.Stmt, container.LoweredFunc, container.Array)):
if not isinstance(retv, (_stmt.Stmt, LoweredFunc, container.Array)):
return retv
fname = func.func_name if hasattr(func, 'func_name') else func.__name__
pname = str(self._pass_id) + "_" + fname + "_ir.cc"
with open(pname, "a") as f:
out = retv.body if isinstance(retv, container.LoweredFunc) else retv
out = retv.body if isinstance(retv, LoweredFunc) else retv
f.write(str(out))
if isinstance(retv, container.Array):
for x in retv:
out = x.body if isinstance(x, container.LoweredFunc) else x
out = x.body if isinstance(x, LoweredFunc) else x
f.write("---------%s\n%s\n-----------\n"%(x.name, str(out)))
self._pass_id += 1
return retv
Expand Down Expand Up @@ -459,7 +461,7 @@ def _build_for_device(flist, target, target_host):
raise ValueError(
"Direct host side access to device memory is detected in %s. "
"Did you forget to bind?" % func.name)
if func.func_type == container.LoweredFunc.MixedFunc:
if func.func_type == LoweredFunc.MixedFunc:
if current_build_config().detect_global_barrier:
func = ir_pass.ThreadSync(func, "global")
func = ir_pass.ThreadSync(func, "shared")
Expand All @@ -471,9 +473,9 @@ def _build_for_device(flist, target, target_host):
fhost.append(fsplits[0])
for x in fsplits[1:]:
fdevice.append(x)
elif func.func_type == container.LoweredFunc.HostFunc:
elif func.func_type == LoweredFunc.HostFunc:
fhost.append(func)
elif func.func_type == container.LoweredFunc.DeviceFunc:
elif func.func_type == LoweredFunc.DeviceFunc:
fdevice.append(func)
else:
raise ValueError("unknown function type %d" % func.func_type)
Expand Down Expand Up @@ -586,9 +588,9 @@ def build(inputs,
flist = lower(inputs, args,
name=name,
binds=binds)
if isinstance(flist, container.LoweredFunc):
if isinstance(flist, LoweredFunc):
flist = [flist]
elif isinstance(inputs, container.LoweredFunc):
elif isinstance(inputs, LoweredFunc):
if args:
raise ValueError("args must be done when build from LoweredFunc.")
flist = [inputs]
Expand All @@ -612,7 +614,7 @@ def build(inputs,
"_target.Target when inputs is dict.")
fname_set = set()
for x in flist:
if not isinstance(x, container.LoweredFunc):
if not isinstance(x, LoweredFunc):
raise ValueError("inputs must be Schedule, LoweredFunc, list "
"of LoweredFunc, or dict of str to list of "
"LoweredFunc.")
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/contrib/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, arg1, ctx=None, shape=None):
The corresponding a dense numpy array,
or a tuple for constructing a sparse matrix directly.

ctx: tvm.TVMContext
ctx: tvmContext
The corresponding context.

shape : tuple of int
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/hybrid/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@
# under the License.
"""Intrinsics of TVM-Python Hybrid Script for Python compilation time
semantic support."""

from tvm.ir.container import Array
from .. import api as _api
from .. import expr as _expr
from .. import make as _make
from .. import target as _tgt
from ..container import Array
from .. import ir_pass
from ..stmt import For
from .util import _internal_assert
Expand Down
Loading