From a5661611472c8e92b20bbe4d074333b8183f2878 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 11 Feb 2020 20:01:36 -0800 Subject: [PATCH] [REFACTOR][PY][API-CHANGE] establish tvm.ir, migrate corresponding files (#4862) * [REFACTOR][PY][API-CHANGE] establish tvm.ir, migrate corresponding relay files. This PR establishes tvm.ir and migrates the corresponding relay files into the new folder. API Change: - relay.Module -> tvm.IRModule * Update with ADT * Migrate transform * address comments * Migrate module * Migrate json_compact * Migrate attrs * Move LoweredFunc to stmt temporarily * temp migrate container * Finish migrate container --- apps/benchmark/util.py | 4 +- docs/api/python/tvm.rst | 5 +- include/tvm/ir/expr.h | 2 +- include/tvm/ir/type.h | 7 +- include/tvm/ir/type_relation.h | 2 +- python/tvm/__init__.py | 8 +- python/tvm/_ffi/_ctypes/object.py | 1 + python/tvm/api.py | 68 +--- .../autotvm/graph_tuner/base_graph_tuner.py | 2 +- .../graph_tuner/utils/traverse_graph.py | 5 +- python/tvm/autotvm/graph_tuner/utils/utils.py | 3 +- python/tvm/autotvm/task/relay_integration.py | 8 +- python/tvm/build_module.py | 22 +- python/tvm/contrib/sparse.py | 2 +- python/tvm/hybrid/calls.py | 3 +- python/tvm/hybrid/parser.py | 2 +- python/tvm/hybrid/util.py | 3 +- python/tvm/ir/__init__.py | 30 ++ .../tvm/{relay/_module.py => ir/_ffi_api.py} | 6 +- .../_module.pyi => ir/_ffi_transform_api.py} | 8 +- python/tvm/ir/adt.py | 87 +++++ python/tvm/{ => ir}/attrs.py | 9 +- python/tvm/ir/base.py | 151 ++++++++ python/tvm/{ => ir}/container.py | 34 +- python/tvm/ir/expr.py | 101 ++++++ python/tvm/{ => ir}/json_compact.py | 0 python/tvm/{relay => ir}/module.py | 94 +++-- python/tvm/ir/tensor_type.py | 56 +++ python/tvm/ir/transform.py | 328 ++++++++++++++++++ python/tvm/ir/type.py | 204 +++++++++++ python/tvm/ir/type_relation.py | 73 ++++ python/tvm/ir_builder.py | 5 +- python/tvm/relay/__init__.py | 9 +- python/tvm/relay/_parser.py | 23 +- python/tvm/relay/adt.py | 79 +---- python/tvm/relay/analysis.py | 28 +- python/tvm/relay/backend/_backend.py | 2 +- python/tvm/relay/backend/interpreter.py | 11 +- python/tvm/relay/backend/vm.py | 12 +- python/tvm/relay/base.py | 60 +--- python/tvm/relay/build_module.py | 29 +- python/tvm/relay/expr.py | 74 +--- python/tvm/relay/expr.pyi | 131 ------- python/tvm/relay/frontend/caffe2.py | 10 +- python/tvm/relay/frontend/common.py | 5 +- python/tvm/relay/frontend/coreml.py | 7 +- python/tvm/relay/frontend/darknet.py | 7 +- python/tvm/relay/frontend/keras.py | 7 +- python/tvm/relay/frontend/mxnet.py | 9 +- python/tvm/relay/frontend/onnx.py | 11 +- python/tvm/relay/frontend/tensorflow.py | 8 +- python/tvm/relay/frontend/tflite.py | 8 +- python/tvm/relay/memory_alloc.py | 2 +- python/tvm/relay/op/__init__.py | 1 - python/tvm/relay/op/nn/_nn.py | 8 +- python/tvm/relay/op/nn/util.py | 4 +- python/tvm/relay/op/op.py | 4 +- python/tvm/relay/op/op_attrs.py | 2 +- python/tvm/relay/prelude.py | 6 +- python/tvm/relay/qnn/op/legalizations.py | 6 +- python/tvm/relay/qnn/transform.py | 2 +- python/tvm/relay/quantize/_calibrate.py | 4 +- python/tvm/relay/testing/__init__.py | 2 +- python/tvm/relay/testing/dcgan.py | 2 +- python/tvm/relay/testing/densenet.py | 2 +- python/tvm/relay/testing/dqn.py | 2 +- python/tvm/relay/testing/inception_v3.py | 2 +- python/tvm/relay/testing/init.py | 4 +- python/tvm/relay/testing/lstm.py | 2 +- python/tvm/relay/testing/mlp.py | 2 +- python/tvm/relay/testing/mobilenet.py | 2 +- python/tvm/relay/testing/py_converter.py | 4 +- python/tvm/relay/testing/resnet.py | 2 +- python/tvm/relay/testing/squeezenet.py | 2 +- python/tvm/relay/testing/vgg.py | 2 +- python/tvm/relay/transform.py | 320 +---------------- python/tvm/relay/transform.pyi | 71 ---- python/tvm/relay/ty.py | 292 +--------------- python/tvm/relay/ty.pyi | 200 ----------- python/tvm/runtime/_ffi_node_api.py | 2 +- python/tvm/schedule.py | 2 +- python/tvm/stmt.py | 8 + src/ir/adt.cc | 4 +- src/ir/attrs.cc | 2 +- src/ir/env_func.cc | 6 +- src/ir/expr.cc | 2 +- src/ir/module.cc | 34 +- src/ir/span.cc | 4 +- src/ir/tensor_type.cc | 2 +- src/ir/transform.cc | 33 +- src/ir/type.cc | 14 +- src/ir/type_relation.cc | 4 +- src/printer/relay_text_printer.cc | 11 +- src/relay/ir/alpha_equal.cc | 5 + src/relay/ir/base.cc | 2 +- tests/python/contrib/test_rpc_tracker.py | 2 +- tests/python/frontend/mxnet/test_graph.py | 4 +- .../frontend/mxnet/test_qnn_ops_utils.py | 4 +- tests/python/relay/test_adt.py | 28 +- tests/python/relay/test_any.py | 54 +-- .../relay/test_backend_compile_engine.py | 10 +- .../relay/test_backend_graph_runtime.py | 7 +- .../python/relay/test_backend_interpreter.py | 10 +- tests/python/relay/test_cpp_build_module.py | 4 +- tests/python/relay/test_error_reporting.py | 4 +- tests/python/relay/test_external_codegen.py | 10 +- tests/python/relay/test_external_runtime.py | 2 +- tests/python/relay/test_feature.py | 2 +- tests/python/relay/test_ir_module.py | 7 +- tests/python/relay/test_ir_nodes.py | 22 +- tests/python/relay/test_ir_parser.py | 18 +- tests/python/relay/test_ir_text_printer.py | 2 +- tests/python/relay/test_ir_well_formed.py | 2 +- tests/python/relay/test_json_compact.py | 4 +- tests/python/relay/test_memory_alloc.py | 6 +- tests/python/relay/test_op_level10.py | 8 +- tests/python/relay/test_op_level2.py | 6 +- tests/python/relay/test_op_qnn_add.py | 12 +- tests/python/relay/test_op_qnn_concatenate.py | 8 +- tests/python/relay/test_op_qnn_conv2d.py | 8 +- tests/python/relay/test_op_qnn_dense.py | 2 +- tests/python/relay/test_op_qnn_dequantize.py | 2 +- tests/python/relay/test_op_qnn_mul.py | 10 +- tests/python/relay/test_op_qnn_quantize.py | 2 +- tests/python/relay/test_op_qnn_requantize.py | 2 +- tests/python/relay/test_pass_alpha_equal.py | 44 +-- .../python/relay/test_pass_alter_op_layout.py | 6 +- tests/python/relay/test_pass_annotation.py | 2 +- .../relay/test_pass_canonicalize_cast.py | 3 +- tests/python/relay/test_pass_check_kind.py | 120 +++---- .../test_pass_combine_parallel_conv2d.py | 5 +- .../relay/test_pass_combine_parallel_dense.py | 5 +- .../relay/test_pass_convert_op_layout.py | 2 +- .../relay/test_pass_dead_code_elimination.py | 2 +- .../test_pass_eliminate_common_subexpr.py | 4 +- tests/python/relay/test_pass_fold_constant.py | 2 +- .../python/relay/test_pass_fold_scale_axis.py | 3 +- tests/python/relay/test_pass_fuse_ops.py | 18 +- tests/python/relay/test_pass_gradient.py | 2 +- tests/python/relay/test_pass_lambda_lift.py | 8 +- tests/python/relay/test_pass_legalize.py | 2 +- tests/python/relay/test_pass_mac_count.py | 2 +- tests/python/relay/test_pass_manager.py | 38 +- tests/python/relay/test_pass_partial_eval.py | 24 +- .../python/relay/test_pass_partition_graph.py | 14 +- tests/python/relay/test_pass_qnn_legalize.py | 6 +- .../test_pass_remove_unused_functions.py | 12 +- .../relay/test_pass_simplify_inference.py | 3 +- .../relay/test_pass_to_a_normal_form.py | 8 +- tests/python/relay/test_pass_to_cps.py | 2 +- .../relay/test_pass_to_graph_normal_form.py | 4 +- .../python/relay/test_pass_unmatched_cases.py | 10 +- tests/python/relay/test_pass_vars.py | 2 +- tests/python/relay/test_py_converter.py | 24 +- tests/python/relay/test_type_functor.py | 2 +- tests/python/relay/test_type_infer.py | 17 +- tests/python/relay/test_type_solver.py | 2 +- tests/python/relay/test_typecall.py | 5 +- tests/python/relay/test_vm.py | 58 ++-- tests/python/relay/test_vm_serialization.py | 14 +- .../unittest/test_codegen_cross_llvm.py | 2 +- tests/python/unittest/test_container.py | 2 +- .../python/unittest/test_graph_tuner_core.py | 2 +- tests/python/unittest/test_lang_container.py | 8 +- tests/python/unittest/test_lang_group.py | 2 +- tests/python/unittest/test_lang_operator.py | 2 +- tests/python/unittest/test_lang_reflection.py | 30 +- tests/python/unittest/test_lang_schedule.py | 6 +- tests/python/unittest/test_lang_tag.py | 2 +- tests/python/unittest/test_lang_tensor.py | 8 +- tests/python/unittest/test_pass_inline.py | 2 +- tests/python/unittest/test_runtime_rpc.py | 2 +- topi/python/topi/arm_cpu/bitserial_conv2d.py | 2 +- topi/python/topi/arm_cpu/conv2d.py | 2 +- topi/python/topi/cuda/conv2d_winograd.py | 2 +- topi/python/topi/nn/bitserial_conv2d.py | 2 +- topi/python/topi/nn/conv2d.py | 4 +- topi/python/topi/nn/conv2d_transpose.py | 2 +- topi/python/topi/x86/conv2d_alter_op.py | 2 +- tutorials/autotvm/tune_relay_arm.py | 2 +- tutorials/autotvm/tune_relay_cuda.py | 2 +- tutorials/autotvm/tune_relay_mobile_gpu.py | 2 +- tutorials/autotvm/tune_relay_x86.py | 2 +- tutorials/dev/relay_pass_infra.py | 8 +- vta/python/vta/build_module.py | 2 +- vta/python/vta/top/graphpack.py | 3 +- vta/tutorials/autotvm/tune_relay_vta.py | 2 +- 187 files changed, 1818 insertions(+), 1976 deletions(-) create mode 100644 python/tvm/ir/__init__.py rename python/tvm/{relay/_module.py => ir/_ffi_api.py} (82%) rename python/tvm/{relay/_module.pyi => ir/_ffi_transform_api.py} (80%) create mode 100644 python/tvm/ir/adt.py rename python/tvm/{ => ir}/attrs.py (94%) create mode 100644 python/tvm/ir/base.py rename python/tvm/{ => ir}/container.py (75%) create mode 100644 python/tvm/ir/expr.py rename python/tvm/{ => ir}/json_compact.py (100%) rename python/tvm/{relay => ir}/module.py (68%) create mode 100644 python/tvm/ir/tensor_type.py create mode 100644 python/tvm/ir/transform.py create mode 100644 python/tvm/ir/type.py create mode 100644 python/tvm/ir/type_relation.py delete mode 100644 python/tvm/relay/expr.pyi delete mode 100644 python/tvm/relay/transform.pyi delete mode 100644 python/tvm/relay/ty.pyi diff --git a/apps/benchmark/util.py b/apps/benchmark/util.py index c7de3a1dda31..86d139f1c851 100644 --- a/apps/benchmark/util.py +++ b/apps/benchmark/util.py @@ -34,7 +34,7 @@ def get_network(name, batch_size, dtype='float32'): Returns ------- - net: relay.Module + net: tvm.IRModule The relay function of network definition params: dict The random parameters for benchmark @@ -70,7 +70,7 @@ def get_network(name, batch_size, dtype='float32'): net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) net = net["main"] net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs) - net = relay.Module.from_expr(net) + net = tvm.IRModule.from_expr(net) else: raise ValueError("Unsupported network: " + name) diff --git a/docs/api/python/tvm.rst b/docs/api/python/tvm.rst index 19762fb20d97..07c2dbc44765 100644 --- a/docs/api/python/tvm.rst +++ b/docs/api/python/tvm.rst @@ -21,8 +21,6 @@ The user facing API for computation declaration. .. autosummary:: - tvm.load_json - tvm.save_json tvm.var tvm.size_var tvm.const @@ -47,8 +45,7 @@ The user facing API for computation declaration. tvm.max tvm.tag_scope -.. autofunction:: tvm.load_json -.. autofunction:: tvm.save_json + .. autofunction:: tvm.var .. autofunction:: tvm.size_var .. autofunction:: tvm.const diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 61b3e13c1630..eceafec75fa1 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -178,7 +178,7 @@ class RelayExpr : public BaseExpr { class GlobalVar; /*! - * \brief Global variable that leaves in the top-level module. + * \brief Global variable that lives in the top-level module. * * A GlobalVar only refers to function definitions. * This is used to enable recursive calls between function. diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 56f2389ad385..9e87731dae72 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -141,11 +141,12 @@ enum TypeKind : int { }; /*! - * \brief Type parameter in the function. - * This can be viewed as template parameter in c++ template function. + * \brief Type parameter in functions. + * + * A type variable can be viewed as template parameter in c++ template function. * * For example, in the following pesudo code, - * the TypeVar of f is TypeVar(kind=kShapeVar, var=n). + * the TypeVar of f is TypeVar("n", kind=kShapeVar). * This function can take in a Tensor with shape=(3, 3) and * returns a Tensor with shape=(9,) * diff --git a/include/tvm/ir/type_relation.h b/include/tvm/ir/type_relation.h index 6d4e75a23f6b..ff36b9671fa8 100644 --- a/include/tvm/ir/type_relation.h +++ b/include/tvm/ir/type_relation.h @@ -165,7 +165,7 @@ using TypeRelationFn = const TypeReporter& reporter)>; /*! - * \brief User defined type relation, is an input-output relation on types. + * \brief User defined type relation, it is an input-output relation on types. * * TypeRelation is more generalized than type call as it allows inference * of both inputs and outputs. diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 580a0714558c..69c24008c10f 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -33,6 +33,12 @@ from .runtime.ndarray import vpi, rocm, opengl, ext_dev, micro_dev from .runtime import ndarray as nd +# tvm.ir +from .ir import IRModule +from .ir import transform +from .ir import container +from . import ir + # others from . import tensor from . import arith @@ -41,10 +47,8 @@ from . import make from . import ir_pass from . import codegen -from . import container from . import schedule -from . import attrs from . import ir_builder from . import target from . import generic diff --git a/python/tvm/_ffi/_ctypes/object.py b/python/tvm/_ffi/_ctypes/object.py index 934e33ff5891..263a76d414a8 100644 --- a/python/tvm/_ffi/_ctypes/object.py +++ b/python/tvm/_ffi/_ctypes/object.py @@ -87,6 +87,7 @@ def __init_handle_by_constructor__(self, fconstructor, *args): instead of creating a new Node. """ # assign handle first to avoid error raising + # pylint: disable=not-callable self.handle = None handle = __init_by_constructor__(fconstructor, args) if not isinstance(handle, ObjectHandle): diff --git a/python/tvm/api.py b/python/tvm/api.py index d27cd20574a5..e7778d6cc5df 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -19,9 +19,11 @@ from numbers import Integral as _Integral import tvm._ffi -import tvm.runtime._ffi_node_api +import tvm.ir from tvm.runtime import convert, const, DataType +from tvm.ir import container as _container + from ._ffi.base import string_types, TVMError from ._ffi.registry import register_func, get_global_func, extract_ext_funcs @@ -30,9 +32,7 @@ from . import expr as _expr from . import tensor as _tensor from . import schedule as _schedule -from . import container as _container from . import tag as _tag -from . import json_compact int8 = "int8" int32 = "int32" @@ -71,66 +71,6 @@ def max_value(dtype): """ return _api_internal._max_value(dtype) - -def get_env_func(name): - """Get an EnvFunc by a global name. - - Parameters - ---------- - name: str - The name of the global function. - - Returns - ------- - env_func : EnvFunc - The result env function. - - Note - ---- - EnvFunc is a Object wrapper around - global function that can be serialized via its name. - This can be used to serialize function field in the language. - """ - return _api_internal._EnvFuncGet(name) - - -def load_json(json_str): - """Load tvm object from json_str. - - Parameters - ---------- - json_str : str - The json string - - Returns - ------- - node : Object - The loaded tvm node. - """ - - try: - return tvm.runtime._ffi_node_api.LoadJSON(json_str) - except TVMError: - json_str = json_compact.upgrade_json(json_str) - return tvm.runtime._ffi_node_api.LoadJSON(json_str) - - -def save_json(node): - """Save tvm object as json string. - - Parameters - ---------- - node : Object - A TVM object to be saved. - - Returns - ------- - json_str : str - Saved json string. - """ - return tvm.runtime._ffi_node_api.SaveJSON(node) - - def var(name="tindex", dtype=int32): """Create a new variable with specified name and dtype @@ -688,7 +628,7 @@ def _IterVar(dom, name, iter_type, thread_tag=''): raise TypeError("need to be list of ranges") dom = Range(dom[0], dom[1]) - if not isinstance(dom, _container.Range): + if not isinstance(dom, tvm.ir.Range): raise TypeError("dom need to be Range") name = name if name else 'iter' v = var(name) diff --git a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py index bdff057c5a7e..b02c289cb10f 100644 --- a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py +++ b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py @@ -141,7 +141,7 @@ def __init__(self, graph, input_shapes, records, target_ops, self._logger.propagate = False # Generate workload and schedule dictionaries. - if isinstance(graph, relay.Module): + if isinstance(graph, tvm.IRModule): graph = graph["main"] if isinstance(graph, relay.expr.Function): diff --git a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py index d3a27cbc1ecd..7648322d3b18 100644 --- a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py +++ b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py @@ -20,6 +20,7 @@ import topi +import tvm from tvm import relay, autotvm from tvm.relay import transform from tvm.relay.expr import Call, Function, TupleGetItem, Var, Constant, Tuple @@ -83,7 +84,7 @@ def expr2graph(expr, target_ops, node_dict, node_list): def _infer_type(node): """A method to infer the type of a relay expression.""" - mod = relay.Module.from_expr(node) + mod = tvm.IRModule.from_expr(node) mod = transform.InferType()(mod) entry = mod["main"] return entry if isinstance(node, relay.Function) else entry.body @@ -136,7 +137,7 @@ def _traverse_expr(node): free_var = relay.Var("var_%d" % i, input_type) params.append(free_var) call = relay.Call(node.op, params, node.attrs) - mod = relay.Module.from_expr(relay.Function(params, call)) + mod = tvm.IRModule.from_expr(relay.Function(params, call)) relay.backend.compile_engine.get().clear() build_thread = threading.Thread(target=relay.build, args=(mod, diff --git a/python/tvm/autotvm/graph_tuner/utils/utils.py b/python/tvm/autotvm/graph_tuner/utils/utils.py index d73f2c35f50e..137ccbed2bbd 100644 --- a/python/tvm/autotvm/graph_tuner/utils/utils.py +++ b/python/tvm/autotvm/graph_tuner/utils/utils.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=eval-used,invalid-name,too-many-arguments """Utility functions""" +import tvm from tvm import relay from tvm.relay import transform @@ -136,7 +137,7 @@ def bind_inputs(expr, input_shapes=None, input_dtypes="float32"): rebind_dict[var] = updated_input_dict[var.name_hint] updated_expr = relay.expr.bind(expr, rebind_dict) - mod = relay.Module.from_expr(updated_expr) + mod = tvm.IRModule.from_expr(updated_expr) mod = transform.InferType()(mod) entry = mod["main"] return entry if isinstance(updated_expr, relay.Function) else entry.body diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index 7471ca3d6c8f..87d28b7a810a 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -63,7 +63,7 @@ def extract_from_program(mod, params, ops, target, target_host=None, Parameters ---------- - mod: relay.module.Module or relay.expr.Function + mod: tvm.IRModule or relay.expr.Function The module or function to tune params: dict of str to numpy array The associated parameters of the program @@ -95,7 +95,7 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None, Parameters ---------- - mods: List[relay.module.Module] or List[relay.expr.Function] + mods: List[tvm.IRModule] or List[relay.expr.Function] The list of modules or functions to tune params: List of dict of str to numpy array The associated parameters of the programs @@ -151,8 +151,8 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None, for mod, param in zip(mods, params): if isinstance(mod, relay.expr.Function): - mod = relay.Module.from_expr(mod) - assert isinstance(mod, relay.module.Module), \ + mod = tvm.IRModule.from_expr(mod) + assert isinstance(mod, tvm.IRModule), \ "only support relay Module or Function to be tuned" relay.backend.compile_engine.get().clear() # wrap build call in thread to avoid multiprocessing problems diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index 9346d7d5a627..768f43884418 100644 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -24,6 +24,7 @@ import tvm.runtime from tvm.runtime import Object, ndarray +from tvm.ir import container from . import api from . import _api_internal from . import tensor @@ -31,10 +32,11 @@ from . import expr from . import ir_pass from . import stmt as _stmt -from . import container from . import codegen from . import target as _target from . import make +from .stmt import LoweredFunc + class DumpIR(object): """ @@ -58,16 +60,16 @@ def decorate(self, func): def dump(*args, **kwargs): """dump function""" retv = func(*args, **kwargs) - if not isinstance(retv, (_stmt.Stmt, container.LoweredFunc, container.Array)): + if not isinstance(retv, (_stmt.Stmt, LoweredFunc, container.Array)): return retv fname = func.func_name if hasattr(func, 'func_name') else func.__name__ pname = str(self._pass_id) + "_" + fname + "_ir.cc" with open(pname, "a") as f: - out = retv.body if isinstance(retv, container.LoweredFunc) else retv + out = retv.body if isinstance(retv, LoweredFunc) else retv f.write(str(out)) if isinstance(retv, container.Array): for x in retv: - out = x.body if isinstance(x, container.LoweredFunc) else x + out = x.body if isinstance(x, LoweredFunc) else x f.write("---------%s\n%s\n-----------\n"%(x.name, str(out))) self._pass_id += 1 return retv @@ -459,7 +461,7 @@ def _build_for_device(flist, target, target_host): raise ValueError( "Direct host side access to device memory is detected in %s. " "Did you forget to bind?" % func.name) - if func.func_type == container.LoweredFunc.MixedFunc: + if func.func_type == LoweredFunc.MixedFunc: if current_build_config().detect_global_barrier: func = ir_pass.ThreadSync(func, "global") func = ir_pass.ThreadSync(func, "shared") @@ -471,9 +473,9 @@ def _build_for_device(flist, target, target_host): fhost.append(fsplits[0]) for x in fsplits[1:]: fdevice.append(x) - elif func.func_type == container.LoweredFunc.HostFunc: + elif func.func_type == LoweredFunc.HostFunc: fhost.append(func) - elif func.func_type == container.LoweredFunc.DeviceFunc: + elif func.func_type == LoweredFunc.DeviceFunc: fdevice.append(func) else: raise ValueError("unknown function type %d" % func.func_type) @@ -586,9 +588,9 @@ def build(inputs, flist = lower(inputs, args, name=name, binds=binds) - if isinstance(flist, container.LoweredFunc): + if isinstance(flist, LoweredFunc): flist = [flist] - elif isinstance(inputs, container.LoweredFunc): + elif isinstance(inputs, LoweredFunc): if args: raise ValueError("args must be done when build from LoweredFunc.") flist = [inputs] @@ -612,7 +614,7 @@ def build(inputs, "_target.Target when inputs is dict.") fname_set = set() for x in flist: - if not isinstance(x, container.LoweredFunc): + if not isinstance(x, LoweredFunc): raise ValueError("inputs must be Schedule, LoweredFunc, list " "of LoweredFunc, or dict of str to list of " "LoweredFunc.") diff --git a/python/tvm/contrib/sparse.py b/python/tvm/contrib/sparse.py index 2a51637fe6ce..966e180ec2b8 100644 --- a/python/tvm/contrib/sparse.py +++ b/python/tvm/contrib/sparse.py @@ -38,7 +38,7 @@ def __init__(self, arg1, ctx=None, shape=None): The corresponding a dense numpy array, or a tuple for constructing a sparse matrix directly. - ctx: tvm.TVMContext + ctx: tvmContext The corresponding context. shape : tuple of int diff --git a/python/tvm/hybrid/calls.py b/python/tvm/hybrid/calls.py index e873e1974d21..630c10fcf2dd 100644 --- a/python/tvm/hybrid/calls.py +++ b/python/tvm/hybrid/calls.py @@ -16,12 +16,11 @@ # under the License. """Intrinsics of TVM-Python Hybrid Script for Python compilation time semantic support.""" - +from tvm.ir.container import Array from .. import api as _api from .. import expr as _expr from .. import make as _make from .. import target as _tgt -from ..container import Array from .. import ir_pass from ..stmt import For from .util import _internal_assert diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index cd2433e64a8c..a0b2dfea6062 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -24,6 +24,7 @@ import numbers from enum import Enum +from tvm.ir.container import Array from .util import _internal_assert from . import calls @@ -32,7 +33,6 @@ from ..api import all as _all from ..api import any as _any -from ..container import Array from ..tensor import Tensor, Operation from .. import _api_internal as _tvm_internal from .. import expr as _expr diff --git a/python/tvm/hybrid/util.py b/python/tvm/hybrid/util.py index 0883960fabfd..8ef200a02b2c 100644 --- a/python/tvm/hybrid/util.py +++ b/python/tvm/hybrid/util.py @@ -21,13 +21,14 @@ import logging import sys import numpy + +from tvm.ir.container import Array from .. import api as _api from .. import make as _make from .. import expr as _expr from .. import stmt as _stmt from .._ffi.base import numeric_types from ..tensor import Tensor -from ..container import Array #pylint: disable=invalid-name diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py new file mode 100644 index 000000000000..e3552b5fe047 --- /dev/null +++ b/python/tvm/ir/__init__.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import +"""Common data structures across all IR variants.""" +from .base import SourceName, Span, Node, EnvFunc, load_json, save_json +from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, BaseFunc, Range +from .type import Type, TypeKind, TypeVar, GlobalTypeVar, TupleType +from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType +from .type_relation import TypeCall, TypeRelation +from .tensor_type import TensorType +from .adt import Constructor, TypeData +from .module import IRModule +from .attrs import Attrs +from .container import Array, Map + +from . import transform diff --git a/python/tvm/relay/_module.py b/python/tvm/ir/_ffi_api.py similarity index 82% rename from python/tvm/relay/_module.py rename to python/tvm/ir/_ffi_api.py index aedb74a05486..d3a9505c38d0 100644 --- a/python/tvm/relay/_module.py +++ b/python/tvm/ir/_ffi_api.py @@ -14,8 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable -"""The interface to the Module exposed from C++.""" +"""FFI APIs for tvm.ir""" import tvm._ffi -tvm._ffi._init_api("relay._module", __name__) + +tvm._ffi._init_api("ir", __name__) diff --git a/python/tvm/relay/_module.pyi b/python/tvm/ir/_ffi_transform_api.py similarity index 80% rename from python/tvm/relay/_module.pyi rename to python/tvm/ir/_ffi_transform_api.py index 66c994e4400e..76a4c337ae75 100644 --- a/python/tvm/relay/_module.pyi +++ b/python/tvm/ir/_ffi_transform_api.py @@ -1,3 +1,4 @@ + # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -14,9 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""FFI APIs for tvm.transform""" +import tvm._ffi -from typing import Union, Tuple, Dict, List -from relay.ir import GlobalId, OperatorId, Item, Object, Span, FileId -from relay.ir import ShapeExtension, Operator, Defn -class Module(Object): ... +tvm._ffi._init_api("transform", __name__) diff --git a/python/tvm/ir/adt.py b/python/tvm/ir/adt.py new file mode 100644 index 000000000000..d126f286475d --- /dev/null +++ b/python/tvm/ir/adt.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Algebraic data type definitions.""" +import tvm._ffi + +from .type import Type +from .expr import RelayExpr +from . import _ffi_api + + +@tvm._ffi.register_object("relay.Constructor") +class Constructor(RelayExpr): + """Relay ADT constructor. + + Parameters + ---------- + name_hint : str + Name of constructor (only a hint). + + inputs : List[Type] + Input types. + + belong_to : GlobalTypeVar + Denotes which ADT the constructor belongs to. + """ + def __init__(self, name_hint, inputs, belong_to): + self.__init_handle_by_constructor__( + _ffi_api.Constructor, name_hint, inputs, belong_to) + + def __call__(self, *args): + """Call the constructor. + + Parameters + ---------- + args: List[RelayExpr] + The arguments to the constructor. + + Returns + ------- + call: RelayExpr + A call to the constructor. + """ + # pylint: disable=import-outside-toplevel + from tvm import relay + return relay.Call(self, args) + + +@tvm._ffi.register_object("relay.TypeData") +class TypeData(Type): + """Stores the definition for an Algebraic Data Type (ADT) in Relay. + + Note that ADT definitions are treated as type-level functions because + the type parameters need to be given for an instance of the ADT. Thus, + any global type var that is an ADT header needs to be wrapped in a + type call that passes in the type params. + + Parameters + ---------- + header: GlobalTypeVar + The name of the ADT. + ADTs with the same constructors but different names are + treated as different types. + + type_vars: List[TypeVar] + Type variables that appear in constructors. + + constructors: List[Constructor] + The constructors for the ADT. + """ + def __init__(self, header, type_vars, constructors): + self.__init_handle_by_constructor__( + _ffi_api.TypeData, header, type_vars, constructors) diff --git a/python/tvm/attrs.py b/python/tvm/ir/attrs.py similarity index 94% rename from python/tvm/attrs.py rename to python/tvm/ir/attrs.py index dc6ca72f5a93..a5967394e72e 100644 --- a/python/tvm/attrs.py +++ b/python/tvm/ir/attrs.py @@ -14,11 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""" TVM Attribute module, which is mainly used for defining attributes of operators""" +""" TVM Attribute module, which is mainly used for defining attributes of operators.""" import tvm._ffi from tvm.runtime import Object -from . import _api_internal +from . import _ffi_api @tvm._ffi.register_object @@ -36,7 +36,7 @@ def list_field_info(self): infos: list of AttrFieldInfo List of field information """ - return _api_internal._AttrsListFieldInfo(self) + return _ffi_api.AttrsListFieldInfo(self) def keys(self): """Get list of names in the attribute. @@ -91,6 +91,3 @@ def get_str(self, key): def __getitem__(self, item): return self.__getattr__(item) - - -tvm._ffi._init_api("tvm.attrs") diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py new file mode 100644 index 000000000000..3314ef130e25 --- /dev/null +++ b/python/tvm/ir/base.py @@ -0,0 +1,151 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Common base structures.""" +import tvm._ffi + +import tvm.error +import tvm.runtime._ffi_node_api +from tvm.runtime import Object + +from . import _ffi_api +from . import json_compact + +class Node(Object): + """Base class of all IR Nodes, implements astext function.""" + def astext(self, show_meta_data=True, annotate=None): + """Get the text format of the expression. + + Parameters + ---------- + show_meta_data : bool + Whether to include meta data section in the text + if there is meta data. + + annotate: Optional[Object->str] + Optionally annotate function to provide additional + information in the comment block. + + Note + ---- + The meta data section is necessary to fully parse the text format. + However, it can contain dumps that are big (e.g constant weights), + so it can be helpful to skip printing the meta data section. + + Returns + ------- + text : str + The text format of the expression. + """ + return _ffi_api.AsText(self, show_meta_data, annotate) + + def __str__(self): + return self.astext(show_meta_data=False) + + +@tvm._ffi.register_object("relay.SourceName") +class SourceName(Object): + """A identifier for a source location. + + Parameters + ---------- + name : str + The name of the source. + """ + def __init__(self, name): + self.__init_handle_by_constructor__(_ffi_api.SourceName, name) + + +@tvm._ffi.register_object("relay.Span") +class Span(Object): + """Specifies a location in a source program. + + Parameters + ---------- + source : SourceName + The source name. + + lineno : int + The line number. + + col_offset : int + The column offset of the location. + """ + def __init__(self, source, lineno, col_offset): + self.__init_handle_by_constructor__( + _ffi_api.Span, source, lineno, col_offset) + + +@tvm._ffi.register_object +class EnvFunc(Object): + """Environment function. + + This is a global function object that can be serialized by its name. + """ + def __call__(self, *args): + return _ffi_api.EnvFuncCall(self, *args) + + @property + def func(self): + return _ffi_api.EnvFuncGetPackedFunc(self) + + @staticmethod + def get(name): + """Get a static env function + + Parameters + ---------- + name : str + The name of the function. + """ + return _ffi_api.EnvFuncGet(name) + + +def load_json(json_str): + """Load tvm object from json_str. + + Parameters + ---------- + json_str : str + The json string + + Returns + ------- + node : Object + The loaded tvm node. + """ + + try: + return tvm.runtime._ffi_node_api.LoadJSON(json_str) + except tvm.error.TVMError: + json_str = json_compact.upgrade_json(json_str) + return tvm.runtime._ffi_node_api.LoadJSON(json_str) + + +def save_json(node): + """Save tvm object as json string. + + Parameters + ---------- + node : Object + A TVM object to be saved. + + Returns + ------- + json_str : str + Saved json string. + """ + return tvm.runtime._ffi_node_api.SaveJSON(node) diff --git a/python/tvm/container.py b/python/tvm/ir/container.py similarity index 75% rename from python/tvm/container.py rename to python/tvm/ir/container.py index 111257eec1fb..11ef107f5514 100644 --- a/python/tvm/container.py +++ b/python/tvm/ir/container.py @@ -14,13 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Container data structures used in TVM DSL.""" +"""Additional container data structures used across IR variants.""" import tvm._ffi from tvm.runtime import Object from tvm.runtime.container import getitem_helper from tvm.runtime import _ffi_node_api -from . import _api_internal @tvm._ffi.register_object @@ -40,20 +39,6 @@ def __len__(self): return _ffi_node_api.ArraySize(self) -@tvm._ffi.register_object -class EnvFunc(Object): - """Environment function. - - This is a global function object that can be serialized by its name. - """ - def __call__(self, *args): - return _api_internal._EnvFuncCall(self, *args) - - @property - def func(self): - return _api_internal._EnvFuncGetPackedFunc(self) - - @tvm._ffi.register_object class Map(Object): """Map container of TVM. @@ -87,20 +72,3 @@ def items(self): """Get the items from the map""" akvs = _ffi_node_api.MapItems(self) return [(akvs[i].value, akvs[i+1]) for i in range(0, len(akvs), 2)] - - -@tvm._ffi.register_object -class Range(Object): - """Represent a range in TVM. - - You do not need to create a Range explicitly. - Python lists and tuples will be converted automatically to a Range in API functions. - """ - - -@tvm._ffi.register_object -class LoweredFunc(Object): - """Represent a LoweredFunc in TVM.""" - MixedFunc = 0 - HostFunc = 1 - DeviceFunc = 2 diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py new file mode 100644 index 000000000000..46acd16b8031 --- /dev/null +++ b/python/tvm/ir/expr.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Common expressions data structures in the IR.""" +import tvm._ffi + + +from .base import Node +from . import _ffi_api + +class BaseExpr(Node): + """Base class of all the expressions.""" + + +class PrimExpr(BaseExpr): + """Base class of all primitive expressions. + + PrimExpr is used in the low-level code + optimizations and integer analysis. + """ + + +class RelayExpr(BaseExpr): + """Base class of all non-primitive expressions.""" + @property + def checked_type(self): + """Get the checked type of tvm.relay.Expr. + + Returns + ------- + checked_type : tvm.relay.Type + The checked type. + """ + ret = self._checked_type_ + if ret is None: + raise ValueError("The type checker has not populated" + " the checked_type for this node") + return ret + + +class BaseFunc(RelayExpr): + """Base class of all functions.""" + + +@tvm._ffi.register_object("relay.GlobalVar") +class GlobalVar(RelayExpr): + """A global variable in the IR. + + GlobalVar is used to refer to the global functions + stored in the IRModule. + + Parameters + ---------- + name_hint: str + The name of the variable. + """ + def __init__(self, name_hint): + self.__init_handle_by_constructor__(_ffi_api.GlobalVar, name_hint) + + def __call__(self, *args): + """Call the global variable. + + Parameters + ---------- + args: List[RelayExpr] + The arguments to the call. + + Returns + ------- + call: BaseExpr + A call taking the variable as a function. + """ + # pylint: disable=import-outside-toplevel + if all(isinstance(x, RelayExpr) for x in args): + from tvm import relay + return relay.Call(self, args) + arg_types = [type(x) for x in args] + raise RuntimeError( + "Do not know how to handle GlobalVar.__call__ for types {}".format(arg_types)) + + +@tvm._ffi.register_object +class Range(Node): + """Represent a range in TVM. + + You do not need to create a Range explicitly. + Python lists and tuples will be converted automatically to a Range in API functions. + """ diff --git a/python/tvm/json_compact.py b/python/tvm/ir/json_compact.py similarity index 100% rename from python/tvm/json_compact.py rename to python/tvm/ir/json_compact.py diff --git a/python/tvm/relay/module.py b/python/tvm/ir/module.py similarity index 68% rename from python/tvm/relay/module.py rename to python/tvm/ir/module.py index 5513bd711c4f..ae1564b27105 100644 --- a/python/tvm/relay/module.py +++ b/python/tvm/ir/module.py @@ -14,36 +14,26 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import -"""A global module storing everything needed to interpret or compile a Relay program.""" -import os -from .base import register_relay_node, RelayNode -from .. import register_func -from .._ffi import base as _base -from . import _make -from . import _module -from . import expr as _expr -from . import ty as _ty +"""IRModule that holds the functions and type definitions.""" +from tvm._ffi.base import string_types +import tvm._ffi -__STD_PATH__ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "std") +from .base import Node +from . import expr as _expr +from . import type as _ty +from . import _ffi_api -@register_func("tvm.relay.std_path") -def _std_path(): - global __STD_PATH__ - return __STD_PATH__ -@register_relay_node -class Module(RelayNode): - """The global Relay module containing collection of functions. +@tvm._ffi.register_object("relay.Module") +class IRModule(Node): + """IRModule that holds functions and type definitions. - Each global function is identified by an unique tvm.relay.GlobalVar. - tvm.relay.GlobalVar and Module is necessary in order to enable - recursions in function to avoid cyclic reference in the function.x + IRModule is the basic unit for all IR transformations across the stack. Parameters ---------- functions: Optional[dict]. - Map of global var to Function + Map of global var to BaseFunc """ def __init__(self, functions=None, type_definitions=None): if functions is None: @@ -51,7 +41,7 @@ def __init__(self, functions=None, type_definitions=None): elif isinstance(functions, dict): mapped_funcs = {} for k, v in functions.items(): - if isinstance(k, _base.string_types): + if isinstance(k, string_types): k = _expr.GlobalVar(k) if not isinstance(k, _expr.GlobalVar): raise TypeError("Expect functions to be Dict[GlobalVar, Function]") @@ -62,13 +52,13 @@ def __init__(self, functions=None, type_definitions=None): elif isinstance(type_definitions, dict): mapped_type_defs = {} for k, v in type_definitions.items(): - if isinstance(k, _base.string_types): + if isinstance(k, string_types): k = _ty.GlobalTypeVar(k) if not isinstance(k, _ty.GlobalTypeVar): raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]") mapped_type_defs[k] = v type_definitions = mapped_type_defs - self.__init_handle_by_constructor__(_make.Module, functions, type_definitions) + self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions) def __setitem__(self, var, val): @@ -85,18 +75,18 @@ def __setitem__(self, var, val): return self._add(var, val) def _add(self, var, val, update=False): - if isinstance(val, _expr.Expr): - if isinstance(var, _base.string_types): - if _module.Module_ContainGlobalVar(self, var): - var = _module.Module_GetGlobalVar(self, var) + if isinstance(val, _expr.RelayExpr): + if isinstance(var, string_types): + if _ffi_api.Module_ContainGlobalVar(self, var): + var = _ffi_api.Module_GetGlobalVar(self, var) else: var = _expr.GlobalVar(var) - _module.Module_Add(self, var, val, update) + _ffi_api.Module_Add(self, var, val, update) else: assert isinstance(val, _ty.Type) - if isinstance(var, _base.string_types): + if isinstance(var, string_types): var = _ty.GlobalTypeVar(var) - _module.Module_AddDef(self, var, val, update) + _ffi_api.Module_AddDef(self, var, val, update) def __getitem__(self, var): """Lookup a global definition by name or by variable. @@ -111,12 +101,11 @@ def __getitem__(self, var): val: Union[Function, Type] The definition referenced by :code:`var` (either a function or type). """ - if isinstance(var, _base.string_types): - return _module.Module_Lookup_str(self, var) - elif isinstance(var, _expr.GlobalVar): - return _module.Module_Lookup(self, var) - else: - return _module.Module_LookupDef(self, var) + if isinstance(var, string_types): + return _ffi_api.Module_Lookup_str(self, var) + if isinstance(var, _expr.GlobalVar): + return _ffi_api.Module_Lookup(self, var) + return _ffi_api.Module_LookupDef(self, var) def update(self, other): """Insert functions in another Module to current one. @@ -128,7 +117,7 @@ def update(self, other): """ if isinstance(other, dict): other = Module(other) - return _module.Module_Update(self, other) + return _ffi_api.Module_Update(self, other) def get_global_var(self, name): """Get a global variable in the function by name. @@ -145,9 +134,9 @@ def get_global_var(self, name): Raises ------ - tvm.TVMError if we cannot find corresponding global var. + tvm.error.TVMError if we cannot find corresponding global var. """ - return _module.Module_GetGlobalVar(self, name) + return _ffi_api.Module_GetGlobalVar(self, name) def get_global_vars(self): """Collect all global vars defined in this module. @@ -157,7 +146,7 @@ def get_global_vars(self): global_vars: tvm.Array[GlobalVar] An array of global vars. """ - return _module.Module_GetGlobalVars(self) + return _ffi_api.Module_GetGlobalVars(self) def get_global_type_vars(self): """Collect all global type vars defined in this module. @@ -167,7 +156,7 @@ def get_global_type_vars(self): global_type_vars: tvm.Array[GlobalTypeVar] An array of global type vars. """ - return _module.Module_GetGlobalTypeVars(self) + return _ffi_api.Module_GetGlobalTypeVars(self) def get_global_type_var(self, name): """Get a global type variable in the function by name. @@ -184,9 +173,9 @@ def get_global_type_var(self, name): Raises ------ - tvm.TVMError if we cannot find corresponding global type var. + tvm.error.TVMError if we cannot find corresponding global type var. """ - return _module.Module_GetGlobalTypeVar(self, name) + return _ffi_api.Module_GetGlobalTypeVar(self, name) def get_constructor(self, tag): """Look up an ADT constructor by tag. @@ -203,9 +192,9 @@ def get_constructor(self, tag): Raises ------ - tvm.TVMError if the corresponding constructor cannot be found. + tvm.error.TVMError if the corresponding constructor cannot be found. """ - return _module.Module_LookupTag(self, tag) + return _ffi_api.Module_LookupTag(self, tag) @staticmethod def from_expr(expr, functions=None, type_defs=None): @@ -213,14 +202,15 @@ def from_expr(expr, functions=None, type_defs=None): Parameters ---------- - expr: Expr + expr: RelayExpr The starting expression + global_funcs: Optional[dict] Map of global vars to function definitions + type_defs: Optional[dict] Map of global type vars to type definitions - Returns ------- mod: Module @@ -230,10 +220,10 @@ def from_expr(expr, functions=None, type_defs=None): """ funcs = functions if functions is not None else {} defs = type_defs if type_defs is not None else {} - return _module.Module_FromExpr(expr, funcs, defs) + return _ffi_api.Module_FromExpr(expr, funcs, defs) def _import(self, file_to_import): - return _module.Module_Import(self, file_to_import) + return _ffi_api.Module_Import(self, file_to_import) def import_from_std(self, file_to_import): - return _module.Module_ImportFromStd(self, file_to_import) + return _ffi_api.Module_ImportFromStd(self, file_to_import) diff --git a/python/tvm/ir/tensor_type.py b/python/tvm/ir/tensor_type.py new file mode 100644 index 000000000000..99286ed13fd2 --- /dev/null +++ b/python/tvm/ir/tensor_type.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Type relation and function for type checking.""" +import tvm._ffi + +from .type import Type +from . import _ffi_api + + +@tvm._ffi.register_object("relay.TensorType") +class TensorType(Type): + """A concrete TensorType in Relay. + + This is the type assigned to tensors with a known dtype and shape. + For example, a tensor of `float32` and `(5, 5)`. + + Parameters + ---------- + shape : List[tvm.ir.PrimExpr] + The shape of the Tensor + + dtype : Optional[str] + The content data type. + """ + def __init__(self, shape, dtype="float32"): + self.__init_handle_by_constructor__( + _ffi_api.TensorType, shape, dtype) + + @property + def concrete_shape(self): + """Get shape of the type as concrete tuple of int. + + Returns + ------- + shape : List[int] + The concrete shape of the Type. + + Raises + ------ + TypeError : If the shape is symbolic + """ + return tuple(int(x) for x in self.shape) diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py new file mode 100644 index 000000000000..619250459b5c --- /dev/null +++ b/python/tvm/ir/transform.py @@ -0,0 +1,328 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-argument +"""Common pass infrastructure across IR variants.""" +import types +import inspect +import functools + +import tvm._ffi + +from tvm._ffi.runtime_ctypes import TVMContext +from tvm.runtime import Object, ndarray as _nd + +from . import _ffi_transform_api + +@tvm._ffi.register_object("relay.PassInfo") +class PassInfo(Object): + """The class contains the meta data required by a pass. It is the + container of information needed by running an optimization or analysis. + This class can be extended by adding new members when more meta data is + needed. + + Parameters + ---------- + opt_level : int + The optimization level of this pass. + + name : str + The pass name. + + required : List[str] + The list of passes that are required by a certain pass. + """ + + def __init__(self, opt_level, name, required=None): + self.__init_handle_by_constructor__( + _ffi_transform_api.PassInfo, opt_level, name, required) + + +@tvm._ffi.register_object("relay.PassContext") +class PassContext(Object): + """The basis where a Relay optimization/analysis runs on. + Each pass context contains a number of auxiliary information that is used + to help an optimization pass. Such information includes the error reporter + to record the errors of during the optimization, etc. + + opt_level : Optional[int] + The optimization level of this pass. + + fallback_device : Optional[Union[int, str, TVMContext]] + The fallback device type. It is also used as the default device for + operators that are not annotated during heterogeneous execution. + + required_pass : Optional[Union[List[str], Set[str], Tuple[str]]] + The list of passes that are required by a certain pass. + + disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]] + The list of passes that are disabled. + """ + def __init__(self, + opt_level=2, + fallback_device=_nd.cpu(), + required_pass=None, + disabled_pass=None, + trace=None): + if isinstance(fallback_device, str): + fallback_device = _nd.context(fallback_device).device_type + elif isinstance(fallback_device, TVMContext): + fallback_device = fallback_device.device_type + if not isinstance(fallback_device, int): + raise TypeError("required_pass is expected to be the type of " + + "int/str/TVMContext.") + + required = list(required_pass) if required_pass else [] + if not isinstance(required, (list, tuple)): + raise TypeError("required_pass is expected to be the type of " + + "list/tuple/set.") + + disabled = list(disabled_pass) if disabled_pass else [] + if not isinstance(disabled, (list, tuple)): + raise TypeError("disabled_pass is expected to be the type of " + + "list/tuple/set.") + + self.__init_handle_by_constructor__(_ffi_transform_api.PassContext, opt_level, + fallback_device, required, + disabled, trace) + + def __enter__(self): + _ffi_transform_api.EnterPassContext(self) + return self + + def __exit__(self, ptype, value, trace): + _ffi_transform_api.ExitPassContext(self) + + @staticmethod + def current(): + """Return the current pass context.""" + return _ffi_transform_api.GetCurrentPassContext() + + +@tvm._ffi.register_object("relay.Pass") +class Pass(Object): + """The base class of all passes. All methods here are just simple wrappers + that are implemented in the backend. They are defined for users to + conveniently interact with the base class. + """ + + @property + def info(self): + """Get the pass meta.""" + return _ffi_transform_api.Info(self) + + def __call__(self, mod): + """Execute the pass. Note that for sequential pass, the dependency among + different passes will be resolved in the backend. + + Parameters + ---------- + mod : tvm.IRModule + The module that a certain optimization is performed on. + + Returns + ------- + mod : tvm.IRModule + The updated module after applying this pass. + """ + return _ffi_transform_api.RunPass(self, mod) + + +@tvm._ffi.register_object("relay.ModulePass") +class ModulePass(Pass): + """A pass that works on tvm.IRModule. Users don't need to interact with + this class directly. Instead, a module pass should be created through + `module_pass`, because the design of the `module_pass` API is flexible + enough to handle the creation of a module pass in different manners. In + addition, all members of a module pass can be accessed from the base class. + The same rule applies to FunctionPass as well. + """ + + +@tvm._ffi.register_object("relay.Sequential") +class Sequential(Pass): + """A pass that works on a sequence of pass objects. Multiple passes can be + executed sequentially using this class. + + Some typical usage of the sequential pass are: + 1. Users provide a list of passes for optimization. + 2. Only an optimization level is provided so that the backend system has + to glob all passes at this level and below to perform the optimizations. + Note that users can also provide a series of passes that they don't want to + apply when running a sequential pass. Pass dependency will be resolved in + the backend as well. + + Parameters + ---------- + passes : Optional[List[Pass]] + A sequence of passes candidate for optimization. + + opt_level : Optional[int] + The optimization level of this sequential pass. + + name : Optional[str] + The name of the sequential pass. + + required : Optional[List[str]] + The list of passes that the sequential pass is dependent on. + """ + def __init__(self, + passes=None, + opt_level=2, + name="sequential", + required=None): + passes = passes if passes else [] + if not isinstance(passes, (list, tuple)): + raise TypeError("passes must be a list of Pass objects.") + + required = required if required else [] + if not isinstance(required, (list, tuple)): + raise TypeError("Required is expected to be the type of list/tuple.") + + self.__init_handle_by_constructor__(_ffi_transform_api.Sequential, + passes, opt_level, name, required) + + +def _wrap_class_module_pass(pass_cls, pass_info): + """Wrap a python class as function pass""" + class PyModulePass(ModulePass): + """Internal wrapper class to create a class instance.""" + def __init__(self, *args, **kwargs): + # initialize handle in cass pass_cls creation failed.fg + self.handle = None + inst = pass_cls(*args, **kwargs) + # it is important not to capture self to + # avoid a cyclic dependency + def _pass_func(mod, ctx): + return inst.transform_module(mod, ctx) + self.__init_handle_by_constructor__( + _ffi_transform_api.MakeModulePass, _pass_func, pass_info) + self._inst = inst + + def __getattr__(self, name): + # fall back to instance attribute if there is not any + return self._inst.__getattribute__(name) + + functools.update_wrapper(PyModulePass.__init__, pass_cls.__init__) + PyModulePass.__name__ = pass_cls.__name__ + PyModulePass.__doc__ = pass_cls.__doc__ + PyModulePass.__module__ = pass_cls.__module__ + return PyModulePass + + +def module_pass(pass_func=None, opt_level=None, name=None, required=None): + """Decorate a module pass. + + This function returns a callback when pass_func is provided. + Otherwise, it serves a decorator function. + + pass_func can also be a class type with a method transform_module. + This function will create a decorated ModulePass using transform_module + as the pass function. + + Parameters + ---------- + pass_func : Optional[Callable[(Module, PassContext) ->Module]] + The transformation function or class. + + opt_level : int + The optimization level of this module pass. + + name : Optional[str] + The name of the module pass. The name could be empty. In this case, the + name of the optimization function will be used as the pass name. + + required : Optional[List[str]] + The list of passes that the module pass is dependent on. + + Returns + ------- + create_module_pass : Union[Callable, ModulePass] + A decorator will be returned if pass_func is not provided, + otherwise return the decorated result. + The returned decorator has two behaviors depending on the input: + A new ModulePass will be returned when we decorate a pass function. + A new ModulePass class will be returned when we decorate a class type. + + Examples + -------- + The following code block decorates a module pass class. + + .. code-block:: python + + @relay.transform.module_pass + class CustomPipeline: + def __init__(self, enable_fold): + self.enable_fold = enable_fold + self.cse = relay.transform.EliminateCommonSubexpr() + self.const_fold = relay.transform.FoldConstant() + + def transform_module(self, mod, ctx): + mod = self.cse(mod, ctx) + if self.enable_fold: + mod = self.const_fold(mod, ctx) + return mod + + # create an instance of customized pipeline + pipeline = CustomPipeline(enable_fold=False) + assert isinstance(pipeline, transform.ModulePass) + # run the pipeline. + output_module = pipeline(input_module) + + The following code creates a module pass by decorating + a user defined transform function. + + .. code-block:: python + + @relay.transform.module_pass(opt_level=2) + def transform(mod, ctx): + tp = relay.TensorType((10,), "float32") + x = relay.var("x", tp) + gv = relay.GlobalVar("var") + func = relay.Function([x], relay.abs(x)) + new_mod = tvm.IRModule({gv: func}) + new_mod.update(mod) + return new_mod + + module_pass = transform + assert isinstance(module_pass, transform.ModulePass) + assert module_pass.info.opt_level == 2 + + # Given a module m, the optimization could be invoked as the follwoing: + updated_mod = module_pass(m) + # Now a function abs should be added to the module m. + """ + if opt_level is None: + raise ValueError("Please provide opt_level for the module pass.") + + required = required if required else [] + if not isinstance(required, (list, tuple)): + raise TypeError("Required is expected to be the type of " + + "list/tuple.") + + def create_module_pass(pass_arg): + """Internal function that creates a module pass""" + fname = name if name else pass_arg.__name__ + info = PassInfo(opt_level, fname, required) + if inspect.isclass(pass_arg): + return _wrap_class_module_pass(pass_arg, info) + if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): + raise TypeError("pass_func must be a callable for Module pass") + return _ffi_transform_api.MakeModulePass(pass_arg, info) + + if pass_func: + return create_module_pass(pass_func) + return create_module_pass diff --git a/python/tvm/ir/type.py b/python/tvm/ir/type.py new file mode 100644 index 000000000000..ebe2aae047e5 --- /dev/null +++ b/python/tvm/ir/type.py @@ -0,0 +1,204 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unified type system in the project.""" +from enum import IntEnum +import tvm._ffi + +from .base import Node +from . import _ffi_api + + +class Type(Node): + """The base class of all types.""" + def __eq__(self, other): + """Compare two types for structural equivalence.""" + return bool(_ffi_api.type_alpha_equal(self, other)) + + def __ne__(self, other): + return not self.__eq__(other) + + def same_as(self, other): + """Compares two Relay types by referential equality.""" + return super().__eq__(other) + + +class TypeKind(IntEnum): + """Possible kinds of TypeVars.""" + Type = 0 + ShapeVar = 1 + BaseType = 2 + Constraint = 4 + AdtHandle = 5 + TypeData = 6 + + +@tvm._ffi.register_object("relay.TypeVar") +class TypeVar(Type): + """Type parameter in functions. + + A type variable represents a type placeholder which will + be filled in later on. This allows the user to write + functions which are generic over types. + + Parameters + ---------- + name_hint: str + The name of the type variable. This name only acts as a hint, and + is not used for equality. + + kind : Optional[TypeKind] + The kind of the type parameter. + """ + def __init__(self, name_hint, kind=TypeKind.Type): + self.__init_handle_by_constructor__( + _ffi_api.TypeVar, name_hint, kind) + + def __call__(self, *args): + """Create a type call from this type. + + Parameters + ---------- + args: List[Type] + The arguments to the type call. + + Returns + ------- + call: Type + The result type call. + """ + # pylint: disable=import-outside-toplevel + from .type_relation import TypeCall + return TypeCall(self, args) + + +@tvm._ffi.register_object("relay.GlobalTypeVar") +class GlobalTypeVar(Type): + """A global type variable that is used for defining new types or type aliases. + + Parameters + ---------- + name_hint: str + The name of the type variable. This name only acts as a hint, and + is not used for equality. + + kind : Optional[TypeKind] + The kind of the type parameter. + """ + def __init__(self, name_hint, kind=TypeKind.AdtHandle): + self.__init_handle_by_constructor__( + _ffi_api.GlobalTypeVar, name_hint, kind) + + def __call__(self, *args): + """Create a type call from this type. + + Parameters + ---------- + args: List[Type] + The arguments to the type call. + + Returns + ------- + call: Type + The result type call. + """ + # pylint: disable=import-outside-toplevel + from .type_relation import TypeCall + return TypeCall(self, args) + + +@tvm._ffi.register_object("relay.TupleType") +class TupleType(Type): + """The type of tuple values. + + Parameters + ---------- + fields : List[Type] + The fields in the tuple + """ + + def __init__(self, fields): + self.__init_handle_by_constructor__( + _ffi_api.TupleType, fields) + + +@tvm._ffi.register_object("relay.TypeConstraint") +class TypeConstraint(Type): + """Abstract class representing a type constraint.""" + + +@tvm._ffi.register_object("relay.FuncType") +class FuncType(Type): + """Function type. + + A function type consists of a list of type parameters to enable + the definition of generic functions, + a set of type constraints which we omit for the time being, + a sequence of argument types, and a return type. + + We can informally write them as: + `forall (type_params), (arg_types) -> ret_type where type_constraints` + + Parameters + ---------- + arg_types : List[tvm.relay.Type] + The argument types + + ret_type : tvm.relay.Type + The return type. + + type_params : Optional[List[tvm.relay.TypeVar]] + The type parameters + + type_constraints : Optional[List[tvm.relay.TypeConstraint]] + The type constraints. + """ + def __init__(self, + arg_types, + ret_type, + type_params=None, + type_constraints=None): + if type_params is None: + type_params = [] + if type_constraints is None: + type_constraints = [] + self.__init_handle_by_constructor__( + _ffi_api.FuncType, arg_types, ret_type, type_params, type_constraints) + + +@tvm._ffi.register_object("relay.IncompleteType") +class IncompleteType(Type): + """Incomplete type during type inference. + + kind : Optional[TypeKind] + The kind of the incomplete type. + """ + def __init__(self, kind=TypeKind.Type): + self.__init_handle_by_constructor__( + _ffi_api.IncompleteType, kind) + + +@tvm._ffi.register_object("relay.RefType") +class RelayRefType(Type): + """Reference Type in relay. + + Parameters + ---------- + value: Type + The value type. + """ + def __init__(self, value): + self.__init_handle_by_constructor__(_ffi_api.RelayRefType, value) diff --git a/python/tvm/ir/type_relation.py b/python/tvm/ir/type_relation.py new file mode 100644 index 000000000000..63c83d9af042 --- /dev/null +++ b/python/tvm/ir/type_relation.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Type relation and function for type checking.""" +import tvm._ffi + +from .type import Type, TypeConstraint +from . import _ffi_api + + +class TypeCall(Type): + """Type function application. + + Parameters + ---------- + func: tvm.ir.Type + The function. + + args: List[tvm.ir.Type] + The arguments. + + Returns + ------- + type_call: TypeCall + The type function application. + """ + def __init__(self, func, args): + self.__init_handle_by_constructor__(_ffi_api.TypeCall, func, args) + + +@tvm._ffi.register_object("relay.TypeRelation") +class TypeRelation(TypeConstraint): + """User defined type relation, it is an input-output relation on types. + + TypeRelation is more generalized than TypeCall as it allows inference + of both inputs and outputs. + + Parameters + ---------- + func : EnvFunc + User defined relation function. + + args : [tvm.ir.Type] + List of types to the func. + + num_inputs : int + Number of input arguments in args, + this act as a hint for type inference. + + attrs : Attrs + The attribute attached to the relation information + + Returns + ------- + type_relation : tvm.ir.TypeRelation + The type relation. + """ + def __init__(self, func, args, num_inputs, attrs): + self.__init_handle_by_constructor__( + _ffi_api.TypeRelation, func, args, num_inputs, attrs) diff --git a/python/tvm/ir_builder.py b/python/tvm/ir_builder.py index c08b3a54f1ac..4cc7f4f8082d 100644 --- a/python/tvm/ir_builder.py +++ b/python/tvm/ir_builder.py @@ -15,16 +15,15 @@ # specific language governing permissions and limitations # under the License. """Developer API of IR node builder make function.""" +from tvm._ffi.base import string_types from tvm.runtime import ObjectGeneric, DataType - -from ._ffi.base import string_types +from tvm.ir import container as _container from . import api as _api from . import stmt as _stmt from . import expr as _expr from . import make as _make from . import ir_pass as _pass -from . import container as _container from .expr import Call as _Call class WithScope(object): diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 25956b4656fd..0df3747a93b1 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -16,7 +16,6 @@ # under the License. # pylint: disable=wildcard-import, redefined-builtin, invalid-name """The Relay IR namespace containing the IR definition and compiler.""" -from __future__ import absolute_import import os from sys import setrecursionlimit from ..api import register_func @@ -25,7 +24,6 @@ from . import expr from . import type_functor from . import expr_functor -from . import module from . import adt from . import analysis from . import transform @@ -66,14 +64,11 @@ # Span Span = base.Span -# Env -Module = module.Module - # Type Type = ty.Type TupleType = ty.TupleType TensorType = ty.TensorType -Kind = ty.Kind +TypeKind = ty.TypeKind TypeVar = ty.TypeVar ShapeVar = ty.ShapeVar TypeConstraint = ty.TypeConstraint @@ -87,7 +82,7 @@ Any = ty.Any # Expr -Expr = expr.Expr +Expr = expr.RelayExpr Constant = expr.Constant Tuple = expr.Tuple Var = expr.Var diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index 0fd1c105a3d1..c1a413098e84 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -37,8 +37,9 @@ def __new__(cls, *args, **kwds): return deque.__new__(cls, *args, **kwds) import tvm +import tvm.ir._ffi_api +from tvm.ir import IRModule -from . import module from .base import Span, SourceName from . import adt from . import expr @@ -190,7 +191,7 @@ def _wrapper(*args, **kwargs): sp = Span(sn, line, col) if isinstance(ast, tvm.relay.expr.TupleWrapper): ast = ast.astuple() - ast.set_span(sp) + tvm.ir._ffi_api.NodeSetSpan(ast, sp) return ast return _wrapper @@ -201,7 +202,7 @@ class ParseTreeToRelayIR(RelayVisitor): def __init__(self, source_name: str) -> None: self.source_name = source_name - self.module = module.Module({}) # type: module.Module + self.module = IRModule({}) # type: IRModule # Adding an empty scope allows naked lets without pain. self.var_scopes = deque([deque()]) # type: Scopes[expr.Var] @@ -243,7 +244,7 @@ def exit_type_param_scope(self) -> Scope[ty.TypeVar]: """Pop off the current TypeVar scope and return it.""" return self.type_var_scopes.popleft() - def mk_typ(self, name: str, kind: ty.Kind) -> ty.TypeVar: + def mk_typ(self, name: str, kind: ty.TypeKind) -> ty.TypeVar: """Create a new TypeVar and add it to the TypeVar scope.""" typ = ty.TypeVar(name, kind) self.type_var_scopes[0].append((name, typ)) @@ -274,7 +275,7 @@ def _type_expr_name(self, e): if isinstance(e, adt.Constructor): return "`{0}` ADT constructor".format(e.belong_to.name_hint) if isinstance(e, ty.GlobalTypeVar): - if e.kind == ty.Kind.AdtHandle: + if e.kind == ty.TypeKind.AdtHandle: return "ADT definition" return "function definition" @@ -352,12 +353,12 @@ def getTypeExpr(self, ctx: Optional[RelayParser.TypeExprContext]) -> Optional[ty return self.visit(ctx) - def visitProg(self, ctx: RelayParser.ProgContext) -> Union[expr.Expr, module.Module]: + def visitProg(self, ctx: RelayParser.ProgContext) -> Union[expr.Expr, IRModule]: self.meta = None if ctx.METADATA(): header, data = str(ctx.METADATA()).split("\n", 1) assert header == "METADATA:" - self.meta = tvm.load_json(data) + self.meta = tvm.ir.load_json(data) if ctx.defn(): self.visit_list(ctx.defn()) return self.module @@ -492,7 +493,7 @@ def mk_func( assert type_params for ty_param in type_params: name = ty_param.getText() - self.mk_typ(name, ty.Kind.Type) + self.mk_typ(name, ty.TypeKind.Type) var_list, attr_list = self.visit(ctx.argList()) if var_list is None: @@ -528,13 +529,13 @@ def handle_adt_header( ctx: Union[RelayParser.ExternAdtDefnContext, RelayParser.AdtDefnContext]): """Handles parsing of the name and type params of an ADT definition.""" adt_name = ctx.generalIdent().getText() - adt_var = self.mk_global_typ_var(adt_name, ty.Kind.AdtHandle) + adt_var = self.mk_global_typ_var(adt_name, ty.TypeKind.AdtHandle) # parse type params type_params = ctx.typeParamList() if type_params is None: type_params = [] else: - type_params = [self.mk_typ(type_ident.getText(), ty.Kind.Type) + type_params = [self.mk_typ(type_ident.getText(), ty.TypeKind.Type) for type_ident in type_params.typeExpr()] return adt_var, type_params @@ -746,7 +747,7 @@ def reportAttemptingFullContext(self, def reportContextSensitivity(self, recognizer, dfa, startIndex, stopIndex, prediction, configs): raise Exception("Context Sensitivity in:\n" + self.text) -def fromtext(data: str, source_name: str = None) -> Union[expr.Expr, module.Module]: +def fromtext(data: str, source_name: str = None) -> Union[expr.Expr, IRModule]: """Parse a Relay program.""" if data == "": raise ParseError("cannot parse the empty string.") diff --git a/python/tvm/relay/adt.py b/python/tvm/relay/adt.py index 7f7496b1a407..9c5dac6362e2 100644 --- a/python/tvm/relay/adt.py +++ b/python/tvm/relay/adt.py @@ -14,12 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import """Algebraic data types in Relay.""" +from tvm.ir import Constructor, TypeData + from .base import RelayNode, register_relay_node, Object from . import _make from .ty import Type -from .expr import Expr, Call +from .expr import ExprWithOp, RelayExpr, Call class Pattern(RelayNode): @@ -112,77 +114,6 @@ def __init__(self, patterns=None): self.__init_handle_by_constructor__(_make.PatternTuple, patterns) -@register_relay_node -class Constructor(Expr): - """Relay ADT constructor.""" - - def __init__(self, name_hint, inputs, belong_to): - """Defines an ADT constructor. - - Parameters - ---------- - name_hint : str - Name of constructor (only a hint). - inputs : List[Type] - Input types. - belong_to : tvm.relay.GlobalTypeVar - Denotes which ADT the constructor belongs to. - - Returns - ------- - con: Constructor - A constructor. - """ - self.__init_handle_by_constructor__(_make.Constructor, name_hint, inputs, belong_to) - - def __call__(self, *args): - """Call the constructor. - - Parameters - ---------- - args: List[relay.Expr] - The arguments to the constructor. - - Returns - ------- - call: relay.Call - A call to the constructor. - """ - return Call(self, args) - - -@register_relay_node -class TypeData(Type): - """Stores the definition for an Algebraic Data Type (ADT) in Relay. - - Note that ADT definitions are treated as type-level functions because - the type parameters need to be given for an instance of the ADT. Thus, - any global type var that is an ADT header needs to be wrapped in a - type call that passes in the type params. - """ - - def __init__(self, header, type_vars, constructors): - """Defines a TypeData object. - - Parameters - ---------- - header: tvm.relay.GlobalTypeVar - The name of the ADT. - ADTs with the same constructors but different names are - treated as different types. - type_vars: List[TypeVar] - Type variables that appear in constructors. - constructors: List[tvm.relay.Constructor] - The constructors for the ADT. - - Returns - ------- - type_data: TypeData - The adt declaration. - """ - self.__init_handle_by_constructor__(_make.TypeData, header, type_vars, constructors) - - @register_relay_node class Clause(Object): """Clause for pattern matching in Relay.""" @@ -206,7 +137,7 @@ def __init__(self, lhs, rhs): @register_relay_node -class Match(Expr): +class Match(ExprWithOp): """Pattern matching expression in Relay.""" def __init__(self, data, clauses, complete=True): diff --git a/python/tvm/relay/analysis.py b/python/tvm/relay/analysis.py index 5220b8650179..d0c172418fba 100644 --- a/python/tvm/relay/analysis.py +++ b/python/tvm/relay/analysis.py @@ -20,11 +20,11 @@ This file contains the set of passes for Relay, which exposes an interface for configuring the passes and scripting them in Python. """ +from tvm.ir import RelayExpr, IRModule + from . import _analysis from . import _make -from .expr import Expr from .ty import Type -from .module import Module from .feature import Feature @@ -70,7 +70,7 @@ def check_kind(t, mod=None): t : tvm.relay.Type The type to check - mod : Optional[tvm.relay.Module] + mod : Optional[tvm.IRModule] The global module. Returns @@ -169,7 +169,7 @@ def free_type_vars(expr, mod=None): expr : Union[tvm.relay.Expr,tvm.relay.Type] The input expression/type - mod : Optional[tvm.relay.Module] + mod : Optional[tvm.IRModule] The global module Returns @@ -177,7 +177,7 @@ def free_type_vars(expr, mod=None): free : List[tvm.relay.TypeVar] The list of free type variables in post-DFS order """ - use_mod = mod if mod is not None else Module() + use_mod = mod if mod is not None else IRModule() return _analysis.free_type_vars(expr, use_mod) @@ -189,7 +189,7 @@ def bound_type_vars(expr, mod=None): expr : Union[tvm.relay.Expr,tvm.relay.Type] The input expression/type - mod : Optional[tvm.relay.Module] + mod : Optional[tvm.IRModule] The global module Returns @@ -197,7 +197,7 @@ def bound_type_vars(expr, mod=None): free : List[tvm.relay.TypeVar] The list of bound type variables in post-DFS order """ - use_mod = mod if mod is not None else Module() + use_mod = mod if mod is not None else IRModule() return _analysis.bound_type_vars(expr, use_mod) @@ -209,7 +209,7 @@ def all_type_vars(expr, mod=None): expr : Union[tvm.relay.Expr,tvm.relay.Type] The input expression/type - mod : Optional[tvm.relay.Module] + mod : Optional[tvm.IRModule] The global module Returns @@ -217,7 +217,7 @@ def all_type_vars(expr, mod=None): free : List[tvm.relay.TypeVar] The list of all type variables in post-DFS order """ - use_mod = mod if mod is not None else Module() + use_mod = mod if mod is not None else IRModule() return _analysis.all_type_vars(expr, use_mod) @@ -353,7 +353,7 @@ def unmatched_cases(match, mod=None): match : tvm.relay.Match The match expression - mod : Optional[tvm.relay.Module] + mod : Optional[tvm.IRModule] The module (defaults to an empty module) Returns @@ -370,10 +370,10 @@ def detect_feature(a, b=None): Parameters ---------- - a : Union[tvm.relay.Expr, tvm.relay.Module] + a : Union[tvm.relay.Expr, tvm.IRModule] The input expression or module. - b : Optional[Union[tvm.relay.Expr, tvm.relay.Module]] + b : Optional[Union[tvm.relay.Expr, tvm.IRModule]] The input expression or module. The two arguments cannot both be expression or module. @@ -382,7 +382,7 @@ def detect_feature(a, b=None): features : Set[Feature] Features used in the program. """ - if isinstance(a, Module): + if isinstance(a, IRModule): a, b = b, a return {Feature(int(x)) for x in _analysis.detect_feature(a, b)} @@ -400,7 +400,7 @@ def structural_hash(value): result : int The hash value """ - if isinstance(value, Expr): + if isinstance(value, RelayExpr): return int(_analysis._expr_hash(value)) elif isinstance(value, Type): return int(_analysis._type_hash(value)) diff --git a/python/tvm/relay/backend/_backend.py b/python/tvm/relay/backend/_backend.py index 270c38e4f523..c2f1df915509 100644 --- a/python/tvm/relay/backend/_backend.py +++ b/python/tvm/relay/backend/_backend.py @@ -16,9 +16,9 @@ # under the License. """The interface of expr function exposed from C++.""" import tvm._ffi +from tvm.ir import container as _container from ... import build_module as _build -from ... import container as _container @tvm._ffi.register_func("relay.backend.lower") diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index 58596ec4f247..18f848c212b2 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -21,9 +21,10 @@ import numpy as np from tvm.runtime import container +from tvm.ir import IRModule + from . import _backend from .. import _make, analysis, transform -from .. import module from ... import nd from ..base import Object, register_relay_node from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const @@ -186,10 +187,10 @@ class Interpreter(Executor): Parameters ---------- - mod : tvm.relay.Module + mod : tvm.IRModule The module to support the execution. - ctx : tvm.TVMContext + ctx : tvmContext The runtime context to run the code on. target : tvm.Target @@ -205,7 +206,7 @@ def optimize(self): Returns ------- - opt_mod : tvm.relay.Module + opt_mod : tvm.IRModule The optimized module. """ seq = transform.Sequential([transform.SimplifyInference(), @@ -239,7 +240,7 @@ def _interp_wrapper(*args, **kwargs): if self.mod: self.mod["main"] = func else: - self.mod = module.Module.from_expr(func) + self.mod = IRModule.from_expr(func) mod = self.optimize() opt_expr = Call(mod["main"], relay_args) diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index b61fafda1d1d..557b9fd6e46d 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -36,7 +36,7 @@ def compile(mod, target=None, target_host=None, params=None): Parameters ---------- - mod : relay.Module + mod : tvm.IRModule The Relay module to build. target : str, :any:`tvm.target.Target`, or dict of str(i.e. @@ -110,7 +110,7 @@ def lower(self, mod, target=None, target_host=None): Parameters ---------- - mod : relay.Module + mod : tvm.IRModule The Relay module to build. target : str, :any:`tvm.target.Target`, or dict of str(i.e. @@ -142,7 +142,7 @@ def optimize(self, mod, target=None, params=None): Parameters ---------- - mod : relay.Module + mod : tvm.IRModule target : str, :any:`tvm.target.Target`, or dict of str (i.e. device/context name) to str/tvm.target.Target, optional @@ -153,7 +153,7 @@ def optimize(self, mod, target=None, params=None): Returns ------- - mod : relay.Module + mod : tvm.IRModule The optimized relay module. params : dict @@ -229,10 +229,10 @@ class VMExecutor(Executor): Parameters ---------- - mod : :py:class:`~tvm.relay.module.Module` + mod : :py:class:`~tvm.IRModule` The module to support the execution. - ctx : :py:class:`~tvm.TVMContext` + ctx : :py:class:`~tvmContext` The runtime context to run the code on. target : :py:class:`Target` diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py index bc041252a668..0d6f22f446cd 100644 --- a/python/tvm/relay/base.py +++ b/python/tvm/relay/base.py @@ -14,16 +14,25 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=no-else-return, unidiomatic-typecheck +# pylint: disable=no-else-return, unidiomatic-typecheck, unused-import """The base node types for the Relay language.""" +import os import tvm._ffi from tvm.runtime import Object +from tvm.ir import SourceName, Span, Node as RelayNode from . import _make from . import _expr from . import _base +__STD_PATH__ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "std") + +@tvm._ffi.register_func("tvm.relay.std_path") +def _std_path(): + return __STD_PATH__ + + def register_relay_node(type_key=None): """Register a Relay node type. @@ -52,55 +61,6 @@ def register_relay_attr_node(type_key=None): return tvm._ffi.register_object(type_key) -class RelayNode(Object): - """Base class of all Relay nodes.""" - def astext(self, show_meta_data=True, annotate=None): - """Get the text format of the expression. - - Parameters - ---------- - show_meta_data : bool - Whether to include meta data section in the text - if there is meta data. - - annotate: Optional[relay.Expr->str] - Optional annotate function to provide additional - information in the comment block. - - Note - ---- - The meta data section is necessary to fully parse the text format. - However, it can contain dumps that are big (e.g constant weights), - so it can be helpful to skip printing the meta data section. - - Returns - ------- - text : str - The text format of the expression. - """ - return _expr.AsText(self, show_meta_data, annotate) - - def set_span(self, span): - _base.set_span(self, span) - - def __str__(self): - return self.astext(show_meta_data=False) - - -@register_relay_node -class Span(RelayNode): - """Specifies a location in a source program.""" - - def __init__(self, source, lineno, col_offset): - self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset) - -@register_relay_node -class SourceName(RelayNode): - """A identifier for a source location""" - - def __init__(self, name): - self.__init_handle_by_constructor__(_make.SourceName, name) - @register_relay_node class Id(Object): """Unique identifier(name) used in Var. diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index ea7a4cacfc60..fa812cb35703 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -21,13 +21,14 @@ import warnings import numpy as np +from tvm.ir import IRModule + from tvm import expr as tvm_expr from .. import nd as _nd, target as _target, autotvm from ..contrib import graph_runtime as _graph_rt from . import _build_module from . import ty as _ty from . import expr as _expr -from .module import Module as _Module from .backend import interpreter as _interpreter from .backend.vm import VMExecutor @@ -141,7 +142,7 @@ def optimize(self, func, target=None, params=None): Returns ------- - mod : relay.Module + mod : tvm.IRModule The optimized relay module. params : dict @@ -185,7 +186,7 @@ def build(mod, target=None, target_host=None, params=None): Parameters ---------- - mod : relay.Module + mod : tvm.IRModule The module to build. Using relay.Function is deprecated. target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context @@ -217,16 +218,16 @@ def build(mod, target=None, target_host=None, params=None): params : dict The parameters of the final graph. """ - if isinstance(mod, _Module): + if isinstance(mod, IRModule): func = mod["main"] elif isinstance(mod, _expr.Function): func = mod warnings.warn( - "Please use input parameter mod (tvm.relay.module.Module) " + "Please use input parameter mod (tvm.IRModule) " "instead of deprecated parameter func (tvm.relay.expr.Function)", DeprecationWarning) else: - raise ValueError("Type of input parameter mod must be tvm.relay.module.Module") + raise ValueError("Type of input parameter mod must be tvm.IRModule") target = _update_target(target) @@ -254,7 +255,7 @@ def optimize(mod, target=None, params=None): Parameters ---------- - mod : relay.Module + mod : tvm.IRModule The module to build. Using relay.Function is deprecated. target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context @@ -268,7 +269,7 @@ def optimize(mod, target=None, params=None): Returns ------- - mod : relay.Module + mod : tvm.IRModule The optimized relay module. params : dict @@ -279,11 +280,11 @@ def optimize(mod, target=None, params=None): elif isinstance(mod, _expr.Function): func = mod warnings.warn( - "Please use input parameter mod (tvm.relay.module.Module) " + "Please use input parameter mod (tvm.IRModule) " "instead of deprecated parameter func (tvm.relay.expr.Function)", DeprecationWarning) else: - raise ValueError("Type of input parameter mod must be tvm.relay.module.Module") + raise ValueError("Type of input parameter mod must be tvm.IRModule") target = _update_target(target) @@ -330,7 +331,7 @@ class GraphExecutor(_interpreter.Executor): Parameters ---------- - mod : :py:class:`~tvm.relay.module.Module` + mod : :py:class:`~tvm.IRModule` The module to support the execution. ctx : :py:class:`TVMContext` @@ -385,17 +386,17 @@ def create_executor(kind="debug", kind : str The type of executor - mod : :py:class:`~tvm.relay.module.Module` + mod : :py:class:`~tvm.IRModule` The Relay module containing collection of functions - ctx : :py:class:`tvm.TVMContext` + ctx : :py:class:`tvmContext` The context to execute the code. target : :py:class:`tvm.Target` The corresponding context """ if mod is None: - mod = _Module() + mod = IRModule() if ctx is not None: assert ctx.device_type == _nd.context(str(target), 0).device_type else: diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 5add5e76a680..e5259fbc0da8 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name +# pylint: disable=no-else-return, invalid-name, unused-import """The expression nodes of Relay.""" from __future__ import absolute_import from numbers import Number as _Number @@ -22,33 +22,21 @@ import numpy as _np from tvm._ffi import base as _base from tvm.runtime import NDArray, convert, ndarray as _nd +from tvm.ir import RelayExpr, GlobalVar, BaseFunc from .base import RelayNode, register_relay_node from . import _make from . import _expr from . import ty as _ty +# alias relay expr as Expr. +Expr = RelayExpr # will be registered afterwards _op_make = None -class Expr(RelayNode): - """The base type for all Relay expressions.""" - @property - def checked_type(self): - """Get the checked type of tvm.relay.Expr. - - Returns - ------- - checked_type : tvm.relay.Type - The checked type. - """ - ret = self._checked_type_ - if ret is None: - raise ValueError("The type checker has not populated" - " the checked_type for this node") - return ret - +class ExprWithOp(RelayExpr): + """Basetype of all relay expressions that defines op overloading.""" def astype(self, dtype): """Cast the content type of the current data to dtype. @@ -173,7 +161,7 @@ def __call__(self, *args): return Call(self, args) @register_relay_node -class Constant(Expr): +class Constant(ExprWithOp): """A constant expression in Relay. Parameters @@ -186,7 +174,7 @@ def __init__(self, data): @register_relay_node -class Tuple(Expr): +class Tuple(ExprWithOp): """Tuple expression that groups several fields together. Parameters @@ -210,7 +198,7 @@ def astype(self, _): @register_relay_node -class Var(Expr): +class Var(ExprWithOp): """A local variable in Relay. Local variable can be used to declare input @@ -238,33 +226,7 @@ def name_hint(self): @register_relay_node -class GlobalVar(Expr): - """A global variable in Tvm.Relay. - - GlobalVar is used to refer to the global functions - stored in the module. - - Parameters - ---------- - name_hint: str - The name of the variable. - """ - def __init__(self, name_hint): - self.__init_handle_by_constructor__(_make.GlobalVar, name_hint) - - def __call__(self, *args): - """Invoke the gobal function. - - Parameters - ---------- - args: List[relay.Expr] - Arguments. - """ - return Call(self, args, None, None) - - -@register_relay_node -class Function(Expr): +class Function(BaseFunc): """A function declaration expression. Parameters @@ -320,7 +282,7 @@ def set_attribute(self, name, ref): @register_relay_node -class Call(Expr): +class Call(ExprWithOp): """Function call node in Relay. Call node corresponds the operator application node @@ -349,7 +311,7 @@ def __init__(self, op, args, attrs=None, type_args=None): @register_relay_node -class Let(Expr): +class Let(ExprWithOp): """Let variable binding expression. Parameters @@ -369,7 +331,7 @@ def __init__(self, variable, value, body): @register_relay_node -class If(Expr): +class If(ExprWithOp): """A conditional expression in Relay. Parameters @@ -389,7 +351,7 @@ def __init__(self, cond, true_branch, false_branch): @register_relay_node -class TupleGetItem(Expr): +class TupleGetItem(ExprWithOp): """Get index-th item from a tuple. Parameters @@ -406,7 +368,7 @@ def __init__(self, tuple_value, index): @register_relay_node -class RefCreate(Expr): +class RefCreate(ExprWithOp): """Create a new reference from initial value. Parameters ---------- @@ -418,7 +380,7 @@ def __init__(self, value): @register_relay_node -class RefRead(Expr): +class RefRead(ExprWithOp): """Get the value inside the reference. Parameters ---------- @@ -430,7 +392,7 @@ def __init__(self, ref): @register_relay_node -class RefWrite(Expr): +class RefWrite(ExprWithOp): """ Update the value inside the reference. The whole expression will evaluate to an empty tuple. @@ -445,7 +407,7 @@ def __init__(self, ref, value): self.__init_handle_by_constructor__(_make.RefWrite, ref, value) -class TempExpr(Expr): +class TempExpr(ExprWithOp): """Baseclass of all TempExpr. TempExprs are pass specific expression that can be diff --git a/python/tvm/relay/expr.pyi b/python/tvm/relay/expr.pyi deleted file mode 100644 index d2d01720f5ff..000000000000 --- a/python/tvm/relay/expr.pyi +++ /dev/null @@ -1,131 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import List -import tvm -from .base import Span, Object -from .ty import Type, TypeParam -from ._analysis import _get_checked_type - - -class Expr(Object): - def checked_type(self): - ... - - def __call__(self, *args): - ... - - -class Constant(Expr): - data = ... # type: tvm.nd.NDArray - - def __init__(self, data): - # type: (tvm.nd.NDArray) -> None - ... - - -class Tuple(Expr): - fields = ... # type: List[Expr] - - def __init__(self, fields): - # type: (List[Expr]) -> None - ... - - -class Var(Expr): - """A local variable in Relay.""" - name_hint = ... # type: str - - def __init__(self, name_hint): - # type: (str) -> None - ... - - -class GlobalVar(Expr): - name_hint = ... # type: str - - def __init__(self, name_hint): - # type: (str) -> None - ... - - -class Param(Expr): - var = ... # type: Var - type = ... # type: Type - - def __init__(self, var, ty): - # type: (Var, Type) -> None - ... - - -class Function(Expr): - """A function in Relay, see tvm/relay/expr.h for more details.""" - type_params = ... # type: List[TypeParam] - params = ... # type: List[Param] - ret_type = ... # type: Type - body = ... # type: Expr - - def __init__(self, - params, # type: List[Param], - ret_type, # type: Type, - body, # type: Expr, - type_params=None, # type: List[TypeParam] - ): - # type: (...) -> None - ... - - -@register_relay_node -class Call(Expr): - """A function call in Relay, see tvm/relay/expr.h for more details.""" - op = ... # type: Expr - args = ... # type: List[Expr] - # todo(@jroesch): add attrs. revise attrs type in __init__ - - def __init__(self, op, args, attrs=None, ty_args=None): - # type: (Expr, List[Expr], Optional[List[Any]], Optional[List[Type]]) -> None - if not ty_args: - ty_args = [] - - self.__init_handle_by_constructor__( - _make.Call, op, args, attrs, ty_args) - - -@register_relay_node -class Let(Expr): - """A variable bindings in Relay, see tvm/relay/expr.h for more details.""" - var = ... # type: Var - value = ... # type: Expr - body = ... # type: Expr - value_type = ... # type: Type - - def __init__(self, var, value, body, value_type): - # type: (Var, Expr, Expr, Type) -> None - ... - - -@register_relay_node -class If(Expr): - """A conditional expression in Relay, see tvm/relay/expr.h for more details.""" - cond = ... # type: Expr - true_value = ... # type: Expr - false_value = ... # type: Expr - span = ... # type: Span - - def __init__(self, cond, true_value, false_value): - # type: (Expr, Expr, Expr) -> None - ... diff --git a/python/tvm/relay/frontend/caffe2.py b/python/tvm/relay/frontend/caffe2.py index 566851d7f7ed..da0cc6479818 100644 --- a/python/tvm/relay/frontend/caffe2.py +++ b/python/tvm/relay/frontend/caffe2.py @@ -16,11 +16,11 @@ # under the License. # pylint: disable=import-self, invalid-name, line-too-long, unused-argument """Caffe2 frontend""" -from __future__ import absolute_import as _abs import tvm +from tvm.ir import IRModule + from .. import analysis from .. import expr as _expr -from .. import module as _module from .. import op as _op from ... import nd as _nd from .common import AttrCvt, Renamer @@ -383,7 +383,7 @@ def __init__(self, shape, dtype): self._ops = {} self._shape = shape self._dtype = dtype - self._mod = _module.Module({}) + self._mod = IRModule({}) def from_caffe2(self, init_net, predict_net): """Construct Relay expression from caffe2 graph. @@ -395,7 +395,7 @@ def from_caffe2(self, init_net, predict_net): Returns ------- - mod : tvm.relay.Module + mod : tvm.IRModule The module that optimizations will be performed on. params : dict @@ -565,7 +565,7 @@ def from_caffe2(init_net, predict_net, shape=None, dtype="float32"): Returns ------- - mod : tvm.relay.Module + mod : tvm.IRModule The module that optimizations will be performed on. params : dict of str to tvm.nd.NDArray diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index a0af826de32b..d427fe953085 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -20,9 +20,10 @@ import numpy as np import tvm +from tvm.ir import IRModule from topi.util import get_const_tuple + from .. import expr as _expr -from .. import module as _module from .. import transform as _transform from .. import op as _op from .. import analysis @@ -453,7 +454,7 @@ def get_name(node): def infer_type(node, mod=None): """A method to infer the type of an intermediate node in the relay graph.""" - new_mod = _module.Module.from_expr(node) + new_mod = IRModule.from_expr(node) if mod is not None: new_mod.update(mod) new_mod = _transform.InferType()(new_mod) diff --git a/python/tvm/relay/frontend/coreml.py b/python/tvm/relay/frontend/coreml.py index 719a2783fd3b..99a3930a4ea1 100644 --- a/python/tvm/relay/frontend/coreml.py +++ b/python/tvm/relay/frontend/coreml.py @@ -21,9 +21,10 @@ import math import numpy as np import tvm +from tvm.ir import IRModule + from .. import analysis from .. import expr as _expr -from .. import module as _module from .. import op as _op from ... import nd as _nd from ..._ffi import base as _base @@ -449,7 +450,7 @@ def from_coreml(model, shape=None): Returns ------- - mod : tvm.relay.Module + mod : tvm.IRModule The relay module for compilation. params : dict of str to tvm.nd.NDArray @@ -505,4 +506,4 @@ def from_coreml(model, shape=None): outexpr = outexpr[0] func = _expr.Function(analysis.free_vars(outexpr), outexpr) params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()} - return _module.Module.from_expr(func), params + return IRModule.from_expr(func), params diff --git a/python/tvm/relay/frontend/darknet.py b/python/tvm/relay/frontend/darknet.py index 0ed7b2112383..7623df293cb9 100644 --- a/python/tvm/relay/frontend/darknet.py +++ b/python/tvm/relay/frontend/darknet.py @@ -23,9 +23,10 @@ from enum import Enum import numpy as np import tvm +from tvm.ir import IRModule + from .. import analysis from .. import expr as _expr -from .. import module as _module from .common import get_relay_op, new_var __all__ = ['from_darknet'] @@ -822,7 +823,7 @@ def from_darknet(self): outputs = _as_list(sym) + self._outs outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) sym = _expr.Function(analysis.free_vars(outputs), outputs) - return _module.Module.from_expr(sym), self._tvmparams + return IRModule.from_expr(sym), self._tvmparams def from_darknet(net, shape=None, @@ -840,7 +841,7 @@ def from_darknet(net, Returns ------- - mod : tvm.relay.Module + mod : tvm.IRModule The relay module for compilation. params : dict of str to tvm.nd.NDArray diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index 740d60073906..d21f1af124ca 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -19,9 +19,10 @@ import sys import numpy as np import tvm +from tvm.ir import IRModule + from .. import analysis from .. import expr as _expr -from .. import module as _module from .. import op as _op from ... import nd as _nd from .common import ExprTable, new_var @@ -752,7 +753,7 @@ def from_keras(model, shape=None): Returns ------- - mod : tvm.relay.Module + mod : tvm.IRModule The relay module for compilation. params : dict of str to tvm.nd.NDArray @@ -837,4 +838,4 @@ def _convert_input_layer(keras_layer): outexpr = outexpr[0] if len(outexpr) == 1 else _expr.Tuple(outexpr) func = _expr.Function(analysis.free_vars(outexpr), outexpr) params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()} - return _module.Module.from_expr(func), params + return IRModule.from_expr(func), params diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 97e28a933c89..d74277bbe402 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -21,12 +21,13 @@ import json import numpy as np import tvm +from tvm.ir import IRModule + from tvm import relay from topi.util import get_const_tuple from .. import analysis from .. import expr as _expr from .. import op as _op -from .. import module as _module from .. import scope_builder as _scope_builder from ... import nd as _nd @@ -1902,7 +1903,7 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, params=None, mod=None): dtype_info : dict or str. Known parameter dtypes - mod : tvm.relay.Module + mod : tvm.IRModule The module that contains global information. It will be used for converting ops that need global information, e.g. control-flow ops. @@ -2009,7 +2010,7 @@ def from_mxnet(symbol, Returns ------- - mod : tvm.relay.Module + mod : tvm.IRModule The relay module for compilation params : dict of str to tvm.nd.NDArray @@ -2020,7 +2021,7 @@ def from_mxnet(symbol, except ImportError as e: raise ImportError("{}. MXNet is required to parse symbols.".format(e)) - mod = _module.Module() + mod = IRModule() if isinstance(symbol, mx.sym.Symbol): params = {} arg_params = arg_params if arg_params else {} diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 9ecd950e3a3c..38ead20d1c9d 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -17,14 +17,13 @@ # pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines # pylint: disable=import-outside-toplevel """ONNX: Open Neural Network Exchange frontend for Relay.""" -from __future__ import absolute_import as _abs - import numpy as np import tvm +from tvm.ir import IRModule + from ... import nd as _nd from .. import analysis from .. import expr as _expr -from .. import module as _module from .. import op as _op from .common import AttrCvt, Renamer from .common import get_relay_op, new_var, infer_shape, infer_channels @@ -1615,7 +1614,7 @@ def from_onnx(self, graph, opset): Returns ------- - mod : tvm.relay.Module + mod : tvm.IRModule The returned relay module params : dict @@ -1708,7 +1707,7 @@ def from_onnx(self, graph, opset): outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) func = _expr.Function(analysis.free_vars(outputs), outputs) - return _module.Module.from_expr(func), self._params + return IRModule.from_expr(func), self._params def _parse_value_proto(self, value_proto): """Parse ValueProto or raw str.""" @@ -1836,7 +1835,7 @@ def from_onnx(model, Returns ------- - mod : tvm.relay.Module + mod : tvm.IRModule The relay module for compilation params : dict of str to tvm.nd.NDArray diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 3aeb1d4f3d6d..ac2ea9d0b1bb 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -29,13 +29,13 @@ import tvm +from tvm.ir import IRModule from tvm.relay.prelude import Prelude from .. import analysis from .. import expr as _expr from .. import op as _op from ..expr_functor import ExprMutator -from .. import module as _module from .common import AttrCvt, get_relay_op from .common import infer_type as _infer_type from .common import infer_shape as _infer_shape @@ -2136,7 +2136,7 @@ def __init__(self): self._input_shapes = {} self._loops = {} self._branches = {} - self._mod = _module.Module({}) + self._mod = IRModule({}) self._prelude = Prelude(self._mod) def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): @@ -2171,7 +2171,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): Returns ------- - mod : tvm.relay.Module + mod : tvm.IRModule The module that optimizations will be performed on. params : dict @@ -2653,7 +2653,7 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None): Returns ------- - mod : tvm.relay.Module + mod : tvm.IRModule The module that optimizations will be performed on. params : dict of str to tvm.nd.NDArray diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index ab630472c372..a0b0c0fce526 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -17,14 +17,14 @@ # pylint: disable=invalid-name, unused-argument, too-many-lines, import-outside-toplevel """Tensorflow lite frontend.""" -from __future__ import absolute_import as _abs import math import numpy as np import tvm +from tvm.ir import IRModule + from tvm import relay from .. import analysis from .. import expr as _expr -from .. import module as _module from .. import op as _op from .. import qnn as _qnn from ..util import get_scalar_from_constant @@ -1901,7 +1901,7 @@ def from_tflite(model, shape_dict, dtype_dict): Returns ------- - mod : tvm.relay.Module + mod : tvm.IRModule The relay module for compilation. params : dict of str to tvm.nd.NDArray @@ -1940,5 +1940,5 @@ def from_tflite(model, shape_dict, dtype_dict): outputs = [exp_tab.get_expr(get_tensor_name(subgraph, i)) for i in model_outputs] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) func = _expr.Function(analysis.free_vars(outputs), outputs) - mod = _module.Module.from_expr(func) + mod = IRModule.from_expr(func) return mod, params diff --git a/python/tvm/relay/memory_alloc.py b/python/tvm/relay/memory_alloc.py index f93aa9eeaf2f..d61c6f1d6fba 100644 --- a/python/tvm/relay/memory_alloc.py +++ b/python/tvm/relay/memory_alloc.py @@ -176,7 +176,7 @@ def visit_call(self, call): view = LinearizeRetType(ret_type) out_types = view.unpack() - is_dynamic = ret_type.is_dynamic() + is_dynamic = ty.type_has_any(ret_type) # TODO(@jroesch): restore this code, more complex then it seems # for arg in call.args: # is_dynamic = is_dynamic or arg.checked_type.is_dynamic() diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index c2ec6ad2d22d..bcd58ba5b1b1 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -41,7 +41,6 @@ from . import _transform from . import _reduce from . import _algorithm -from ..expr import Expr from ..base import register_relay_node diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index fcd9e99d1440..3fdafd5b8628 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -275,7 +275,7 @@ def legalize_conv2d(attrs, inputs, types): Parameters ---------- - attrs : tvm.attrs.Attrs + attrs : tvm.ir.Attrs Attributes of current convolution inputs : list of tvm.relay.Expr The args of the Relay expr to be legalized @@ -296,7 +296,7 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layout): Parameters ---------- - attrs : tvm.attrs.Attrs + attrs : tvm.ir.Attrs Attributes of current convolution inputs : list of tvm.relay.Expr The args of the Relay expr to be legalized @@ -413,7 +413,7 @@ def legalize_conv2d_transpose(attrs, inputs, types): Parameters ---------- - attrs : tvm.attrs.Attrs + attrs : tvm.ir.Attrs Attributes of current Transposed convolution inputs : list of tvm.relay.Expr The args of the Relay expr to be legalized @@ -947,7 +947,7 @@ def legalize_bitserial_conv2d(attrs, inputs, types): Parameters ---------- - attrs : tvm.attrs.Attrs + attrs : tvm.ir.Attrs Attributes of current convolution inputs : list of tvm.relay.Expr The args of the Relay expr to be legalized diff --git a/python/tvm/relay/op/nn/util.py b/python/tvm/relay/op/nn/util.py index ba536ad39936..323ef7f9310e 100644 --- a/python/tvm/relay/op/nn/util.py +++ b/python/tvm/relay/op/nn/util.py @@ -16,8 +16,8 @@ # under the License. # pylint: disable=invalid-name, unused-variable """NN operator common utilities""" -from __future__ import absolute_import -from .... import container +from tvm.ir import container + def get_pad_tuple2d(padding): """Common code to get the pad option diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index f9bc853282bb..c74201ef9c1f 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -20,13 +20,13 @@ import tvm._ffi from ..base import register_relay_node -from ..expr import Expr +from ..expr import RelayExpr from ...api import register_func from ...build_module import lower, build from . import _make @register_relay_node -class Op(Expr): +class Op(RelayExpr): """A Relay operator definition.""" def __init__(self): diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 2da35daba225..12abf4a787db 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -16,7 +16,7 @@ # under the License. """The attributes node used for Relay operators""" -from ...attrs import Attrs +from tvm.ir import Attrs from ..base import register_relay_attr_node diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 94a75749ce5c..5288a2e08011 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -16,13 +16,15 @@ # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """A prelude containing useful global functions and ADT definitions.""" +from tvm.ir import IRModule + from .ty import GlobalTypeVar, TensorType, Any, scalar_type from .expr import Var, Function, GlobalVar, If, const from .op.tensor import add, subtract, equal from .adt import Constructor, TypeData, Clause, Match from .adt import PatternConstructor, PatternVar, PatternWildcard from . import op -from .module import Module + class TensorArrayOps(object): """Contains tensor array related ops""" @@ -648,7 +650,7 @@ class Prelude: def __init__(self, mod=None): if mod is None: - mod = Module() + mod = IRModule() self.mod = mod self.load_prelude() diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index f57fef233a1f..22785eec6b41 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -63,7 +63,7 @@ def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op): Parameters ---------- - attrs : tvm.attrs.Attrs + attrs : tvm.ir.Attrs Attributes of current convolution inputs : list of tvm.relay.Expr The args of the Relay expr to be legalized @@ -106,7 +106,7 @@ def helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay_op): Parameters ---------- - attrs : tvm.attrs.Attrs + attrs : tvm.ir.Attrs Attributes of current convolution inputs : list of tvm.relay.Expr The args of the Relay expr to be legalized @@ -169,7 +169,7 @@ def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op): Parameters ---------- - attrs : tvm.attrs.Attrs + attrs : tvm.ir.Attrs Attributes of current convolution inputs : list of tvm.relay.Expr The args of the Relay expr to be legalized diff --git a/python/tvm/relay/qnn/transform.py b/python/tvm/relay/qnn/transform.py index a76bdaf6310f..6d38490b2d19 100644 --- a/python/tvm/relay/qnn/transform.py +++ b/python/tvm/relay/qnn/transform.py @@ -42,7 +42,7 @@ def CanonicalizeOps(): # We want to utilize all the existing Relay infrastructure. So, instead of supporting this # QNN requantize op, we convert it into a sequence of existing Relay operators. - mod = relay.Module.from_expr(qnn_expr) + mod = tvm.IRModule.from_expr(qnn_expr) mod = relay.qnn.transform.CanonicalizeOps()(mod) relay_expr = mod['main'] print(relay_expr) diff --git a/python/tvm/relay/quantize/_calibrate.py b/python/tvm/relay/quantize/_calibrate.py index d904fed489bc..482a6f292f54 100644 --- a/python/tvm/relay/quantize/_calibrate.py +++ b/python/tvm/relay/quantize/_calibrate.py @@ -20,12 +20,12 @@ import multiprocessing as mp import numpy as np import tvm +from tvm.ir import IRModule from . import _quantize from . import quantize from .. import op as _op from .. import expr as _expr -from .. import module as _module from .. import analysis as _analysis from .. import transform as _transform from .. import build_module as _build_module @@ -141,7 +141,7 @@ def _make_const(val): func = mod['main'] _analysis.post_order_visit(func, visit_func) func = _expr.bind(func, const_params) - return _module.Module.from_expr(func) + return IRModule.from_expr(func) # weight scale functions diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index bcf8985657da..bff01e859a50 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -47,7 +47,7 @@ def run_opt_pass(expr, opt_pass): assert isinstance(opt_pass, transform.Pass) - mod = relay.Module.from_expr(expr) + mod = tvm.IRModule.from_expr(expr) mod = opt_pass(mod) entry = mod["main"] return entry if isinstance(expr, relay.Function) else entry.body diff --git a/python/tvm/relay/testing/dcgan.py b/python/tvm/relay/testing/dcgan.py index 6907eb01c88c..9d7bdaaf8c06 100644 --- a/python/tvm/relay/testing/dcgan.py +++ b/python/tvm/relay/testing/dcgan.py @@ -103,7 +103,7 @@ def get_workload(batch_size, oshape=(3, 64, 64), ngf=128, random_len=100, dtype= Returns ------- - mod : tvm.relay.Module + mod : tvm.IRModule The relay module that contains a DCGAN network. params : dict of str to NDArray The parameters. diff --git a/python/tvm/relay/testing/densenet.py b/python/tvm/relay/testing/densenet.py index 9818f446cf75..de140fbc15ab 100644 --- a/python/tvm/relay/testing/densenet.py +++ b/python/tvm/relay/testing/densenet.py @@ -105,7 +105,7 @@ def get_workload(densenet_size=121, classes=1000, batch_size=4, Returns ------- - mod: tvm.relay.Module + mod: tvm.IRModule The relay module that contains a DenseNet network. params : dict of str to NDArray diff --git a/python/tvm/relay/testing/dqn.py b/python/tvm/relay/testing/dqn.py index cdf9d24af996..10da37001f12 100644 --- a/python/tvm/relay/testing/dqn.py +++ b/python/tvm/relay/testing/dqn.py @@ -72,7 +72,7 @@ def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="flo The data type Returns ------- - mod : tvm.relay.Module + mod : tvm.IRModule The relay module that contains a DQN network. params : dict of str to NDArray The parameters. diff --git a/python/tvm/relay/testing/inception_v3.py b/python/tvm/relay/testing/inception_v3.py index fa4233d67b31..8a540e598b77 100644 --- a/python/tvm/relay/testing/inception_v3.py +++ b/python/tvm/relay/testing/inception_v3.py @@ -290,7 +290,7 @@ def get_workload(batch_size=1, num_classes=1000, Returns ------- - mod : tvm.relay.Module + mod : tvm.IRModule The relay module that contains an Inception V3 network. params : dict of str to NDArray diff --git a/python/tvm/relay/testing/init.py b/python/tvm/relay/testing/init.py index 0b8ab2b42029..352230a6150f 100644 --- a/python/tvm/relay/testing/init.py +++ b/python/tvm/relay/testing/init.py @@ -144,13 +144,13 @@ def create_workload(net, initializer=None, seed=0): Returns ------- - mod : tvm.relay.Module + mod : tvm.IRModule The created relay module. params : dict of str to NDArray The parameters. """ - mod = relay.Module.from_expr(net) + mod = tvm.IRModule.from_expr(net) mod = relay.transform.InferType()(mod) shape_dict = { v.name_hint : v.checked_type for v in mod["main"].params} diff --git a/python/tvm/relay/testing/lstm.py b/python/tvm/relay/testing/lstm.py index d0134c1a864d..2480d15f79bb 100644 --- a/python/tvm/relay/testing/lstm.py +++ b/python/tvm/relay/testing/lstm.py @@ -173,7 +173,7 @@ def get_workload(iterations, num_hidden, batch_size=1, dtype="float32"): The data type Returns ------- - mod : tvm.relay.Module + mod : tvm.IRModule The relay module that contains a LSTM network. params : dict of str to NDArray The parameters. diff --git a/python/tvm/relay/testing/mlp.py b/python/tvm/relay/testing/mlp.py index 337bde5d5889..d11873165097 100644 --- a/python/tvm/relay/testing/mlp.py +++ b/python/tvm/relay/testing/mlp.py @@ -84,7 +84,7 @@ def get_workload(batch_size, Returns ------- - mod : tvm.relay.Module + mod : tvm.IRModule The relay module that contains a mlp network. params : dict of str to NDArray diff --git a/python/tvm/relay/testing/mobilenet.py b/python/tvm/relay/testing/mobilenet.py index 1b3ce03d19d9..9aaefdfdb02d 100644 --- a/python/tvm/relay/testing/mobilenet.py +++ b/python/tvm/relay/testing/mobilenet.py @@ -151,7 +151,7 @@ def get_workload(batch_size=1, num_classes=1000, image_shape=(3, 224, 224), Returns ------- - mod : tvm.relay.Module + mod : tvm.IRModule The relay module that contains a MobileNet network. params : dict of str to NDArray diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py index e2825f815b67..eacfe379137f 100644 --- a/python/tvm/relay/testing/py_converter.py +++ b/python/tvm/relay/testing/py_converter.py @@ -584,7 +584,7 @@ def visit_op(self, _): def to_python(expr: Expr, mod=None, target=tvm.target.create('llvm')): """Converts the given Relay expression into a Python script (as a Python AST object). For easiest debugging, import the astor package and use to_source().""" - mod = mod if mod is not None else relay.Module() + mod = mod if mod is not None else tvm.IRModule() converter = PythonConverter(mod, target) return converter.convert(expr) @@ -592,7 +592,7 @@ def to_python(expr: Expr, mod=None, target=tvm.target.create('llvm')): def run_as_python(expr: Expr, mod=None, target=tvm.target.create('llvm')): """Converts the given Relay expression into a Python script and executes it.""" - mod = mod if mod is not None else relay.Module() + mod = mod if mod is not None else tvm.IRModule() py_ast = to_python(expr, mod, target) code = compile(py_ast, '', 'exec') var_map = { diff --git a/python/tvm/relay/testing/resnet.py b/python/tvm/relay/testing/resnet.py index bde788e1f9b9..97b6bdc7e617 100644 --- a/python/tvm/relay/testing/resnet.py +++ b/python/tvm/relay/testing/resnet.py @@ -262,7 +262,7 @@ def get_workload(batch_size=1, Returns ------- - mod : tvm.relay.Module + mod : tvm.IRModule The relay module that contains a ResNet network. params : dict of str to NDArray diff --git a/python/tvm/relay/testing/squeezenet.py b/python/tvm/relay/testing/squeezenet.py index 1e9ea73e9360..1a946b6eaa9a 100644 --- a/python/tvm/relay/testing/squeezenet.py +++ b/python/tvm/relay/testing/squeezenet.py @@ -149,7 +149,7 @@ def get_workload(batch_size=1, Returns ------- - mod : tvm.relay.Module + mod : tvm.IRModule The relay module that contains a SqueezeNet network. params : dict of str to NDArray diff --git a/python/tvm/relay/testing/vgg.py b/python/tvm/relay/testing/vgg.py index 205c5b1fa8e3..686230b9fbaf 100644 --- a/python/tvm/relay/testing/vgg.py +++ b/python/tvm/relay/testing/vgg.py @@ -124,7 +124,7 @@ def get_workload(batch_size, Returns ------- - mod : tvm.relay.Module + mod : tvm.IRModule The relay module that contains a VGG network. params : dict of str to NDArray diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index cfca4a6ed3b2..4c2bf873778a 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name,arguments-differ,no-else-return,unused-argument,missing-docstring +# pylint: disable=invalid-name, unused-argument, missing-docstring, unused-import """ Relay pass transformation infrastructure. """ @@ -23,96 +23,12 @@ import functools import tvm -from tvm._ffi.runtime_ctypes import TVMContext +from tvm.runtime import ndarray as _nd +from tvm.ir.transform import PassInfo, PassContext, Pass, ModulePass, Sequential, module_pass + from tvm import relay from . import _transform -from .base import RelayNode, register_relay_node -from .. import nd as _nd - - -@register_relay_node -class PassInfo(RelayNode): - """The class contains the meta data required by a pass. It is the - container of information needed by running an optimization or analysis. - This class can be extended by adding new members when more meta data is - needed. - - Parameters - ---------- - opt_level : int - The optimization level of this pass. - - name : str - The pass name. - - required : List[str] - The list of passes that are required by a certain pass. - """ - - def __init__(self, opt_level, name, required=None): - self.__init_handle_by_constructor__( - _transform.PassInfo, opt_level, name, required) - - -@register_relay_node -class PassContext(RelayNode): - """The basis where a Relay optimization/analysis runs on. - Each pass context contains a number of auxiliary information that is used - to help an optimization pass. Such information includes the error reporter - to record the errors of during the optimization, etc. - - opt_level : Optional[int] - The optimization level of this pass. - - fallback_device : Optional[Union[int, str, TVMContext]] - The fallback device type. It is also used as the default device for - operators that are not annotated during heterogeneous execution. - - required_pass : Optional[Union[List[str], Set[str], Tuple[str]]] - The list of passes that are required by a certain pass. - - disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]] - The list of passes that are disabled. - """ - def __init__(self, - opt_level=2, - fallback_device=_nd.cpu(), - required_pass=None, - disabled_pass=None, - trace=None): - if isinstance(fallback_device, str): - fallback_device = _nd.context(fallback_device).device_type - elif isinstance(fallback_device, TVMContext): - fallback_device = fallback_device.device_type - if not isinstance(fallback_device, int): - raise TypeError("required_pass is expected to be the type of " + - "int/str/TVMContext.") - - required = list(required_pass) if required_pass else [] - if not isinstance(required, (list, tuple)): - raise TypeError("required_pass is expected to be the type of " + - "list/tuple/set.") - - disabled = list(disabled_pass) if disabled_pass else [] - if not isinstance(disabled, (list, tuple)): - raise TypeError("disabled_pass is expected to be the type of " + - "list/tuple/set.") - - self.__init_handle_by_constructor__(_transform.PassContext, opt_level, - fallback_device, required, - disabled, trace) - - def __enter__(self): - _transform.EnterPassContext(self) - return self - - def __exit__(self, ptype, value, trace): - _transform.ExitPassContext(self) - - @staticmethod - def current(): - """Return the current pass context.""" - return _transform.GetCurrentPassContext() +from .base import register_relay_node def build_config(opt_level=2, @@ -143,7 +59,7 @@ def build_config(opt_level=2, "CombineParallelDense": 4 } - fallback_device : int, str, or tvm.TVMContext, optional + fallback_device : int, str, or tvmContext, optional The fallback device. It is also used as the default device for operators without specified device during heterogeneous execution. @@ -165,46 +81,6 @@ def build_config(opt_level=2, disabled_pass, trace) -@register_relay_node -class Pass(RelayNode): - """The base class of all passes. All methods here are just simple wrappers - that are implemented in the backend. They are defined for users to - conveniently interact with the base class. - """ - - @property - def info(self): - """Get the pass meta.""" - return _transform.Info(self) - - def __call__(self, mod): - """Execute the pass. Note that for sequential pass, the dependency among - different passes will be resolved in the backend. - - Parameters - ---------- - mod : tvm.relay.Module - The module that a certain optimization is performed on. - - Returns - ------- - mod : tvm.relay.Module - The updated module after applying this pass. - """ - return _transform.RunPass(self, mod) - - -@register_relay_node -class ModulePass(Pass): - """A pass that works on tvm.relay.Module. Users don't need to interact with - this class directly. Instead, a module pass should be created through - `module_pass`, because the design of the `module_pass` API is flexible - enough to handle the creation of a module pass in different manners. In - addition, all members of a module pass can be accessed from the base class. - The same rule applies to FunctionPass as well. - """ - - @register_relay_node class FunctionPass(Pass): """A pass that works on each tvm.relay.Function in a module. A function @@ -212,51 +88,6 @@ class FunctionPass(Pass): """ -@register_relay_node -class Sequential(Pass): - """A pass that works on a sequence of pass objects. Multiple passes can be - executed sequentially using this class. - - Some typical usage of the sequential pass are: - 1. Users provide a list of passes for optimization. - 2. Only an optimization level is provided so that the backend system has - to glob all passes at this level and below to perform the optimizations. - Note that users can also provide a series of passes that they don't want to - apply when running a sequential pass. Pass dependency will be resolved in - the backend as well. - - Parameters - ---------- - passes : Optional[List[Pass]] - A sequence of passes candidate for optimization. - - opt_level : Optional[int] - The optimization level of this sequential pass. - - name : Optional[str] - The name of the sequential pass. - - required : Optional[List[str]] - The list of passes that the sequential pass is dependent on. - """ - - def __init__(self, - passes=None, - opt_level=2, - name="sequential", - required=None): - passes = passes if passes else [] - if not isinstance(passes, (list, tuple)): - raise TypeError("passes must be a list of Pass objects.") - - required = required if required else [] - if not isinstance(required, (list, tuple)): - raise TypeError("Required is expected to be the type of list/tuple.") - - self.__init_handle_by_constructor__(_transform.Sequential, - passes, opt_level, name, required) - - def InferType(): """Infer the type of an expr. @@ -716,7 +547,7 @@ def gradient(expr, mod=None, mode='higher_order'): expr : tvm.relay.Expr The input expression, which is a Function or a GlobalVar. - mod : Optional[tvm.relay.Module] + mod : Optional[tvm.IRModule] mode : Optional[String] The mode of the automatic differentiation algorithm. @@ -747,7 +578,7 @@ def to_cps(func, mod=None): func: tvm.relay.Function The input function. - mod: Optional[tvm.relay.Module] + mod: Optional[tvm.IRModule] The global module. Returns @@ -778,138 +609,6 @@ def un_cps(func): return _transform.un_cps(func) -def _wrap_class_module_pass(pass_cls, pass_info): - """Wrap a python class as function pass""" - class PyModulePass(ModulePass): - """Internal wrapper class to create a class instance.""" - def __init__(self, *args, **kwargs): - # initialize handle in cass pass_cls creation failed.fg - self.handle = None - inst = pass_cls(*args, **kwargs) - # it is important not to capture self to - # avoid a cyclic dependency - def _pass_func(mod, ctx): - return inst.transform_module(mod, ctx) - self.__init_handle_by_constructor__( - _transform.MakeModulePass, _pass_func, pass_info) - self._inst = inst - - def __getattr__(self, name): - # fall back to instance attribute if there is not any - return self._inst.__getattribute__(name) - - functools.update_wrapper(PyModulePass.__init__, pass_cls.__init__) - PyModulePass.__name__ = pass_cls.__name__ - PyModulePass.__doc__ = pass_cls.__doc__ - PyModulePass.__module__ = pass_cls.__module__ - return PyModulePass - - -def module_pass(pass_func=None, opt_level=None, name=None, required=None): - """Decorate a module pass. - - This function returns a callback when pass_func is provided. - Otherwise, it serves a decorator function. - - pass_func can also be a class type with a method transform_module. - This function will create a decorated ModulePass using transform_module - as the pass function. - - Parameters - ---------- - pass_func : Optional[Callable[(Module, PassContext) ->Module]] - The transformation function or class. - - opt_level : int - The optimization level of this module pass. - - name : Optional[str] - The name of the module pass. The name could be empty. In this case, the - name of the optimization function will be used as the pass name. - - required : Optional[List[str]] - The list of passes that the module pass is dependent on. - - Returns - ------- - create_module_pass : Union[Callable, ModulePass] - A decorator will be returned if pass_func is not provided, - otherwise return the decorated result. - The returned decorator has two behaviors depending on the input: - A new ModulePass will be returned when we decorate a pass function. - A new ModulePass class will be returned when we decorate a class type. - - Examples - -------- - The following code block decorates a module pass class. - - .. code-block:: python - - @relay.transform.module_pass - class CustomPipeline: - def __init__(self, enable_fold): - self.enable_fold = enable_fold - self.cse = relay.transform.EliminateCommonSubexpr() - self.const_fold = relay.transform.FoldConstant() - - def transform_module(self, mod, ctx): - mod = self.cse(mod, ctx) - if self.enable_fold: - mod = self.const_fold(mod, ctx) - return mod - - # create an instance of customized pipeline - pipeline = CustomPipeline(enable_fold=False) - assert isinstance(pipeline, transform.ModulePass) - # run the pipeline. - output_module = pipeline(input_module) - - The following code creates a module pass by decorating - a user defined transform function. - - .. code-block:: python - - @relay.transform.module_pass(opt_level=2) - def transform(mod, ctx): - tp = relay.TensorType((10,), "float32") - x = relay.var("x", tp) - gv = relay.GlobalVar("var") - func = relay.Function([x], relay.abs(x)) - new_mod = relay.Module({gv: func}) - new_mod.update(mod) - return new_mod - - module_pass = transform - assert isinstance(module_pass, transform.ModulePass) - assert module_pass.info.opt_level == 2 - - # Given a module m, the optimization could be invoked as the follwoing: - updated_mod = module_pass(m) - # Now a function abs should be added to the module m. - """ - if opt_level is None: - raise ValueError("Please provide opt_level for the module pass.") - - required = required if required else [] - if not isinstance(required, (list, tuple)): - raise TypeError("Required is expected to be the type of " + - "list/tuple.") - - def create_module_pass(pass_arg): - """Internal function that creates a module pass""" - fname = name if name else pass_arg.__name__ - info = PassInfo(opt_level, fname, required) - if inspect.isclass(pass_arg): - return _wrap_class_module_pass(pass_arg, info) - if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): - raise TypeError("pass_func must be a callable for Module pass") - return _transform.MakeModulePass(pass_arg, info) - - if pass_func: - return create_module_pass(pass_func) - return create_module_pass - - def _wrap_class_function_pass(pass_cls, pass_info): """Wrap a python class as function pass""" class PyFunctionPass(FunctionPass): @@ -1071,6 +770,5 @@ def visit_var(self, var): new_shape = list(ty.shape) new_shape[change_batch.data[var]] = change_batch.batch_size return relay.Var(var.name_hint, relay.TensorType(new_shape, ty.dtype)) - else: - return var + return var return ChangeBatchMutator().visit(func) diff --git a/python/tvm/relay/transform.pyi b/python/tvm/relay/transform.pyi deleted file mode 100644 index 2c466b0576a7..000000000000 --- a/python/tvm/relay/transform.pyi +++ /dev/null @@ -1,71 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import tvm -from .base import Object - - -class PassContext(Object): - def __init__(self): - ... - -class PassInfo(Object): - name = ... # type: str - opt_level = ... # type: int - required = ... # type: list - - def __init__(self, name, opt_level, required) - # type: (str, int, list) -> None - - -class Pass(Object): - def __init__(self): - ... - - -class ModulePass(Pass): - name = ... # type: str - opt_level = ... # type: int - pass_func = ... # type: Callable - required = ... # type: list - - def __init__(self, name, opt_level, pass_func, required): - # type: (str, int, Callable, list) -> None - ... - - -class FunctionPass(Pass): - name = ... # type: str - opt_level = ... # type: int - pass_func = ... # type: Callable - required = ... # type: list - - def __init__(self, name, opt_level, pass_func, required): - # type: (str, int, Callable, list) -> None - ... - - -class Sequential(Pass): - name = ... # type: str - opt_level = ... # type: int - passes = ... # type: list - required = ... # type: list - disabled = ... # type: list - - def __init__(self, name, opt_level, passes, required, disabled): - # type: (str, int, list, list, list) -> None - ... diff --git a/python/tvm/relay/ty.py b/python/tvm/relay/ty.py index 356fe0beb0da..13d7f9197e79 100644 --- a/python/tvm/relay/ty.py +++ b/python/tvm/relay/ty.py @@ -14,133 +14,30 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name +# pylint: disable=invalid-name, unused-import """The type nodes of the Relay language.""" -from enum import IntEnum +from tvm.ir import Type, TypeKind, TypeVar, GlobalTypeVar +from tvm.ir import TypeConstraint, FuncType, TupleType, IncompleteType +from tvm.ir import TypeCall, TypeRelation, TensorType, RelayRefType as RefType + from .base import RelayNode, register_relay_node from . import _make Any = _make.Any -class Type(RelayNode): - """The base type for all Relay types.""" - - def __eq__(self, other): - """Compare two Relay types for structural equivalence using - alpha equivalence. - """ - return bool(_make._alpha_equal(self, other)) - - def __ne__(self, other): - return not self.__eq__(other) - - def same_as(self, other): - """Compares two Relay types by referential equality.""" - return super().__eq__(other) - - def __call__(self, *args): - """Create a type call from this type. +def type_has_any(tensor_type): + """Check whether type has any as a shape. - Parameters - ---------- - args: List[relay.Type] - The arguments to the type call. - - Returns - ------- - call: relay.TypeCall - """ - return TypeCall(self, args) - - def is_dynamic(self): - return _make.IsDynamic(self) - -@register_relay_node -class TensorType(Type): - """A concrete TensorType in Relay. - - This is the type assigned to tensors with a known dtype and shape. For - example, a tensor of `float32` and `(5, 5)`. - - Parameters - ---------- - shape : List[tvm.Expr] - The shape of the Tensor - - dtype : Optional[str] - The content data type. - Default to "float32". + tensor_type : Type + The type to be inspected Returns ------- - tensor_type : tvm.relay.TensorType - The tensor type. - """ - def __init__(self, shape, dtype="float32"): - self.__init_handle_by_constructor__( - _make.TensorType, shape, dtype) - - @property - def concrete_shape(self): - """Get shape of the type as concrete tuple of int. - - Returns - ------- - shape : List[int] - The concrete shape of the Type. - - Raises - ------ - TypeError : If the shape is symbolic - """ - return tuple(int(x) for x in self.shape) - - -class Kind(IntEnum): - """The kind of a type parameter, represents a variable shape, - base type, type, or dimension. - - This controls what a type parameter is allowed to be instantiated - with. For example one's of kind BaseType can only be `float32`, `int32`, - and so on. - """ - Type = 0 - ShapeVar = 1 - BaseType = 2 - Shape = 3 - Constraint = 4 - AdtHandle = 5 - TypeData = 6 - -@register_relay_node -class TypeVar(Type): - """A type variable used for generic types in Relay, - see tvm/relay/type.h for more details. - - A type variable represents a type placeholder which will - be filled in later on. This allows the user to write - functions which are generic over types. + has_any : bool + The check result. """ + return _make.IsDynamic(tensor_type) - def __init__(self, name_hint, kind=Kind.Type): - """Construct a TypeVar. - - Parameters - ---------- - name_hint: str - The name of the type variable. This name only acts as a hint, and - is not used for equality. - - kind : Optional[Kind] - The kind of the type parameter. - Default to Kind.Type. - - Returns - ------- - type_var : tvm.relay.TypeVar - The type variable. - """ - self.__init_handle_by_constructor__(_make.TypeVar, name_hint, kind) def ShapeVar(name): """A helper which constructs a type var of which the shape kind. @@ -154,172 +51,9 @@ def ShapeVar(name): type_var : tvm.relay.TypeVar The shape variable. """ - return TypeVar(name, kind=Kind.ShapeVar) - -@register_relay_node -class GlobalTypeVar(Type): - """A global type variable in Relay. - GlobalTypeVar is used to refer to the global type-level definitions - stored in the environment. - """ - - def __init__(self, name_hint, kind=Kind.AdtHandle): - """Construct a GlobalTypeVar. - - Parameters - ---------- - name_hint: str - The name of the global type variable. This name only acts as a - hint, and is not used for equality. - - kind: Kind, optional - The kind of the type parameter, Kind.AdtHandle by default. - - Returns - ------- - type_var: GlobalTypeVar - The global type variable. - """ - self.__init_handle_by_constructor__(_make.GlobalTypeVar, name_hint, kind) - - -@register_relay_node -class TypeCall(Type): - """Type-level function application in Relay. - A type call applies argument types to a constructor (type-level function). - """ - - def __init__(self, func, args): - """Construct a TypeCall. - Parameters - ---------- - func: tvm.relay.Type - The function. - args: List[tvm.expr.Type] - The arguments. - Returns - ------- - type_call: TypeCall - The type function application. - """ - self.__init_handle_by_constructor__(_make.TypeCall, func, args) - - -@register_relay_node -class TypeConstraint(Type): - """Abstract class representing a type constraint.""" - - -@register_relay_node -class TupleType(Type): - """A tuple type in Relay, see tvm/relay/type.h for more details. - - Lists the type of each field in the tuple. - """ - - def __init__(self, fields): - """Constructs a tuple type - - Parameters - ---------- - fields : List[tvm.relay.Type] - The fields in the tuple - - Returns - ------- - tuple_type : tvm.relay.TupleType - the tuple type - """ - self.__init_handle_by_constructor__(_make.TupleType, fields) - - -@register_relay_node -class FuncType(Type): - """A function type in Relay, see tvm/relay/type.h for more details. - - This is the type assigned to functions in Relay. They consist of - a list of type parameters which enable the definition of generic - functions, a set of type constraints which we omit for the time - being, a sequence of argument types, and a return type. - - We informally write them as: - `forall (type_params), (arg_types) -> ret_type where type_constraints` - - Parameters - ---------- - arg_types : List[tvm.relay.Type] - The argument types - - ret_type : tvm.relay.Type - The return type. - - type_params : Optional[List[tvm.relay.TypeVar]] - The type parameters - - type_constraints : Optional[List[tvm.relay.TypeConstraint]] - The type constraints. - """ - def __init__(self, - arg_types, - ret_type, - type_params=None, - type_constraints=None): - if type_params is None: - type_params = [] - if type_constraints is None: - type_constraints = [] - self.__init_handle_by_constructor__( - _make.FuncType, arg_types, ret_type, type_params, type_constraints) + return TypeVar(name, kind=TypeKind.ShapeVar) -@register_relay_node -class IncompleteType(Type): - """An incomplete type.""" - def __init__(self, kind=Kind.Type): - self.__init_handle_by_constructor__(_make.IncompleteType, kind) - - -@register_relay_node -class TypeRelation(TypeConstraint): - """Type relation in relay. - - Parameters - ---------- - func : EnvFunc - User defined relation function. - - args : [tvm.relay.Type] - List of types to the func. - - num_inputs : int - Number of input arguments in args, - this act as a hint for type inference. - - attrs : Attrs - The attribute attached to the relation information - - Returns - ------- - type_relation : tvm.relay.TypeRelation - The type relation. - """ - def __init__(self, func, args, num_inputs, attrs): - self.__init_handle_by_constructor__(_make.TypeRelation, - func, args, num_inputs, attrs) - - -@register_relay_node -class RefType(Type): - """Reference Type in relay. - - Parameters - ---------- - value: Type - The value type. - """ - def __init__(self, value): - self.__init_handle_by_constructor__(_make.RefType, value) - def scalar_type(dtype): """Creates a scalar type. diff --git a/python/tvm/relay/ty.pyi b/python/tvm/relay/ty.pyi deleted file mode 100644 index cde851160167..000000000000 --- a/python/tvm/relay/ty.pyi +++ /dev/null @@ -1,200 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name -"""The type nodes of the Relay language.""" -from enum import IntEnum -from .base import Object, register_relay_node -from . import _make - - -class Type(Object): - """The base type for all Relay types.""" - - def __eq__(self, other): - """Compare two Relay types for structural equivalence using - alpha equivalence. - """ - return bool(_make._type_alpha_eq(self, other)) - - def __ne__(self, other): - return not self.__eq__(other) - - def same_as(self, other): - """Compares two Relay types by referential equality.""" - return super().__eq__(other) - - -@register_relay_node -class TensorType(Type): - """A concrete TensorType in Relay, see tvm/relay/type.h for more details. - - This is the type assigned to tensor's with a known dype and shape. For - example a tensor of `float32` and `(5, 5)`. - """ - - def __init__(self, shape, dtype): - """Construct a tensor type. - - Parameters - ---------- - shape: list of tvm.Expr - dtype: str - - Returns - ------- - tensor_type: The TensorType - """ - self.__init_handle_by_constructor__(_make.TensorType, shape, dtype) - - -class Kind(IntEnum): - """The kind of a type parameter, represents a variable shape, - base type, type, or dimension. - - This controls what a type parameter is allowed to be instantiated - with. For example one's of kind BaseType can only be `float32`, `int32`, - and so on. - """ - ShapeVar = 0 - Shape = 1 - BaseType = 2 - Type = 3 - - -@register_relay_node -class TypeParam(Type): - """A type parameter used for generic types in Relay, - see tvm/relay/type.h for more details. - - A type parameter represents a type placeholder which will - be filled in later on. This allows the user to write - functions which are generic over types. - """ - - def __init__(self, var, kind): - """Construct a TypeParam. - - Parameters - ---------- - var: tvm.expr.Var - The tvm.Var which backs the type parameter. - - kind: Kind - The kind of the type parameter. - - Returns - ------- - type_param: TypeParam - The type parameter. - """ - self.__init_handle_by_constructor__(_make.TypeParam, var, kind) - - -@register_relay_node -class TypeConstraint(Type): - """Abstract class representing a type constraint.""" - pass - - -@register_relay_node -class TupleType(Type): - """A tuple type in Relay, see tvm/relay/type.h for more details. - - Lists the type of each field in the tuple. - """ - - def __init__(self, fields): - """Constructs a tuple type - - Parameters - ---------- - fields: list of tvm.Type - - Returns - ------- - tuple_type: the tuple type - """ - self.__init_handle_by_constructor__(_make.TupleType, fields) - - -@register_relay_node -class FuncType(Type): - """A function type in Relay, see tvm/relay/type.h for more details. - - This is the type assigned to functions in Relay. They consist of - a list of type parameters which enable the definition of generic - functions, a set of type constraints which we omit for the time - being, a sequence of argument types, and a return type. - - We informally write them as: - `forall (type_params), (arg_types) -> ret_type where type_constraints` - """ - - def __init__(self, - arg_types, - ret_type, - type_params, - type_constraints, - ): - """Construct a function type. - - Parameters - ---------- - arg_types: list of Type - ret_type: Type - type_params: list of TypeParam - type_constraints: list of TypeConstraint - - Returns - ------- - func_type: FuncType - The function type. - """ - self.__init_handle_by_constructor__( - _make.FuncType, arg_types, ret_type, type_params, type_constraints) - - -@register_relay_node -class IncompleteType(Type): - """An incomplete type.""" - - def __init__(self, kind=Kind.Type): - self.__init_handle_by_constructor__(_make.IncompleteType, kind) - -@register_relay_node -class TypeRelation(TypeConstraint): - """Type relation in relay. - - Parameters - ---------- - func : EnvFunc - User defined relation function. - - args : list of types - List of types to the func. - - num_inputs: int - Number of input arguments in args, - this act as a hint for type inference. - - attrs : Attrs - The attribute attached to the relation information - """ - def __init__(self, func, args, num_inputs, attrs): - self.__init_handle_by_constructor__(_make.TypeRelation, - func, args, num_inputs, attrs) diff --git a/python/tvm/runtime/_ffi_node_api.py b/python/tvm/runtime/_ffi_node_api.py index adaa376b8a8d..64e0a939f97c 100644 --- a/python/tvm/runtime/_ffi_node_api.py +++ b/python/tvm/runtime/_ffi_node_api.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=invalid-name, unused-argument -"""FFI for tvm.runtime.extra""" +"""FFI for tvm.node""" import tvm._ffi # The implementations below are default ones when the corresponding diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index d779b5979b18..650bf9d1aab1 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -19,11 +19,11 @@ from tvm._ffi.base import string_types from tvm.runtime import Object, convert +from tvm.ir import container as _container from . import _api_internal from . import tensor as _tensor from . import expr as _expr -from . import container as _container @tvm._ffi.register_object diff --git a/python/tvm/stmt.py b/python/tvm/stmt.py index 27bb4db76ac0..e5feb50ddf6f 100644 --- a/python/tvm/stmt.py +++ b/python/tvm/stmt.py @@ -366,6 +366,14 @@ def __init__(self, func, value_index, dtype, bounds): _make.Prefetch, func, value_index, dtype, bounds) +@tvm._ffi.register_object +class LoweredFunc(Object): + """Represent a LoweredFunc in TVM.""" + MixedFunc = 0 + HostFunc = 1 + DeviceFunc = 2 + + def stmt_seq(*args): """Make sequence of statements diff --git a/src/ir/adt.cc b/src/ir/adt.cc index f94284090e26..91d655a062a1 100644 --- a/src/ir/adt.cc +++ b/src/ir/adt.cc @@ -38,7 +38,7 @@ Constructor::Constructor(std::string name_hint, TVM_REGISTER_NODE_TYPE(ConstructorNode); -TVM_REGISTER_GLOBAL("relay._make.Constructor") +TVM_REGISTER_GLOBAL("ir.Constructor") .set_body_typed([](std::string name_hint, tvm::Array inputs, GlobalTypeVar belong_to) { @@ -64,7 +64,7 @@ TypeData::TypeData(GlobalTypeVar header, TVM_REGISTER_NODE_TYPE(TypeDataNode); -TVM_REGISTER_GLOBAL("relay._make.TypeData") +TVM_REGISTER_GLOBAL("ir.TypeData") .set_body_typed([](GlobalTypeVar header, tvm::Array type_vars, tvm::Array constructors) { diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index c5d7446d2955..60a543109c50 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -334,7 +334,7 @@ bool DictAttrsNode::ContentEqual(const Object* other, AttrsEqual equal) const { return equal(this->dict, static_cast(other)->dict); } -TVM_REGISTER_GLOBAL("_AttrsListFieldInfo") +TVM_REGISTER_GLOBAL("ir.AttrsListFieldInfo") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = args[0].operator Attrs()->ListFieldInfo(); }); diff --git a/src/ir/env_func.cc b/src/ir/env_func.cc index b125c0318853..3e85c5f47b52 100644 --- a/src/ir/env_func.cc +++ b/src/ir/env_func.cc @@ -50,10 +50,10 @@ EnvFunc EnvFunc::Get(const std::string& name) { return EnvFunc(CreateEnvNode(name)); } -TVM_REGISTER_GLOBAL("_EnvFuncGet") +TVM_REGISTER_GLOBAL("ir.EnvFuncGet") .set_body_typed(EnvFunc::Get); -TVM_REGISTER_GLOBAL("_EnvFuncCall") +TVM_REGISTER_GLOBAL("ir.EnvFuncCall") .set_body([](TVMArgs args, TVMRetValue* rv) { EnvFunc env = args[0]; CHECK_GE(args.size(), 1); @@ -62,7 +62,7 @@ TVM_REGISTER_GLOBAL("_EnvFuncCall") args.size() - 1), rv); }); -TVM_REGISTER_GLOBAL("_EnvFuncGetPackedFunc") +TVM_REGISTER_GLOBAL("ir.EnvFuncGetPackedFunc") .set_body_typed([](const EnvFunc&n) { return n->func; }); diff --git a/src/ir/expr.cc b/src/ir/expr.cc index c061587ba360..78c6879d8ced 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -154,7 +154,7 @@ GlobalVar::GlobalVar(std::string name_hint) { TVM_REGISTER_NODE_TYPE(GlobalVarNode); -TVM_REGISTER_GLOBAL("relay._make.GlobalVar") +TVM_REGISTER_GLOBAL("ir.GlobalVar") .set_body_typed([](std::string name){ return GlobalVar(name); }); diff --git a/src/ir/module.cc b/src/ir/module.cc index 7f3796ed07f5..04fe5d55bceb 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -338,13 +338,13 @@ IRModule IRModule::FromText(const std::string& text, const std::string& source_p TVM_REGISTER_NODE_TYPE(IRModuleNode); -TVM_REGISTER_GLOBAL("relay._make.Module") +TVM_REGISTER_GLOBAL("ir.IRModule") .set_body_typed([](tvm::Map funcs, tvm::Map types) { return IRModule(funcs, types, {}); }); -TVM_REGISTER_GLOBAL("relay._module.Module_Add") +TVM_REGISTER_GLOBAL("ir.Module_Add") .set_body([](TVMArgs args, TVMRetValue* ret) { IRModule mod = args[0]; GlobalVar var = args[1]; @@ -369,67 +369,67 @@ TVM_REGISTER_GLOBAL("relay._module.Module_Add") *ret = mod; }); -TVM_REGISTER_GLOBAL("relay._module.Module_AddDef") +TVM_REGISTER_GLOBAL("ir.Module_AddDef") .set_body_method(&IRModuleNode::AddTypeDef); -TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalVar") +TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVar") .set_body_method(&IRModuleNode::GetGlobalVar); -TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalVars") +TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVars") .set_body_method(&IRModuleNode::GetGlobalVars); -TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalTypeVars") +TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVars") .set_body_method(&IRModuleNode::GetGlobalTypeVars); -TVM_REGISTER_GLOBAL("relay._module.Module_ContainGlobalVar") +TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalVar") .set_body_method(&IRModuleNode::ContainGlobalVar); -TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalTypeVar") +TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVar") .set_body_method(&IRModuleNode::GetGlobalTypeVar); -TVM_REGISTER_GLOBAL("relay._module.Module_Lookup") +TVM_REGISTER_GLOBAL("ir.Module_Lookup") .set_body_typed([](IRModule mod, GlobalVar var) { return mod->Lookup(var); }); -TVM_REGISTER_GLOBAL("relay._module.Module_Lookup_str") +TVM_REGISTER_GLOBAL("ir.Module_Lookup_str") .set_body_typed([](IRModule mod, std::string var) { return mod->Lookup(var); }); -TVM_REGISTER_GLOBAL("relay._module.Module_LookupDef") +TVM_REGISTER_GLOBAL("ir.Module_LookupDef") .set_body_typed([](IRModule mod, GlobalTypeVar var) { return mod->LookupTypeDef(var); }); -TVM_REGISTER_GLOBAL("relay._module.Module_LookupDef_str") +TVM_REGISTER_GLOBAL("ir.Module_LookupDef_str") .set_body_typed([](IRModule mod, std::string var) { return mod->LookupTypeDef(var); }); -TVM_REGISTER_GLOBAL("relay._module.Module_LookupTag") +TVM_REGISTER_GLOBAL("ir.Module_LookupTag") .set_body_typed([](IRModule mod, int32_t tag) { return mod->LookupTag(tag); }); -TVM_REGISTER_GLOBAL("relay._module.Module_FromExpr") +TVM_REGISTER_GLOBAL("ir.Module_FromExpr") .set_body_typed([](RelayExpr e, tvm::Map funcs, tvm::Map type_defs) { return IRModule::FromExpr(e, funcs, type_defs); }); -TVM_REGISTER_GLOBAL("relay._module.Module_Update") +TVM_REGISTER_GLOBAL("ir.Module_Update") .set_body_typed([](IRModule mod, IRModule from) { mod->Update(from); }); -TVM_REGISTER_GLOBAL("relay._module.Module_Import") +TVM_REGISTER_GLOBAL("ir.Module_Import") .set_body_typed([](IRModule mod, std::string path) { mod->Import(path); }); -TVM_REGISTER_GLOBAL("relay._module.Module_ImportFromStd") +TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd") .set_body_typed([](IRModule mod, std::string path) { mod->ImportFromStd(path); });; diff --git a/src/ir/span.cc b/src/ir/span.cc index 2ea7095c89ac..d03903c2d3a5 100644 --- a/src/ir/span.cc +++ b/src/ir/span.cc @@ -45,7 +45,7 @@ SourceName SourceName::Get(const std::string& name) { return SourceName(GetSourceNameNode(name)); } -TVM_REGISTER_GLOBAL("relay._make.SourceName") +TVM_REGISTER_GLOBAL("ir.SourceName") .set_body_typed(SourceName::Get); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -70,7 +70,7 @@ Span SpanNode::make(SourceName source, int lineno, int col_offset) { TVM_REGISTER_NODE_TYPE(SpanNode); -TVM_REGISTER_GLOBAL("relay._make.Span") +TVM_REGISTER_GLOBAL("ir.Span") .set_body_typed(SpanNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) diff --git a/src/ir/tensor_type.cc b/src/ir/tensor_type.cc index 5e7c51c72d9b..57cdebc931fb 100644 --- a/src/ir/tensor_type.cc +++ b/src/ir/tensor_type.cc @@ -55,7 +55,7 @@ PrimExpr TensorTypeNode::Size() const { TVM_REGISTER_NODE_TYPE(TensorTypeNode); -TVM_REGISTER_GLOBAL("relay._make.TensorType") +TVM_REGISTER_GLOBAL("ir.TensorType") .set_body_typed([](Array shape, DataType dtype) { return TensorType(shape, dtype); }); diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 14bd063b0169..2b5010b5ffd5 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -300,10 +300,15 @@ bool SequentialNode::PassEnabled(const PassInfo& info) const { Pass GetPass(const std::string& pass_name) { using tvm::runtime::Registry; - std::string fpass_name = "relay._transform." + pass_name; - const auto* f = Registry::Get(fpass_name); - CHECK(f != nullptr) << "Cannot find " << fpass_name - << "to create the pass " << pass_name; + const runtime::PackedFunc* f = nullptr; + if (pass_name.find("transform.") != std::string::npos) { + f = Registry::Get(pass_name); + } else if ((f = Registry::Get("transform." + pass_name))) { + // pass + } else if ((f = Registry::Get("relay._transform." + pass_name))) { + } + CHECK(f != nullptr) << "Cannot use " << pass_name + << "to create the pass"; return (*f)(); } @@ -311,7 +316,7 @@ Pass GetPass(const std::string& pass_name) { // a Sequential without the consideration of their orders. The phase // ordering problem needs to be handled in the future. IRModule SequentialNode::operator()(const IRModule& module, - const PassContext& pass_ctx) const { + const PassContext& pass_ctx) const { IRModule mod = module; for (const Pass& pass : passes) { CHECK(pass.defined()) << "Found undefined pass for optimization."; @@ -339,12 +344,12 @@ Pass CreateModulePass( TVM_REGISTER_NODE_TYPE(PassInfoNode); -TVM_REGISTER_GLOBAL("relay._transform.PassInfo") +TVM_REGISTER_GLOBAL("transform.PassInfo") .set_body_typed([](int opt_level, std::string name, tvm::Array required) { return PassInfo(opt_level, name, required); }); -TVM_REGISTER_GLOBAL("relay._transform.Info") +TVM_REGISTER_GLOBAL("transform.Info") .set_body([](TVMArgs args, TVMRetValue* ret) { Pass pass = args[0]; *ret = pass->Info(); @@ -366,14 +371,14 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(ModulePassNode); -TVM_REGISTER_GLOBAL("relay._transform.MakeModulePass") +TVM_REGISTER_GLOBAL("transform.MakeModulePass") .set_body_typed( [](runtime::TypedPackedFunc pass_func, PassInfo pass_info) { return ModulePass(pass_func, pass_info); }); -TVM_REGISTER_GLOBAL("relay._transform.RunPass") +TVM_REGISTER_GLOBAL("transform.RunPass") .set_body([](TVMArgs args, TVMRetValue* ret) { Pass pass = args[0]; IRModule mod = args[1]; @@ -390,7 +395,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(SequentialNode); -TVM_REGISTER_GLOBAL("relay._transform.Sequential") +TVM_REGISTER_GLOBAL("transform.Sequential") .set_body([](TVMArgs args, TVMRetValue* ret) { tvm::Array passes = args[0]; int opt_level = args[1]; @@ -416,7 +421,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(PassContextNode); -TVM_REGISTER_GLOBAL("relay._transform.PassContext") +TVM_REGISTER_GLOBAL("transform.PassContext") .set_body([](TVMArgs args, TVMRetValue* ret) { auto pctx = PassContext::Create(); int opt_level = args[0]; @@ -465,13 +470,13 @@ class PassContext::Internal { } }; -TVM_REGISTER_GLOBAL("relay._transform.GetCurrentPassContext") +TVM_REGISTER_GLOBAL("transform.GetCurrentPassContext") .set_body_typed(PassContext::Current); -TVM_REGISTER_GLOBAL("relay._transform.EnterPassContext") +TVM_REGISTER_GLOBAL("transform.EnterPassContext") .set_body_typed(PassContext::Internal::EnterScope); -TVM_REGISTER_GLOBAL("relay._transform.ExitPassContext") +TVM_REGISTER_GLOBAL("transform.ExitPassContext") .set_body_typed(PassContext::Internal::ExitScope); } // namespace transform diff --git a/src/ir/type.cc b/src/ir/type.cc index 02ddfc9371fd..e0420aaf754a 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -33,7 +33,7 @@ PrimType::PrimType(runtime::DataType dtype) { TVM_REGISTER_NODE_TYPE(PrimTypeNode); -TVM_REGISTER_GLOBAL("relay._make.PrimType") +TVM_REGISTER_GLOBAL("ir.PrimType") .set_body_typed([](runtime::DataType dtype) { return PrimType(dtype); }); @@ -54,7 +54,7 @@ TypeVar::TypeVar(std::string name, TypeKind kind) { TVM_REGISTER_NODE_TYPE(TypeVarNode); -TVM_REGISTER_GLOBAL("relay._make.TypeVar") +TVM_REGISTER_GLOBAL("ir.TypeVar") .set_body_typed([](std::string name, int kind) { return TypeVar(name, static_cast(kind)); }); @@ -76,7 +76,7 @@ GlobalTypeVar::GlobalTypeVar(std::string name, TypeKind kind) { TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode); -TVM_REGISTER_GLOBAL("relay._make.GlobalTypeVar") +TVM_REGISTER_GLOBAL("ir.GlobalTypeVar") .set_body_typed([](std::string name, int kind) { return GlobalTypeVar(name, static_cast(kind)); }); @@ -102,7 +102,7 @@ FuncType::FuncType(tvm::Array arg_types, TVM_REGISTER_NODE_TYPE(FuncTypeNode); -TVM_REGISTER_GLOBAL("relay._make.FuncType") +TVM_REGISTER_GLOBAL("ir.FuncType") .set_body_typed([](tvm::Array arg_types, Type ret_type, tvm::Array type_params, @@ -131,7 +131,7 @@ TupleType TupleType::Empty() { TVM_REGISTER_NODE_TYPE(TupleTypeNode); -TVM_REGISTER_GLOBAL("relay._make.TupleType") +TVM_REGISTER_GLOBAL("ir.TupleType") .set_body_typed([](Array fields) { return TupleType(fields); }); @@ -151,7 +151,7 @@ IncompleteType::IncompleteType(TypeKind kind) { TVM_REGISTER_NODE_TYPE(IncompleteTypeNode); -TVM_REGISTER_GLOBAL("relay._make.IncompleteType") +TVM_REGISTER_GLOBAL("ir.IncompleteType") .set_body_typed([](int kind) { return IncompleteType(static_cast(kind)); }); @@ -169,7 +169,7 @@ RelayRefType::RelayRefType(Type value) { data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relay._make.RefType") +TVM_REGISTER_GLOBAL("ir.RelayRefType") .set_body_typed([](Type value) { return RelayRefType(value); }); diff --git a/src/ir/type_relation.cc b/src/ir/type_relation.cc index 1d80f95b10c9..bd79c9c7fd16 100644 --- a/src/ir/type_relation.cc +++ b/src/ir/type_relation.cc @@ -35,7 +35,7 @@ TypeCall::TypeCall(Type func, tvm::Array args) { TVM_REGISTER_NODE_TYPE(TypeCallNode); -TVM_REGISTER_GLOBAL("relay._make.TypeCall") +TVM_REGISTER_GLOBAL("ir.TypeCall") .set_body_typed([](Type func, Array type) { return TypeCall(func, type); }); @@ -61,7 +61,7 @@ TypeRelation::TypeRelation(TypeRelationFn func, TVM_REGISTER_NODE_TYPE(TypeRelationNode); -TVM_REGISTER_GLOBAL("relay._make.TypeRelation") +TVM_REGISTER_GLOBAL("ir.TypeRelation") .set_body_typed([](TypeRelationFn func, Array args, int num_inputs, diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 0fa4da5b5077..00bf70b5d289 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -131,6 +131,7 @@ class RelayTextPrinter : } else if (node.as()) { return PrintMod(Downcast(node)); } else { + // default module. std::ostringstream os; os << node; return Doc() << os.str(); @@ -905,20 +906,18 @@ static const char* kSemVer = "v0.0.4"; // - relay_text_printer.cc (specific printing logics for relay) // - tir_text_printer.cc (specific printing logics for TIR) std::string PrettyPrint(const ObjectRef& node) { - Doc doc; - doc << relay::RelayTextPrinter(false, nullptr).PrintFinal(node); - return doc.str(); + return AsText(node, false, nullptr); } std::string AsText(const ObjectRef& node, bool show_meta_data, runtime::TypedPackedFunc annotate) { Doc doc; - doc << kSemVer << Doc::NewLine() - << relay::RelayTextPrinter(show_meta_data, annotate).PrintFinal(node); + doc << kSemVer << Doc::NewLine(); + doc << relay::RelayTextPrinter(show_meta_data, annotate).PrintFinal(node); return doc.str(); } -TVM_REGISTER_GLOBAL("relay._expr.AsText") +TVM_REGISTER_GLOBAL("ir.AsText") .set_body_typed(AsText); } // namespace tvm diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index 2d07f6131f13..48634bafa744 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -599,6 +599,11 @@ TVM_REGISTER_GLOBAL("relay._make._alpha_equal") return AlphaEqualHandler(false, false).Equal(a, b); }); +TVM_REGISTER_GLOBAL("ir.type_alpha_equal") +.set_body_typed([](Type a, Type b) { + return AlphaEqual(a, b); +}); + TVM_REGISTER_GLOBAL("relay._make._assert_alpha_equal") .set_body_typed([](ObjectRef a, ObjectRef b) { bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b); diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index 85b17b5a33e0..22423b8dfe5f 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -33,7 +33,7 @@ using namespace tvm::runtime; TVM_REGISTER_NODE_TYPE(IdNode); -TVM_REGISTER_GLOBAL("relay._base.set_span") +TVM_REGISTER_GLOBAL("ir.NodeSetSpan") .set_body_typed([](ObjectRef node_ref, Span sp) { if (auto* rn = node_ref.as()) { rn->span = sp; diff --git a/tests/python/contrib/test_rpc_tracker.py b/tests/python/contrib/test_rpc_tracker.py index 6abfc90c352b..11e7766f374b 100644 --- a/tests/python/contrib/test_rpc_tracker.py +++ b/tests/python/contrib/test_rpc_tracker.py @@ -84,7 +84,7 @@ def myfunc(remote): f1 = remote2.get_function("rpc.test2.addone") assert f1(10) == 11 - except tvm.TVMError as e: + except tvm.error.TVMError as e: pass remote3 = tclient.request("abc") f1 = remote3.get_function("rpc.test2.addone") diff --git a/tests/python/frontend/mxnet/test_graph.py b/tests/python/frontend/mxnet/test_graph.py index 467d5529d0c1..6e870000a76b 100644 --- a/tests/python/frontend/mxnet/test_graph.py +++ b/tests/python/frontend/mxnet/test_graph.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. import mxnet as mx + +import tvm from tvm import relay from tvm.relay import transform import model_zoo @@ -99,7 +101,7 @@ def relay_compose(F, **kwargs): z = F.split(x, **kwargs) z = F.subtract(F.add(z[0], z[2]), y) func = relay.Function(relay.analysis.free_vars(z), z) - return relay.Module.from_expr(func) + return tvm.IRModule.from_expr(func) mx_sym = mx_compose(mx, num_outputs=3, axis=1) mod, _ = relay.frontend.from_mxnet( diff --git a/tests/python/frontend/mxnet/test_qnn_ops_utils.py b/tests/python/frontend/mxnet/test_qnn_ops_utils.py index 0c7374d4d8a7..4ee5f2e3c3c3 100644 --- a/tests/python/frontend/mxnet/test_qnn_ops_utils.py +++ b/tests/python/frontend/mxnet/test_qnn_ops_utils.py @@ -34,7 +34,7 @@ def dequantize_test_driver(in_dtype, quant_args, in_data, verify_output_data): max_range=max_range, in_dtype=in_dtype) mod = relay.Function(relay.analysis.free_vars(dequantized_output), dequantized_output) - mod = relay.Module.from_expr(mod) + mod = tvm.IRModule.from_expr(mod) with relay.build_config(opt_level=3): graph, lib, params = relay.build(mod, "llvm", params=None) rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) @@ -90,7 +90,7 @@ def quantize_test_driver(out_dtype, quant_args, in_data, verify_output_data): max_range=max_range, out_dtype=out_dtype) mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output) - mod = relay.Module.from_expr(mod) + mod = tvm.IRModule.from_expr(mod) with relay.build_config(opt_level=3): graph, lib, params = relay.build(mod, "llvm", params=None) rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 81594c0d04ed..8f631f8fd047 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -23,7 +23,7 @@ import numpy as np -mod = relay.Module() +mod = tvm.IRModule() p = Prelude(mod) add_nat_definitions(p) @@ -730,7 +730,7 @@ def check_tensor_array(ta_mod, ref_res, *args, dtype="float32", def test_tensor_expand_dims(): def run(dtype): x = relay.var('x') - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) expand_dims_func = p.get_var('tensor_expand_dims', dtype) tensor1 = p.get_var('tensor1', dtype) @@ -745,7 +745,7 @@ def run(dtype): def test_tensor_array_constructor(): def run(dtype): x = relay.var('x') - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) tensor_array = p.get_var('tensor_array', dtype) mod["main"] = relay.Function([x], tensor_array(x)) @@ -757,7 +757,7 @@ def run(dtype): def test_tensor_array_read(): def run(dtype): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) l = relay.var('l') i = relay.var('i') @@ -773,7 +773,7 @@ def run(dtype): def test_tensor_array_write(): def run(dtype): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) v1 = relay.var('v1') v2 = relay.var('v2') @@ -793,7 +793,7 @@ def run(dtype): def test_tensor_array_stack(): def run(dtype): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) tensor_array = p.get_var('tensor_array', dtype) tensor1 = p.get_var('tensor1', dtype) @@ -815,7 +815,7 @@ def run(dtype): def test_tensor_array_unstack(): def run(dtype): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) unstack_tensor1 = p.get_var('tensor_array_unstack_tensor1', dtype) v = relay.var('v') @@ -828,7 +828,7 @@ def run(dtype): def test_tensor_take(): def run(dtype): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) take = p.get_var('tensor_take', dtype) tensor2 = p.get_var('tensor2', dtype) @@ -847,7 +847,7 @@ def run(dtype): def test_tensor_concatenate(): def run(dtype): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) concat = p.get_var('tensor_concatenate', dtype) tensor1 = p.get_var('tensor1', dtype) @@ -865,7 +865,7 @@ def run(dtype): def test_tensor_array_concat(): def run(dtype): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) v1 = relay.var('v1') v2 = relay.var('v2') @@ -888,9 +888,9 @@ def run(dtype): def test_tensor_array_scatter(): def run(dtype): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) - + # tensor array v1 = relay.var('v1') v2 = relay.var('v2') @@ -938,9 +938,9 @@ def run(dtype): def test_tensor_array_split(): def run(dtype): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) - + # tensor array v1 = relay.var('v1') v2 = relay.var('v2') diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index fe2e9e9bb82e..3e392a8e630f 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -37,7 +37,7 @@ def verify_any_broadcast(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op): dtype = 'float32' x = relay.var('x', shape=x_shape, dtype=dtype) y = relay.var('y', shape=y_shape, dtype=dtype) - mod = relay.module.Module() + mod = tvm.IRModule() mod["main"] = relay.Function([x, y], op(x, y)) x_np = np.random.uniform(size=x_np_shape).astype(dtype) y_np = np.random.uniform(size=y_np_shape).astype(dtype) @@ -62,7 +62,7 @@ def test_any_broadcast(): def verify_any_elemwise(x_shape, x_np_shape, op, np_op): dtype = 'float32' x = relay.var('x', shape=x_shape, dtype=dtype) - mod = relay.module.Module() + mod = tvm.IRModule() mod["main"] = relay.Function([x], op(x)) x_np = np.random.uniform(size=x_np_shape).astype(dtype) res_np = np_op(x_np) @@ -96,7 +96,7 @@ def check_fail(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op): def verify_any_full(x_shape, x_np_shape, relay_op, np_op, dtype='float32'): x = relay.var('x', shape=x_shape, dtype=dtype) - mod = relay.module.Module() + mod = tvm.IRModule() mod['main'] = relay.Function([x], relay.zeros_like(x)) x_np = np.random.uniform(size=x_np_shape).astype(dtype) res_np = np.zeros_like(x_np) @@ -126,7 +126,7 @@ def test_any_concat(): xx = x - relay.expr.const(3.0) yy = y * relay.expr.const(5.0) z = relay.op.concatenate([xx, yy], axis=0) - mod = relay.module.Module() + mod = tvm.IRModule() mod["main"] = relay.Function([x, y], z) x_np = np.random.uniform(size=(3, 2)).astype('float32') y_np = np.random.uniform(size=(1, 2)).astype('float32') @@ -139,7 +139,7 @@ def test_any_concat(): def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape): x = relay.var('x', shape=x_shape, dtype="float32") y = relay.reshape(x, newshape=newshape) - mod = relay.module.Module() + mod = tvm.IRModule() mod["main"] = relay.Function([x], y) data = np.random.uniform(size=x_np_shape).astype('float32') for kind in ["debug", "vm"]: @@ -158,7 +158,7 @@ def test_any_reshape(): def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"): x = relay.var('x', shape=x_shape, dtype=dtype) y = relay.argwhere(x) - mod = relay.module.Module() + mod = tvm.IRModule() mod["main"] = relay.Function([x], y) data = np.random.choice([0, 1, 2, 3], size=x_np_shape).astype(dtype) expected = np.argwhere(data) @@ -186,7 +186,7 @@ def test_any_argwhere(): verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), "int8") def verify_any_take(data_shape, indices_shape, axis, data_np_shape, indices_np_shape): - mod = relay.Module() + mod = tvm.IRModule() data = relay.var('data', shape=data_shape, dtype='float32') indices = relay.var('indices', shape=indices_shape, dtype='int32') y = relay.take(data, indices, axis=axis) @@ -212,7 +212,7 @@ def test_any_take(): verify_any_take(any_dims(2), any_dims(4), -1, (4, 5), (2, 3, 4, 5)) def verify_any_tile(dshape, reps, np_dshape, np_reps): - mod = relay.Module() + mod = tvm.IRModule() x = relay.var("x", shape=dshape, dtype="float32") y = relay.tile(x, reps=reps) mod["main"] = relay.Function([x], y) @@ -233,7 +233,7 @@ def test_any_tile(): def test_any_shape_of(): x = relay.var('x', shape=any_dims(2), dtype='float32') y = relay.shape_of(x) - mod = relay.module.Module() + mod = tvm.IRModule() mod["main"] = relay.Function([x], y) data = np.random.uniform(size=(3, 4)).astype('float32') for kind in ["debug", "vm"]: @@ -244,7 +244,7 @@ def test_any_shape_of(): x = relay.var('x', shape=any_dims(3), dtype='float32') y0 = relay.shape_of(x) y1 = relay.take(y0, relay.const(1, 'int32')) - mod = relay.module.Module() + mod = tvm.IRModule() mod["main"] = relay.Function([x], y1) data = np.random.uniform(size=(2, 3, 4)).astype('float32') for kind in ["debug", "vm"]: @@ -254,7 +254,7 @@ def test_any_shape_of(): def verify_any_reduce(reduce_op, data_shape, axis, exclude, keepdims, static_data_shape, ref_out_shape): - mod = relay.Module() + mod = tvm.IRModule() dtype = "bool" if reduce_op == relay.all else "float32" data = relay.var('data', shape=data_shape, dtype=dtype) y = reduce_op(data, axis, keepdims, exclude) @@ -277,7 +277,7 @@ def test_any_reduce(): verify_any_reduce(relay.variance, any_dims(5), (2, 4), False, False, (3, 4, 5, 6, 7), (3, 4, 6)) def verify_any_layout_transform(data_shape, src_layout, dst_layout, static_data_shape, ref_out_shape): - mod = relay.Module() + mod = tvm.IRModule() dtype = "float32" data = relay.var('data', shape=data_shape, dtype=dtype) y = relay.layout_transform(data, src_layout, dst_layout) @@ -297,7 +297,7 @@ def test_any_layout_transform(): verify_any_layout_transform((16, 1), "CH", "C4cH", (16, 1), (4, 4, 1)) def verify_any_expand_dims(data_shape, axis, num_newaxis, static_data_shape, ref_out_shape): - mod = relay.Module() + mod = tvm.IRModule() dtype = "float32" data = relay.var('data', shape=data_shape, dtype=dtype) y = relay.expand_dims(data, axis=axis, num_newaxis=num_newaxis) @@ -314,7 +314,7 @@ def test_any_expand_dims(): verify_any_expand_dims(any_dims(3), -1, 2, (1, 2, 3), (1, 2, 3, 1, 1)) def verify_any_transpose(data_shape, axes, static_data_shape): - mod = relay.Module() + mod = tvm.IRModule() dtype = "float32" data = relay.var('data', shape=data_shape, dtype=dtype) y = relay.transpose(data, axes=axes) @@ -332,7 +332,7 @@ def test_any_transpose(): verify_any_transpose(any_dims(6), (0, 1, 3, 2, 5, 4), (11, 12, 2, 1, 9, 17)) def verify_any_squeeze(data_shape, axis, static_data_shape): - mod = relay.Module() + mod = tvm.IRModule() dtype = "float32" data = relay.var('data', shape=data_shape, dtype=dtype) y = relay.squeeze(data, axis=axis) @@ -349,7 +349,7 @@ def test_any_squeeze(): verify_any_squeeze((1, relay.Any(), relay.Any(), 1, relay.Any(), relay.Any()), (0, 3), (1, 12, 2, 1, 9, 17)) def test_any_reshape_like(): - mod = relay.Module() + mod = tvm.IRModule() dtype = "float32" data = relay.var('data', shape=(relay.Any(), 3, 10), dtype=dtype) shape_like = relay.var('data', shape=(relay.Any(), 5, 6), dtype=dtype) @@ -366,7 +366,7 @@ def test_any_reshape_like(): def verify_any_conv2d_NCHWc(data_shape, kernel_shape, strides, padding, dilation, data_layout, kernel_layout, out_layout, static_data_shape, ref_out_shape): - mod = relay.Module() + mod = tvm.IRModule() dtype = "float32" data = relay.var('data', shape=data_shape, dtype=dtype) kernel = relay.var('kernel', shape=kernel_shape, dtype=dtype) @@ -392,7 +392,7 @@ def test_any_conv2d_NCHWc(): def verify_any_pool2d(pool_type, data_shape, pool_size, strides, padding, layout, static_data_shape, ref_out_shape): - mod = relay.Module() + mod = tvm.IRModule() dtype = "float32" pool_func = relay.nn.max_pool2d if pool_type == "max" else relay.nn.avg_pool2d data = relay.var('data', shape=data_shape, dtype=dtype) @@ -414,7 +414,7 @@ def test_any_pool2d(): (3, 3), (2, 2), (1, 1), "NCHW4c", (2, 3, 220, 220, 4), (2, 3, 110, 110, 4)) def verify_any_global_pool2d(pool_type, data_shape, layout, static_data_shape, ref_out_shape): - mod = relay.Module() + mod = tvm.IRModule() dtype = "float32" pool_func = relay.nn.global_max_pool2d if pool_type == "max" else relay.nn.global_avg_pool2d data = relay.var('data', shape=data_shape, dtype=dtype) @@ -436,7 +436,7 @@ def test_any_global_pool2d(): "NCHW4c", (2, 3, 220, 220, 4), (2, 3, 1, 1, 4)) def verify_any_split(data_shape, indices_or_sections, axis, static_data_shape, ref_out_shape): - mod = relay.Module() + mod = tvm.IRModule() dtype = "float32" data = relay.var('data', shape=data_shape, dtype=dtype) y = relay.split(data, indices_or_sections, axis) @@ -454,7 +454,7 @@ def test_any_split(): verify_any_split((relay.Any(), 12), (1, 4, 8), 1, (7, 12), [(7, 1), (7, 3), (7, 4)]) def test_any_batch_flatten(): - mod = relay.Module() + mod = tvm.IRModule() dtype = "float32" data = relay.var('data', shape=any_dims(3), dtype=dtype) y = relay.nn.batch_flatten(data) @@ -469,7 +469,7 @@ def test_any_batch_flatten(): def verify_any_dense(data_shape, weight_shape, units, static_data_shape, static_weight_shape, ref_out_shape): - mod = relay.Module() + mod = tvm.IRModule() dtype = "float32" data = relay.var('data', shape=data_shape, dtype=dtype) weight = relay.var('weight', shape=weight_shape, dtype=dtype) @@ -488,7 +488,7 @@ def test_any_dense(): verify_any_dense(any_dims(2), (50, relay.Any()), 50, (4, 40), (50, 40), (4, 50)) def verify_any_pad(data_shape, pad_width, static_data_shape): - mod = relay.Module() + mod = tvm.IRModule() dtype = "float32" data = relay.var('data', shape=data_shape, dtype=dtype) y = relay.nn.pad(data, pad_width) @@ -505,7 +505,7 @@ def test_any_pad(): verify_any_pad(any_dims(4), ((1, 0), (1, 3), (0, 2), (9, 0)), (13, 11, 3, 1)) def verify_any_softmax(data_shape, axis, static_data_shape, ref_out_shape): - mod = relay.Module() + mod = tvm.IRModule() dtype = "float32" data = relay.var('data', shape=data_shape, dtype=dtype) y = relay.nn.softmax(data, axis) @@ -525,7 +525,7 @@ def test_fused_ops(): x = relay.var('x', shape=(relay.Any(), relay.Any()), dtype='float32') y0 = x + relay.const(1.0, 'float32') y1 = y0 * relay.const(2.0, 'float32') - mod = relay.module.Module() + mod = tvm.IRModule() mod["main"] = relay.Function([x], y1) data = np.random.uniform(size=(5, 4)).astype('float32') for kind in ["vm"]: @@ -542,7 +542,7 @@ def test_arange_with_dynamic_shape(): y2 = relay.op.arange(y1, dtype="int32") y3 = y2 + relay.const(1, dtype="int32") data = np.random.rand(10, 5, 3).astype('float32') - mod = relay.module.Module() + mod = tvm.IRModule() mod["main"] = relay.Function([x], y3) for kind in ["debug", "vm"]: ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") @@ -577,7 +577,7 @@ def _body(i, st): start = relay.var('start', shape=(), dtype='int32') body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1))) func = relay.Function([start], relay.TupleGetItem(body, 1)) - mod = relay.module.Module() + mod = tvm.IRModule() mod["main"] = func data = np.array(0.0, dtype='int32') ref = np.array([0] + list(range(10))).reshape((11, 1)).astype("int32") diff --git a/tests/python/relay/test_backend_compile_engine.py b/tests/python/relay/test_backend_compile_engine.py index 43090eea15f0..640da0bd2ebe 100644 --- a/tests/python/relay/test_backend_compile_engine.py +++ b/tests/python/relay/test_backend_compile_engine.py @@ -27,7 +27,7 @@ def get_func(shape): y = relay.add(x, x) z = relay.add(y, x) f = relay.Function([x], z) - mod = relay.Module.from_expr(f) + mod = tvm.IRModule.from_expr(f) mod = relay.transform.InferType()(mod) return mod["main"] z1 = engine.lower(get_func((10,)), "llvm") @@ -59,7 +59,7 @@ def test_compile_placeholder_bypass(): result = relay.Tuple([x, relay.op.concatenate([y, z], axis=0)]) func = relay.Function(relay.analysis.free_vars(result), result) with relay.build_config(opt_level=0): - graph, lib, params = relay.build(relay.Module.from_expr(func), 'llvm') + graph, lib, params = relay.build(tvm.IRModule.from_expr(func), 'llvm') def test_compile_injective_with_tuple(): @@ -68,7 +68,7 @@ def test_compile_injective_with_tuple(): x_transpose = relay.transpose(x) output = relay.Tuple([x_transpose, y]) func = relay.Function([x, y], output) - relay.build(relay.Module.from_expr(func), 'llvm') + relay.build(tvm.IRModule.from_expr(func), 'llvm') def test_compile_tuple_dup(): @@ -76,7 +76,7 @@ def test_compile_tuple_dup(): log = relay.log(x) output = relay.Tuple([log, log]) f = relay.Function([x], output) - relay.build(relay.Module.from_expr(f), 'llvm') + relay.build(tvm.IRModule.from_expr(f), 'llvm') def test_compile_full(): @@ -88,7 +88,7 @@ def test_compile_full(): tvm.expr.IntImm('int32', 64)) output = relay.full(relay.const(0, 'int32'), shape=shape, dtype='int32') f = relay.Function([], output) - mod = relay.Module.from_expr(f) + mod = tvm.IRModule.from_expr(f) mod = relay.qnn.transform.CanonicalizeOps()(mod) relay.build(mod, 'llvm') diff --git a/tests/python/relay/test_backend_graph_runtime.py b/tests/python/relay/test_backend_graph_runtime.py index fbccb94bc670..d5d29b645cfa 100644 --- a/tests/python/relay/test_backend_graph_runtime.py +++ b/tests/python/relay/test_backend_graph_runtime.py @@ -21,7 +21,6 @@ from tvm.contrib import graph_runtime from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.op import add -from tvm.relay.module import Module from tvm.relay.testing.config import ctx_list # @tq, @jr should we put this in testing ns? @@ -100,7 +99,7 @@ def test_with_params(): x_data = np.random.rand(10, 5).astype('float32') y_data = np.random.rand(1, 5).astype('float32') params = {"y": y_data} - graph, lib, params = relay.build(relay.Module.from_expr(func), "llvm", params=params) + graph, lib, params = relay.build(tvm.IRModule.from_expr(func), "llvm", params=params) mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) mod.set_input(**params) mod.set_input(x=x_data) @@ -123,7 +122,7 @@ def test_plan_memory(): z = relay.exp(z) z = relay.exp(z) func = relay.Function([x, y], z) - mod = relay.Module.from_expr(func) + mod = tvm.IRModule.from_expr(func) mod = relay.transform.FuseOps(0)(mod) func = mod["main"] smap = relay.backend._backend.GraphPlanMemory(func) @@ -169,7 +168,7 @@ def unit_numpy(X, W): for target, ctx in ctx_list(): with relay.build_config(opt_level=2): - graph, lib, params = relay.build(relay.Module.from_expr(z), target) + graph, lib, params = relay.build(tvm.IRModule.from_expr(z), target) m = graph_runtime.create(graph, lib, ctx) m.set_input("X", tvm.nd.array(x.astype(dtype))) m.set_input("y", tvm.nd.array(y.astype(dtype))) diff --git a/tests/python/relay/test_backend_interpreter.py b/tests/python/relay/test_backend_interpreter.py index 28906f19ea3a..9b548f12f65b 100644 --- a/tests/python/relay/test_backend_interpreter.py +++ b/tests/python/relay/test_backend_interpreter.py @@ -93,7 +93,7 @@ def test_subtract(): def test_simple_loop(): - mod = relay.module.Module({}) + mod = tvm.IRModule({}) sum_up = relay.GlobalVar('sum_up') i = relay.var('i', shape=[], dtype='int32') sb = ScopeBuilder() @@ -110,7 +110,7 @@ def test_simple_loop(): def test_loop(): - mod = relay.module.Module({}) + mod = tvm.IRModule({}) sum_up = relay.GlobalVar('sum_up') i = relay.var('i', shape=[], dtype='int32') accum = relay.var('accum', shape=[], dtype='int32') @@ -129,7 +129,7 @@ def test_loop(): def test_ref(): - mod = relay.Module() + mod = tvm.IRModule() three_with_ref = relay.GlobalVar('three_with_ref') i = relay.Var('i') iv = relay.Var('iv') @@ -168,7 +168,7 @@ def test_kwargs_params(): def test_function_taking_adt_ref_tuple(): - mod = relay.Module() + mod = tvm.IRModule() prelude = relay.prelude.Prelude(mod) intrp = create_executor("debug", mod) @@ -212,7 +212,7 @@ def test_tuple_passing(): relay.ty.TensorType((), 'int64')])) fn = relay.Function([x], relay.expr.TupleGetItem(x, 0)) - mod = relay.Module({}) + mod = tvm.IRModule({}) gv = relay.GlobalVar('main') mod[gv] = fn mod = relay.transform.InferType()(mod) diff --git a/tests/python/relay/test_cpp_build_module.py b/tests/python/relay/test_cpp_build_module.py index 165e00d9c702..2af4a2030f4f 100644 --- a/tests/python/relay/test_cpp_build_module.py +++ b/tests/python/relay/test_cpp_build_module.py @@ -43,7 +43,7 @@ def test_basic_build(): targets = { tvm.expr.IntImm("int32", ctx.device_type): tgt } - g_json, mmod, params = relay.build(relay.Module.from_expr(func), targets, "llvm", params=params) + g_json, mmod, params = relay.build(tvm.IRModule.from_expr(func), targets, "llvm", params=params) # test rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx) @@ -115,7 +115,7 @@ def check_conversion(tgt, ctx): # build with relay.build_config(opt_level=1): - g_json, mmod, params = relay.build(relay.Module.from_expr(func), tgt) + g_json, mmod, params = relay.build(tvm.IRModule.from_expr(func), tgt) # test rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx) diff --git a/tests/python/relay/test_error_reporting.py b/tests/python/relay/test_error_reporting.py index 74e651884803..aef93ad9f4dc 100644 --- a/tests/python/relay/test_error_reporting.py +++ b/tests/python/relay/test_error_reporting.py @@ -19,12 +19,12 @@ def check_type_err(expr, msg): try: - mod = relay.Module.from_expr(expr) + mod = tvm.IRModule.from_expr(expr) mod = relay.transform.InferType()(mod) entry = mod["main"] expr = entry if isinstance(expr, relay.Function) else entry.body assert False - except tvm.TVMError as err: + except tvm.error.TVMError as err: assert msg in str(err) def test_wellformed(): diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index 17585835c178..3735259280a4 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -125,7 +125,7 @@ def test_multi_node_subgraph(): r = relay.concatenate((call0, call1, q2), axis=0) f = relay.Function([x, w0, w1, w2, w3, w4, w5, w6, w7], r) - mod = relay.Module() + mod = tvm.IRModule() mod["main"] = f mod = relay.transform.InferType()(mod) @@ -154,7 +154,7 @@ def test_extern_gcc_single_op(): f = relay.Function([x0, y0], z) f = set_external_func_attr(f, "ccompiler", "ccompiler_0") call = relay.Call(f, [x, y]) - mod = relay.Module.from_expr(call) + mod = tvm.IRModule.from_expr(call) x_data = np.random.rand(8, 8).astype('float32') y_data = np.random.rand(8, 8).astype('float32') @@ -188,7 +188,7 @@ def test_extern_gcc(): sub = relay.Function([x2, y2], sub) sub = set_external_func_attr(sub, "ccompiler", "ccompiler_0") call_sub = relay.Call(sub, [call_mul, call_add]) - mod = relay.Module.from_expr(call_sub) + mod = tvm.IRModule.from_expr(call_sub) x_data = np.random.rand(2, 2).astype('float32') y_data = np.random.rand(2, 2).astype('float32') @@ -223,12 +223,12 @@ def test_extern_dnnl(): out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2) f = relay.Function([data1, weight1, weight2], out) - ref_mod = relay.Module() + ref_mod = tvm.IRModule() ref_mod['main'] = f f = set_external_func_attr(f, "dnnl", "dnnl_0") call = relay.Call(f, [data0, weight0, weight0]) - mod = relay.Module.from_expr(call) + mod = tvm.IRModule.from_expr(call) i_data = np.random.uniform(0, 1, ishape).astype(dtype) w_data = np.random.uniform(0, 1, w1shape).astype(dtype) diff --git a/tests/python/relay/test_external_runtime.py b/tests/python/relay/test_external_runtime.py index 5fc03df8a9a0..7e8c83217451 100644 --- a/tests/python/relay/test_external_runtime.py +++ b/tests/python/relay/test_external_runtime.py @@ -330,7 +330,7 @@ def get_synthetic_lib(): sub2 = relay.subtract(add2, w7) ret = relay.concatenate((subgraph0_ret, subgraph1_ret, sub2), 0) func = relay.Function([x, w0, w1, w2, w3, w4, w5, w6, w7], ret) - mod = relay.Module.from_expr(func) + mod = tvm.IRModule.from_expr(func) _, lib, _ = relay.build(mod, "llvm") return lib diff --git a/tests/python/relay/test_feature.py b/tests/python/relay/test_feature.py index 64eda9d04e7c..9066e85cf6da 100644 --- a/tests/python/relay/test_feature.py +++ b/tests/python/relay/test_feature.py @@ -50,7 +50,7 @@ def test_ad(): x = relay.var("x", t) func = relay.Function([x], x + x) func = run_infer_type(func) - mod = relay.Module.from_expr(gradient(func)) + mod = tvm.IRModule.from_expr(gradient(func)) mod = relay.transform.InferType()(mod) back_func = mod["main"] feats = detect_feature(back_func) diff --git a/tests/python/relay/test_ir_module.py b/tests/python/relay/test_ir_module.py index 72a92c8697fc..939672d42152 100644 --- a/tests/python/relay/test_ir_module.py +++ b/tests/python/relay/test_ir_module.py @@ -17,7 +17,6 @@ """Tests for module functionality.""" import tvm from tvm import relay -from tvm.relay import Module from tvm.relay.prelude import Prelude from tvm.relay.testing import add_nat_definitions @@ -30,10 +29,10 @@ def adt_list(p): def test_constructor_tag_round_trip(): - mod1 = Module() + mod1 = tvm.IRModule() p1 = Prelude(mod1) add_nat_definitions(p1) - mod2 = Module() + mod2 = tvm.IRModule() p2 = Prelude(mod2) add_nat_definitions(p2) @@ -52,7 +51,7 @@ def test_constructor_tag_differences(): # ensure that if we have the type data for a given ADT, the tags # for the constructors of the *same ADT* are simple offsets from # each other - mod = Module() + mod = tvm.IRModule() p = Prelude(mod) add_nat_definitions(p) diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index dec840a214a0..29e578b08b11 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -23,15 +23,15 @@ import numpy as np def check_json_roundtrip(node): - json_str = tvm.save_json(node) - back = tvm.load_json(json_str) + json_str = tvm.ir.save_json(node) + back = tvm.ir.load_json(json_str) assert graph_equal(back, node) def test_bad_constructor(): try: x = relay.ty.TensorType("xx", "xx") - except tvm.TVMError: + except tvm.error.TVMError: pass @@ -48,7 +48,7 @@ def test_span(): # span is not a node so we can't use graph_equal # to test the round trip - back = tvm.load_json(tvm.save_json(span)) + back = tvm.ir.load_json(tvm.ir.save_json(span)) assert back.source == span.source assert back.lineno == span.lineno assert back.col_offset == span.col_offset @@ -67,8 +67,8 @@ def test_tensor_type(): def test_type_param(): - tp = relay.TypeVar('name', relay.Kind.Type) - assert tp.kind == relay.Kind.Type + tp = relay.TypeVar('name', relay.TypeKind.Type) + assert tp.kind == relay.TypeKind.Type # assert tp.span # TODO allow us to set span str(tp) check_json_roundtrip(tp) @@ -91,7 +91,7 @@ def test_func_type(): def test_tuple_type(): - tp = relay.TypeVar('tp', relay.Kind.Type) + tp = relay.TypeVar('tp', relay.TypeKind.Type) tf = relay.FuncType(tvm.convert([]), None, tvm.convert([]), tvm.convert([])) tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') fields = tvm.convert([tp, tf, tt]) @@ -103,13 +103,13 @@ def test_tuple_type(): def test_type_relation(): - tp = relay.TypeVar('tp', relay.Kind.Type) + tp = relay.TypeVar('tp', relay.TypeKind.Type) tf = relay.FuncType(tvm.convert([]), None, tvm.convert([]), tvm.convert([])) tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') args = tvm.convert([tp, tf, tt]) num_inputs = 2 - func = tvm.get_env_func("tvm.relay.type_relation.Broadcast") + func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast") attrs = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) tr = relay.TypeRelation(func, args, num_inputs, attrs) @@ -193,8 +193,8 @@ def test_function_attrs(): assert fn.span == None str(fn) check_json_roundtrip(fn) - json_str = tvm.save_json(fn) - fn_after = tvm.load_json(json_str) + json_str = tvm.ir.save_json(fn) + fn_after = tvm.ir.load_json(json_str) model_params_after = fn_after.get_params() after_keys = [item[0] for item in model_params_after.items()] for key1, key2 in zip(model_params, after_keys): diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index a871ae144387..261cbb97c4af 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -355,7 +355,7 @@ def @id(%x: int32) -> int32 { %x } """) - assert isinstance(id_defn, relay.Module) + assert isinstance(id_defn, tvm.IRModule) def test_recursive_call(): @@ -365,7 +365,7 @@ def @id(%x: int32) -> int32 { @id(%x) } """) - assert isinstance(id_defn, relay.Module) + assert isinstance(id_defn, tvm.IRModule) def test_ifelse(): @@ -639,7 +639,7 @@ def test_tuple_type(): def test_adt_defn(): - mod = relay.Module() + mod = tvm.IRModule() glob_typ_var = relay.GlobalTypeVar("Ayy") prog = relay.TypeData( @@ -656,7 +656,7 @@ def test_adt_defn(): def test_empty_adt_defn(): - mod = relay.Module() + mod = tvm.IRModule() glob_typ_var = relay.GlobalTypeVar("Ayy") prog = relay.TypeData(glob_typ_var, [], []) @@ -670,7 +670,7 @@ def test_empty_adt_defn(): def test_multiple_cons_defn(): - mod = relay.Module() + mod = tvm.IRModule() list_var = relay.GlobalTypeVar("List") typ_var = relay.TypeVar("A") @@ -696,7 +696,7 @@ def test_multiple_type_param_defn(): relay.Constructor("Left", [typ_var_a], glob_typ_var), relay.Constructor("Right", [typ_var_b], glob_typ_var), ]) - mod = relay.Module() + mod = tvm.IRModule() mod[glob_typ_var] = prog assert parses_as( """ @@ -713,7 +713,7 @@ def test_match(): # pair each match keyword with whether it specifies a complete match or not match_keywords = [("match", True), ("match?", False)] for (match_keyword, is_complete) in match_keywords: - mod = relay.Module() + mod = tvm.IRModule() list_var = relay.GlobalTypeVar("List") typ_var = relay.TypeVar("A") @@ -773,7 +773,7 @@ def @length[A](%%xs: List[A]) -> int32 { def test_adt_cons_expr(): - mod = relay.Module() + mod = tvm.IRModule() list_var = relay.GlobalTypeVar("List") typ_var = relay.TypeVar("A") @@ -853,7 +853,7 @@ def @id[A](%x: A) -> A { x } def test_extern_adt_defn(): # TODO(weberlo): update this test once extern is implemented - mod = relay.Module() + mod = tvm.IRModule() extern_var = relay.GlobalTypeVar("T") typ_var = relay.TypeVar("A") diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index e84de6765177..e2a0bdc205d6 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -58,7 +58,7 @@ def test_env(): z = relay.add(x, y) z = relay.add(z, z) f = relay.Function([x, y], z) - env = relay.Module() + env = tvm.IRModule() env["myf"] = f text = astext(env) assert "def @myf" in text diff --git a/tests/python/relay/test_ir_well_formed.py b/tests/python/relay/test_ir_well_formed.py index bee0a021ac5b..fbbfbd23a6c2 100644 --- a/tests/python/relay/test_ir_well_formed.py +++ b/tests/python/relay/test_ir_well_formed.py @@ -50,7 +50,7 @@ def test_tuple_get_item(): def test_adt(): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) x = relay.Var("x") some_case = relay.Clause(relay.PatternConstructor(p.some, diff --git a/tests/python/relay/test_json_compact.py b/tests/python/relay/test_json_compact.py index 143c3d23a5e3..40b686a05c5e 100644 --- a/tests/python/relay/test_json_compact.py +++ b/tests/python/relay/test_json_compact.py @@ -34,11 +34,11 @@ def test_type_var(): "attrs": {"tvm_version": "0.6.0"}, "b64ndarrays": [], } - tvar = tvm.load_json(json.dumps(data)) + tvar = tvm.ir.load_json(json.dumps(data)) assert isinstance(tvar, relay.TypeVar) assert tvar.name_hint == "in0" nodes[1]["type_key"] = "relay.GlobalTypeVar" - tvar = tvm.load_json(json.dumps(data)) + tvar = tvm.ir.load_json(json.dumps(data)) assert isinstance(tvar, relay.GlobalTypeVar) assert tvar.name_hint == "in0" diff --git a/tests/python/relay/test_memory_alloc.py b/tests/python/relay/test_memory_alloc.py index 5c1bbc72bf22..18b1500dfc3c 100644 --- a/tests/python/relay/test_memory_alloc.py +++ b/tests/python/relay/test_memory_alloc.py @@ -20,7 +20,7 @@ from tvm.relay import memory_alloc def check_vm_alloc(func, check_fn): - mod = relay.Module() + mod = tvm.IRModule() mod['main'] = func ex = relay.create_executor('vm', mod) args = [] @@ -37,11 +37,11 @@ def storage_type(mod): return relay.TypeCall(mod.get_global_type_var("Storage"), []) def test_tyck_alloc_storage(): - mod = relay.Module() + mod = tvm.IRModule() mod.import_from_std("core.rly") def test_tyck_alloc_tensor(): - mod = relay.Module() + mod = tvm.IRModule() mod.import_from_std("core.rly") sto = relay.Var("x", storage_type(mod)) sh = relay.const(np.array([1, 2]), dtype="int64") diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 6a6f21d9241f..c3033e9181cb 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -55,7 +55,7 @@ def test_checkpoint_alpha_equal(): with transform.PassContext(opt_level=3): passes = [transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)] - mod = transform.Sequential(passes)(relay.Module.from_expr(df)) + mod = transform.Sequential(passes)(tvm.IRModule.from_expr(df)) df = mod["main"] df_parsed = relay.parser.fromtext( @@ -111,7 +111,7 @@ def test_checkpoint_alpha_equal_tuple(): with transform.PassContext(opt_level=3): passes = [transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)] - mod = transform.Sequential(passes)(relay.Module.from_expr(df)) + mod = transform.Sequential(passes)(tvm.IRModule.from_expr(df)) df = mod["main"] df_parsed = relay.parser.fromtext( @@ -424,7 +424,7 @@ def _get_oshape(indices_shape, depth, axis): else: oshape.append(indices_shape[indices_index]) indices_index += 1 - + return oshape def _verify(indices_shape, depth, on_value, off_value, axis, dtype): @@ -443,7 +443,7 @@ def _verify(indices_shape, depth, on_value, off_value, axis, dtype): intrp = relay.create_executor(kind, ctx=ctx, target=target) out_relay = intrp.evaluate(func)(indices_np) tvm.testing.assert_allclose(out_relay.asnumpy(), out_np) - + _verify((3,), 3, 1, 0, -1, "int32") _verify((3,), 3, 1.0, 0.0, -1, "float32") _verify((2, 2), 5, 2, -2, 0, "int32") diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 0acd83639363..ea729618097e 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -237,7 +237,7 @@ def compile_test_conv2d_arm_cpu(dtype, out_dtype, scale, dshape, kshape, groups=groups, **attrs) func = relay.Function([x, w], y) - mod = tvm.relay.Module() + mod = tvm.IRModule() mod["main"] = func test_schedule='{"i": ["llvm -device=arm_cpu", "topi_nn_depthwise_conv2d_nchw", \ @@ -276,7 +276,7 @@ def compile_test_conv2d_arm_cpu(dtype, out_dtype, scale, dshape, kshape, dshape = (1, 512, 32, 32) kshape = (512, 1, 3, 3) compile_test_conv2d_arm_cpu("float32", "float32", 1, dshape, kshape, - padding=(1, 1), channels=512, + padding=(1, 1), channels=512, groups=512, kernel_size=(3 ,3)) # CUDA is disabled for 'direct' schedule: @@ -344,7 +344,7 @@ def run_test_conv2d_cuda(dtype, out_dtype, scale, dshape, kshape, groups=groups, **attrs) func = relay.Function([x, w], y) - mod = relay.Module() + mod = tvm.IRModule() mod['main'] = func mod = relay.transform.InferType()(mod) diff --git a/tests/python/relay/test_op_qnn_add.py b/tests/python/relay/test_op_qnn_add.py index 033a1041b579..e1f54ed4b78c 100644 --- a/tests/python/relay/test_op_qnn_add.py +++ b/tests/python/relay/test_op_qnn_add.py @@ -35,7 +35,7 @@ def test_tflite_same_io_qnn_params(): output_zero_point=relay.const(127, 'int32')) func = relay.Function([x, y], z) - mod = relay.Module.from_expr(func) + mod = tvm.IRModule.from_expr(func) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] @@ -73,7 +73,7 @@ def test_tflite_different_io_qnn_params(): output_zero_point=relay.const(128, 'int32')) func = relay.Function([x, y], z) - mod = relay.Module.from_expr(func) + mod = tvm.IRModule.from_expr(func) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] @@ -111,7 +111,7 @@ def test_saturation(): output_zero_point=relay.const(0, 'int32')) func = relay.Function([x, y], z) - mod = relay.Module.from_expr(func) + mod = tvm.IRModule.from_expr(func) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] @@ -133,7 +133,7 @@ def test_saturation(): output_zero_point=relay.const(0, 'int32')) func = relay.Function([x, y], z) - mod = relay.Module.from_expr(func) + mod = tvm.IRModule.from_expr(func) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] @@ -155,7 +155,7 @@ def test_saturation(): output_zero_point=relay.const(0, 'int32')) func = relay.Function([x, y], z) - mod = relay.Module.from_expr(func) + mod = tvm.IRModule.from_expr(func) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] @@ -177,7 +177,7 @@ def test_saturation(): output_zero_point=relay.const(0, 'int32')) func = relay.Function([x, y], z) - mod = relay.Module.from_expr(func) + mod = tvm.IRModule.from_expr(func) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] diff --git a/tests/python/relay/test_op_qnn_concatenate.py b/tests/python/relay/test_op_qnn_concatenate.py index ed496941cf8e..35c2f971a791 100644 --- a/tests/python/relay/test_op_qnn_concatenate.py +++ b/tests/python/relay/test_op_qnn_concatenate.py @@ -40,7 +40,7 @@ def test_same_io_qnn_params(): axis=axis) func = relay.Function([x, y], z) - mod = relay.Module.from_expr(func) + mod = tvm.IRModule.from_expr(func) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] @@ -71,7 +71,7 @@ def test_different_io_qnn_params(): axis=axis) func = relay.Function([x, y], z) - mod = relay.Module.from_expr(func) + mod = tvm.IRModule.from_expr(func) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] @@ -102,7 +102,7 @@ def test_few_same_io_qnn_params(): axis=axis) func = relay.Function([x, y], z) - mod = relay.Module.from_expr(func) + mod = tvm.IRModule.from_expr(func) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] @@ -133,7 +133,7 @@ def test_same_i_qnn_params(): axis=axis) func = relay.Function([x, y], z) - mod = relay.Module.from_expr(func) + mod = tvm.IRModule.from_expr(func) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] diff --git a/tests/python/relay/test_op_qnn_conv2d.py b/tests/python/relay/test_op_qnn_conv2d.py index ced12c843563..264475ca3432 100644 --- a/tests/python/relay/test_op_qnn_conv2d.py +++ b/tests/python/relay/test_op_qnn_conv2d.py @@ -98,7 +98,7 @@ def get_qnn_func(data, kernel_layout=kernel_layout) mod = relay.Function(relay.analysis.free_vars(func), func) - mod = relay.Module.from_expr(mod) + mod = tvm.IRModule.from_expr(mod) return mod def get_funcs(data_shape, @@ -138,7 +138,7 @@ def get_funcs(data_shape, groups, channels) ref_func = run_infer_type(ref_func) - ref_func = relay.Module.from_expr(ref_func) + ref_func = tvm.IRModule.from_expr(ref_func) qnn_func = get_qnn_func(data, kernel, input_zero_point, @@ -759,7 +759,7 @@ def test_broadcast_layout(): func = relay.add(bias, func) func = relay.add(func, bias) func = relay.Function(relay.analysis.free_vars(func), func) - mod = relay.Module.from_expr(func) + mod = tvm.IRModule.from_expr(func) with relay.build_config(opt_level=3): graph, lib, params = relay.build(mod, "llvm -mcpu=skylake-avx512") @@ -896,7 +896,7 @@ def test_per_channel_kernel_scale(): out_dtype="int32") mod = relay.Function(relay.analysis.free_vars(func), func) - mod = relay.Module.from_expr(mod) + mod = tvm.IRModule.from_expr(mod) if __name__ == "__main__": test_no_zero_point() diff --git a/tests/python/relay/test_op_qnn_dense.py b/tests/python/relay/test_op_qnn_dense.py index 11987a55b855..0e7c284653f4 100644 --- a/tests/python/relay/test_op_qnn_dense.py +++ b/tests/python/relay/test_op_qnn_dense.py @@ -201,7 +201,7 @@ def qnn_dense_driver(test_configuration): expected_out_dtype = requantize_config['out_dtype'] mod = relay.Function(relay.analysis.free_vars(mod), mod) - mod = relay.Module.from_expr(mod) + mod = tvm.IRModule.from_expr(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod) with relay.build_config(opt_level=2): graph, lib, params = relay.build(mod, "llvm", params=None) diff --git a/tests/python/relay/test_op_qnn_dequantize.py b/tests/python/relay/test_op_qnn_dequantize.py index 4510c570c9ff..b1965c97ad0d 100644 --- a/tests/python/relay/test_op_qnn_dequantize.py +++ b/tests/python/relay/test_op_qnn_dequantize.py @@ -28,7 +28,7 @@ def quantize_test_driver(in_dtype, quant_args, in_data, verify_output_data): quantized_output = relay.qnn.op.dequantize(input_data, input_scale=input_scale, input_zero_point=input_zero_point) mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output) - mod = relay.Module.from_expr(mod) + mod = tvm.IRModule.from_expr(mod) with relay.build_config(opt_level=3): graph, lib, params = relay.build(mod, "llvm", params=None) rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) diff --git a/tests/python/relay/test_op_qnn_mul.py b/tests/python/relay/test_op_qnn_mul.py index 16f0be78ff0f..959a02a976ad 100644 --- a/tests/python/relay/test_op_qnn_mul.py +++ b/tests/python/relay/test_op_qnn_mul.py @@ -52,7 +52,7 @@ def test_tflite_same_io_qnn_params(): output_zero_point=relay.const(output_zero_point, 'int32')) func = relay.Function([x, y], z) - mod = relay.Module.from_expr(func) + mod = tvm.IRModule.from_expr(func) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] @@ -103,7 +103,7 @@ def test_tflite_different_io_qnn_params(): output_zero_point=relay.const(output_zero_point, 'int32')) func = relay.Function([x, y], z) - mod = relay.Module.from_expr(func) + mod = tvm.IRModule.from_expr(func) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] @@ -149,7 +149,7 @@ def test_saturation(): output_zero_point=relay.const(output_zero_point, 'int32')) func = relay.Function([x, y], z) - mod = relay.Module.from_expr(func) + mod = tvm.IRModule.from_expr(func) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] @@ -180,7 +180,7 @@ def test_saturation(): output_zero_point=relay.const(output_zero_point, 'int32')) func = relay.Function([x, y], z) - mod = relay.Module.from_expr(func) + mod = tvm.IRModule.from_expr(func) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] @@ -212,7 +212,7 @@ def test_saturation(): output_zero_point=relay.const(output_zero_point, 'int32')) func = relay.Function([x, y], z) - mod = relay.Module.from_expr(func) + mod = tvm.IRModule.from_expr(func) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] diff --git a/tests/python/relay/test_op_qnn_quantize.py b/tests/python/relay/test_op_qnn_quantize.py index 45caedaf4a44..bdc7bc04d6da 100644 --- a/tests/python/relay/test_op_qnn_quantize.py +++ b/tests/python/relay/test_op_qnn_quantize.py @@ -30,7 +30,7 @@ def quantize_test_driver(in_dtype, quant_args, axis, out_dtype, in_data, verify_ axis=axis, out_dtype=out_dtype) mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output) - mod = relay.Module.from_expr(mod) + mod = tvm.IRModule.from_expr(mod) with relay.build_config(opt_level=3): graph, lib, params = relay.build(mod, "llvm", params=None) rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) diff --git a/tests/python/relay/test_op_qnn_requantize.py b/tests/python/relay/test_op_qnn_requantize.py index b682498cb10b..8af778160ccb 100644 --- a/tests/python/relay/test_op_qnn_requantize.py +++ b/tests/python/relay/test_op_qnn_requantize.py @@ -59,7 +59,7 @@ def get_mod(data_shape, data_dtype, out_dtype, input_scale, output_scale, out_dtype=out_dtype) mod = relay.Function(relay.analysis.free_vars(mod), mod) - mod = relay.Module.from_expr(mod) + mod = tvm.IRModule.from_expr(mod) return mod def test_same_scale(): diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index 6ef435a19388..bdc032e8b65c 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -39,9 +39,9 @@ def test_tensor_type_alpha_equal(): def test_incomplete_type_alpha_equal(): - t1 = relay.IncompleteType(relay.Kind.Shape) - t2 = relay.IncompleteType(relay.Kind.Type) - t3 = relay.IncompleteType(relay.Kind.Type) + t1 = relay.IncompleteType(relay.TypeKind.ShapeVar) + t2 = relay.IncompleteType(relay.TypeKind.Type) + t3 = relay.IncompleteType(relay.TypeKind.Type) # only equal when there is pointer equality assert t2 == t2 @@ -51,9 +51,9 @@ def test_incomplete_type_alpha_equal(): def test_type_param_alpha_equal(): - t1 = relay.TypeVar("v1", relay.Kind.Type) - t2 = relay.TypeVar("v2", relay.Kind.Shape) - t3 = relay.TypeVar("v3", relay.Kind.Type) + t1 = relay.TypeVar("v1", relay.TypeKind.Type) + t2 = relay.TypeVar("v2", relay.TypeKind.ShapeVar) + t3 = relay.TypeVar("v3", relay.TypeKind.Type) # only pointer equality and eq_map allow equal params assert t1 == t1 @@ -76,13 +76,13 @@ def test_func_type_alpha_equal(): t1 = relay.TensorType((1, 2), "float32") t2 = relay.TensorType((1, 2, 3), "float32") - tp1 = relay.TypeVar("v1", relay.Kind.Type) - tp2 = relay.TypeVar("v2", relay.Kind.Type) - tp3 = relay.TypeVar("v3", relay.Kind.Shape) - tp4 = relay.TypeVar("v3", relay.Kind.Shape) + tp1 = relay.TypeVar("v1", relay.TypeKind.Type) + tp2 = relay.TypeVar("v2", relay.TypeKind.Type) + tp3 = relay.TypeVar("v3", relay.TypeKind.ShapeVar) + tp4 = relay.TypeVar("v3", relay.TypeKind.ShapeVar) - broadcast = tvm.get_env_func("tvm.relay.type_relation.Broadcast") - identity = tvm.get_env_func("tvm.relay.type_relation.Identity") + broadcast = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast") + identity = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Identity") tr1 = relay.TypeRelation(broadcast, tvm.convert([tp1, tp3]), 1, None) tr2 = relay.TypeRelation(broadcast, tvm.convert([tp2, tp4]), 1, None) @@ -135,8 +135,8 @@ def test_func_type_alpha_equal(): def test_tuple_type_alpha_equal(): t1 = relay.TensorType((1, 2, 3), "float32") t2 = relay.TensorType((1, 2, 3, 4), "float32") - tp1 = relay.TypeVar("v1", relay.Kind.Type) - tp2 = relay.TypeVar("v2", relay.Kind.Type) + tp1 = relay.TypeVar("v1", relay.TypeKind.Type) + tp2 = relay.TypeVar("v2", relay.TypeKind.Type) tup1 = relay.TupleType(tvm.convert([t1, t2, tp1])) tup2 = relay.TupleType(tvm.convert([t1, t2, tp1])) @@ -157,8 +157,8 @@ def test_type_relation_alpha_equal(): # functions are compared only by pointer equality so # we need to be sure to use the same pointers - broadcast = tvm.get_env_func("tvm.relay.type_relation.Broadcast") - identity = tvm.get_env_func("tvm.relay.type_relation.Identity") + broadcast = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast") + identity = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Identity") attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) attr1_same = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) @@ -347,10 +347,10 @@ def test_function_alpha_equal(): v4 = relay.Var("v4", tt2) vret = relay.Constant(tvm.nd.array(np.ones(1))) - tp1 = relay.TypeVar("tp1", relay.Kind.Type) - tp2 = relay.TypeVar("tp2", relay.Kind.Type) - tp3 = relay.TypeVar("tp3", relay.Kind.Shape) - tp4 = relay.TypeVar("tp4", relay.Kind.Shape) + tp1 = relay.TypeVar("tp1", relay.TypeKind.Type) + tp2 = relay.TypeVar("tp2", relay.TypeKind.Type) + tp3 = relay.TypeVar("tp3", relay.TypeKind.ShapeVar) + tp4 = relay.TypeVar("tp4", relay.TypeKind.ShapeVar) basic_args = [relay.Var("v3", tt1), relay.Var("v4", tt2)] basic_tps = [tp1, tp2] @@ -515,7 +515,7 @@ def test_if_alpha_equal(): def test_constructor_alpha_equal(): # smoke test: it should be pointer equality - mod = relay.Module() + mod = tvm.IRModule() p = relay.prelude.Prelude(mod) assert alpha_equal(p.nil, p.nil) @@ -524,7 +524,7 @@ def test_constructor_alpha_equal(): def test_match_alpha_equal(): - mod = relay.Module() + mod = tvm.IRModule() p = relay.prelude.Prelude(mod) x = relay.Var('x') diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index b01e1bbe0504..2ec3f282a6c4 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -23,7 +23,7 @@ def run_opt_pass(expr, passes): passes = passes if isinstance(passes, list) else [passes] - mod = relay.Module.from_expr(expr) + mod = tvm.IRModule.from_expr(expr) seq = transform.Sequential(passes) with transform.PassContext(opt_level=3): mod = seq(mod) @@ -1005,7 +1005,7 @@ def before(): kernel_size=(3, 3), padding=(1, 1)) y = relay.nn.relu(y) - mod = relay.Module() + mod = tvm.IRModule() foo = relay.GlobalVar('foo') mod[foo] = relay.Function([x, weight], y) mod["main"] = relay.Function([x, weight], foo(x, weight)) @@ -1024,7 +1024,7 @@ def expected(): kernel_size=(3, 3), padding=(1, 1)) y = relay.nn.relu(y) - mod = relay.Module() + mod = tvm.IRModule() foo = relay.GlobalVar('foo') mod[foo] = relay.Function([x, weight], y) mod["main"] = relay.Function([x, weight], foo(x, weight)) diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index 69ce4c5211be..3e7d916c96fa 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -27,7 +27,7 @@ def run_opt_pass(expr, passes): passes = passes if isinstance(passes, list) else [passes] - mod = relay.Module.from_expr(expr) + mod = tvm.IRModule.from_expr(expr) seq = transform.Sequential(passes) with transform.PassContext(opt_level=3): mod = seq(mod) diff --git a/tests/python/relay/test_pass_canonicalize_cast.py b/tests/python/relay/test_pass_canonicalize_cast.py index b72ded21ef52..672b4b192995 100644 --- a/tests/python/relay/test_pass_canonicalize_cast.py +++ b/tests/python/relay/test_pass_canonicalize_cast.py @@ -17,7 +17,6 @@ import tvm import tvm.relay as relay -import tvm.relay.module as _module import tvm.relay.transform as _transform @@ -53,7 +52,7 @@ def check(shape): bias1 = relay.var("bias1", shape=(16, 1, 1), dtype="int32") bias2 = relay.var("bias2", shape=(16, 1, 1), dtype="int32") y = before(data, conv_weight, bias1, bias2) - mod = _module.Module.from_expr(y) + mod = tvm.IRModule.from_expr(y) seq = _transform.Sequential([_transform.InferType(), _transform.CanonicalizeCast(), _transform.InferType()]) with _transform.PassContext(opt_level=3): diff --git a/tests/python/relay/test_pass_check_kind.py b/tests/python/relay/test_pass_check_kind.py index 16d57021f10b..62a92040ff16 100644 --- a/tests/python/relay/test_pass_check_kind.py +++ b/tests/python/relay/test_pass_check_kind.py @@ -21,30 +21,30 @@ def test_typevar_kind(): # returns the same kind - tp1 = relay.TypeVar('tp1', relay.Kind.Type) - tp2 = relay.TypeVar('tp2', relay.Kind.Shape) - tp3 = relay.TypeVar('tp3', relay.Kind.Constraint) + tp1 = relay.TypeVar('tp1', relay.TypeKind.Type) + tp2 = relay.TypeVar('tp2', relay.TypeKind.ShapeVar) + tp3 = relay.TypeVar('tp3', relay.TypeKind.Constraint) - assert check_kind(tp1) == relay.Kind.Type - assert check_kind(tp2) == relay.Kind.Shape - assert check_kind(tp3) == relay.Kind.Constraint + assert check_kind(tp1) == relay.TypeKind.Type + assert check_kind(tp2) == relay.TypeKind.ShapeVar + assert check_kind(tp3) == relay.TypeKind.Constraint def test_tuple_kind(): # only contain type kinds - tp = relay.TypeVar('tp', relay.Kind.Type) + tp = relay.TypeVar('tp', relay.TypeKind.Type) tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') tf = relay.FuncType(tvm.convert([]), tt, tvm.convert([]), tvm.convert([])) fields = tvm.convert([tp, tf, tt]) tup_ty = relay.TupleType(fields) - assert check_kind(tup_ty) == relay.Kind.Type + assert check_kind(tup_ty) == relay.TypeKind.Type def test_func_kind(): # only contain type kinds - tp1 = relay.TypeVar('tp1', relay.Kind.Type) - tp2 = relay.TypeVar('tp2', relay.Kind.Type) + tp1 = relay.TypeVar('tp1', relay.TypeKind.Type) + tp2 = relay.TypeVar('tp2', relay.TypeKind.Type) shape = tvm.convert([1, 2, 3]) dtype = 'float32' @@ -58,7 +58,7 @@ def test_func_kind(): ret_type = relay.TupleType(tvm.convert([tp2, tensor_type])) tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints) - assert check_kind(tf) == relay.Kind.Type + assert check_kind(tf) == relay.TypeKind.Type def test_ref_kind(): @@ -67,65 +67,65 @@ def test_ref_kind(): ft = relay.FuncType(tvm.convert([]), tt, tvm.convert([]), tvm.convert([])) rt1 = relay.RefType(tt) - assert check_kind(rt1) == relay.Kind.Type + assert check_kind(rt1) == relay.TypeKind.Type rt2 = relay.RefType(ft) - assert check_kind(rt2) == relay.Kind.Type + assert check_kind(rt2) == relay.TypeKind.Type rt3 = relay.RefType(relay.TupleType([rt1, rt2])) - assert check_kind(rt3) == relay.Kind.Type + assert check_kind(rt3) == relay.TypeKind.Type def test_relation_kind(): # only have type kinds for arguments - tp = relay.TypeVar('tp', relay.Kind.Type) + tp = relay.TypeVar('tp', relay.TypeKind.Type) tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') tf = relay.FuncType(tvm.convert([]), tt, tvm.convert([]), tvm.convert([])) args = tvm.convert([tf, tt, tp]) tr = relay.TypeRelation(None, args, 2, None) - assert check_kind(tr) == relay.Kind.Constraint + assert check_kind(tr) == relay.TypeKind.Constraint def test_global_typevar_kind(): - v1 = relay.GlobalTypeVar('gtv1', relay.Kind.AdtHandle) - v2 = relay.GlobalTypeVar('gtv2', relay.Kind.Type) + v1 = relay.GlobalTypeVar('gtv1', relay.TypeKind.AdtHandle) + v2 = relay.GlobalTypeVar('gtv2', relay.TypeKind.Type) - assert check_kind(v1) == relay.Kind.AdtHandle - assert check_kind(v2) == relay.Kind.Type + assert check_kind(v1) == relay.TypeKind.AdtHandle + assert check_kind(v2) == relay.TypeKind.Type def test_typecall_kind(): gtv = relay.GlobalTypeVar('gtv') - mod = relay.Module() + mod = tvm.IRModule() data = relay.TypeData(gtv, [], []) mod[gtv] = data empty_call = relay.TypeCall(gtv, []) - assert check_kind(empty_call, mod) == relay.Kind.Type + assert check_kind(empty_call, mod) == relay.TypeKind.Type - new_mod = relay.Module() + new_mod = tvm.IRModule() tv = relay.TypeVar('tv') new_data = relay.TypeData(gtv, [tv], []) new_mod[gtv] = new_data call = relay.TypeCall(gtv, [relay.TupleType([])]) - assert check_kind(call, new_mod) == relay.Kind.Type + assert check_kind(call, new_mod) == relay.TypeKind.Type -@pytest.mark.xfail(raises=tvm._ffi.base.TVMError) +@pytest.mark.xfail(raises=tvm.error.TVMError) def test_invalid_tuple_kind(): - tp1 = relay.TypeVar('tp1', relay.Kind.Shape) - tp2 = relay.TypeVar('tp2', relay.Kind.BaseType) - tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar) + tp1 = relay.TypeVar('tp1', relay.TypeKind.ShapeVar) + tp2 = relay.TypeVar('tp2', relay.TypeKind.BaseType) + tp3 = relay.TypeVar('tp3', relay.TypeKind.Constraint) fields = tvm.convert([tp1, tp2, tp3]) tup_ty = relay.TupleType(fields) check_kind(tup_ty) -@pytest.mark.xfail(raises=tvm._ffi.base.TVMError) +@pytest.mark.xfail(raises=tvm.error.TVMError) def test_invalid_func_kind(): - tp1 = relay.TypeVar('tp1', relay.Kind.Shape) - tp2 = relay.TypeVar('tp2', relay.Kind.BaseType) - tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar) + tp1 = relay.TypeVar('tp1', relay.TypeKind.ShapeVar) + tp2 = relay.TypeVar('tp2', relay.TypeKind.BaseType) + tp3 = relay.TypeVar('tp3', relay.TypeKind.Constraint) type_params = tvm.convert([tp1, tp2, tp3]) type_constraints = tvm.convert([]) @@ -136,36 +136,36 @@ def test_invalid_func_kind(): check_kind(tf) -@pytest.mark.xfail(raises=tvm._ffi.base.TVMError) +@pytest.mark.xfail(raises=tvm.error.TVMError) def test_invalid_ref_kind(): - tp = relay.TypeVar('tp', relay.Kind.Shape) + tp = relay.TypeVar('tp', relay.TypeKind.ShapeVar) rt = relay.RefType(tp) check_kind(rt) -@pytest.mark.xfail(raises=tvm._ffi.base.TVMError) +@pytest.mark.xfail(raises=tvm.error.TVMError) def test_invalid_relation_kind(): - tp1 = relay.TypeVar('tp1', relay.Kind.Shape) - tp2 = relay.TypeVar('tp2', relay.Kind.BaseType) - tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar) + tp1 = relay.TypeVar('tp1', relay.TypeKind.ShapeVar) + tp2 = relay.TypeVar('tp2', relay.TypeKind.BaseType) + tp3 = relay.TypeVar('tp3', relay.TypeKind.Constraint) args = tvm.convert([tp1, tp2, tp3]) - func = tvm.get_env_func("tvm.relay.type_relation.Broadcast") + func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast") tr = relay.TypeRelation(func, args, 2, None) check_kind(tr) -@pytest.mark.xfail(raises=tvm._ffi.base.TVMError) +@pytest.mark.xfail(raises=tvm.error.TVMError) def test_typecall_invalid_callee(): # global type var must be an ADT handle - gtv = relay.GlobalTypeVar('v1', relay.Kind.Type) + gtv = relay.GlobalTypeVar('v1', relay.TypeKind.Type) check_kind(relay.TypeCall(gtv, [])) -@pytest.mark.xfail(raises=tvm._ffi.base.TVMError) +@pytest.mark.xfail(raises=tvm.error.TVMError) def test_typecall_invalid_args(): # args must all be type kind - mod = relay.Module() + mod = tvm.IRModule() gtv = relay.GlobalTypeVar('v1') data = relay.TypeData(gtv, [], []) mod[gtv] = data @@ -173,9 +173,9 @@ def test_typecall_invalid_args(): check_kind(relay.TypeCall(gtv, [data])) -@pytest.mark.xfail(raises=tvm._ffi.base.TVMError) +@pytest.mark.xfail(raises=tvm.error.TVMError) def test_typecall_invalid_num_args(): - mod = relay.Module() + mod = tvm.IRModule() gtv = relay.GlobalTypeVar('v1') tv = relay.TypeVar('tv') data = relay.TypeData(gtv, [tv], []) @@ -183,27 +183,27 @@ def test_typecall_invalid_num_args(): check_kind(relay.TypeCall(gtv, [])) -@pytest.mark.xfail(raises=tvm._ffi.base.TVMError) +@pytest.mark.xfail(raises=tvm.error.TVMError) def test_func_with_invalid_ret_type(): - tp1 = relay.TypeVar('tp1', relay.Kind.Type) - tp2 = relay.TypeVar('tp2', relay.Kind.Shape) + tp1 = relay.TypeVar('tp1', relay.TypeKind.Type) + tp2 = relay.TypeVar('tp2', relay.TypeKind.ShapeVar) tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([])) check_kind(tf) -@pytest.mark.xfail(raises=tvm._ffi.base.TVMError) +@pytest.mark.xfail(raises=tvm.error.TVMError) def test_func_with_invalid_arg_types(): - tp1 = relay.TypeVar('tp1', relay.Kind.Shape) - tp2 = relay.TypeVar('tp2', relay.Kind.Type) + tp1 = relay.TypeVar('tp1', relay.TypeKind.ShapeVar) + tp2 = relay.TypeVar('tp2', relay.TypeKind.Type) tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([])) check_kind(tf) -@pytest.mark.xfail(raises=tvm._ffi.base.TVMError) +@pytest.mark.xfail(raises=tvm.error.TVMError) def test_func_with_invalid_tuple(): - tp1 = relay.TypeVar('tp1', relay.Kind.Shape) + tp1 = relay.TypeVar('tp1', relay.TypeKind.ShapeVar) ret_type = relay.TupleType(tvm.convert([tp1, tp1, tp1])) @@ -211,24 +211,24 @@ def test_func_with_invalid_tuple(): check_kind(tf) -@pytest.mark.xfail(raises=tvm._ffi.base.TVMError) +@pytest.mark.xfail(raises=tvm.error.TVMError) def test_func_with_invalid_relation(): - tp1 = relay.TypeVar('tp1', relay.Kind.Type) - tp2 = relay.TypeVar('tp2', relay.Kind.Shape) - tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar) + tp1 = relay.TypeVar('tp1', relay.TypeKind.Type) + tp2 = relay.TypeVar('tp2', relay.TypeKind.ShapeVar) + tp3 = relay.TypeVar('tp3', relay.TypeKind.Constraint) - func = tvm.get_env_func("tvm.relay.type_relation.Identity") + func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Identity") tr = relay.TypeRelation(func, tvm.convert([tp2, tp3]), 1, None) tf = relay.FuncType(tvm.convert([tp1]), tp1, tvm.convert([tp1, tp2, tp3]), tvm.convert([tr])) check_kind(tf) -@pytest.mark.xfail(raises=tvm._ffi.base.TVMError) +@pytest.mark.xfail(raises=tvm.error.TVMError) def test_tuple_with_invalid_func(): tensor_type = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') - tp1 = relay.TypeVar('tp1', relay.Kind.Shape) + tp1 = relay.TypeVar('tp1', relay.TypeKind.ShapeVar) tf = relay.FuncType(tvm.convert([]), tp1, tvm.convert([tp1]), tvm.convert([])) tup_ty = relay.TupleType(tvm.convert([tensor_type, tf])) diff --git a/tests/python/relay/test_pass_combine_parallel_conv2d.py b/tests/python/relay/test_pass_combine_parallel_conv2d.py index 599b308b2136..c10a7b8d1b39 100644 --- a/tests/python/relay/test_pass_combine_parallel_conv2d.py +++ b/tests/python/relay/test_pass_combine_parallel_conv2d.py @@ -14,18 +14,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import tvm from tvm import relay from tvm.relay import transform def run_combine_parallel(expr, min_num_branches=3): - mod = relay.Module.from_expr(expr) + mod = tvm.IRModule.from_expr(expr) mod = transform.CombineParallelConv2D(min_num_branches)(mod) return mod["main"] def run_opt_pass(expr, opt_pass): assert isinstance(opt_pass, transform.Pass) - mod = relay.Module.from_expr(expr) + mod = tvm.IRModule.from_expr(expr) mod = opt_pass(mod) return mod["main"] diff --git a/tests/python/relay/test_pass_combine_parallel_dense.py b/tests/python/relay/test_pass_combine_parallel_dense.py index 070ab8658b88..f693f30060d9 100644 --- a/tests/python/relay/test_pass_combine_parallel_dense.py +++ b/tests/python/relay/test_pass_combine_parallel_dense.py @@ -14,18 +14,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import tvm from tvm import relay from tvm.relay import transform def run_combine_parallel(expr, min_num_branches=3): - mod = relay.Module.from_expr(expr) + mod = tvm.IRModule.from_expr(expr) mod = transform.CombineParallelDense(min_num_branches)(mod) return mod["main"] def run_opt_pass(expr, opt_pass): assert isinstance(opt_pass, transform.Pass) - mod = relay.Module.from_expr(expr) + mod = tvm.IRModule.from_expr(expr) mod = opt_pass(mod) return mod["main"] diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index dfd745164069..4b80d6ca120d 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -24,7 +24,7 @@ def run_opt_pass(expr, passes): passes = passes if isinstance(passes, list) else [passes] - mod = relay.Module.from_expr(expr) + mod = tvm.IRModule.from_expr(expr) seq = transform.Sequential(passes) with transform.PassContext(opt_level=3): mod = seq(mod) diff --git a/tests/python/relay/test_pass_dead_code_elimination.py b/tests/python/relay/test_pass_dead_code_elimination.py index 89bae1f71b47..3f1ec9efd5e1 100644 --- a/tests/python/relay/test_pass_dead_code_elimination.py +++ b/tests/python/relay/test_pass_dead_code_elimination.py @@ -47,7 +47,7 @@ def __init__(self): def run_opt_pass(expr, opt_pass): assert isinstance(opt_pass, transform.Pass) - mod = relay.Module.from_expr(expr) + mod = tvm.IRModule.from_expr(expr) mod = opt_pass(mod) entry = mod["main"] return entry if isinstance(expr, relay.Function) else entry.body diff --git a/tests/python/relay/test_pass_eliminate_common_subexpr.py b/tests/python/relay/test_pass_eliminate_common_subexpr.py index 09ea7044daf5..e2fec6161c87 100644 --- a/tests/python/relay/test_pass_eliminate_common_subexpr.py +++ b/tests/python/relay/test_pass_eliminate_common_subexpr.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Test eliminate common subexpr pass""" +import tvm + from tvm import relay from tvm.relay.op import register_alter_op_layout from tvm.relay import transform, analysis @@ -22,7 +24,7 @@ def run_opt_pass(expr, opt_pass): assert isinstance(opt_pass, transform.Pass) - mod = relay.Module.from_expr(expr) + mod = tvm.IRModule.from_expr(expr) mod = opt_pass(mod) entry = mod["main"] return entry if isinstance(expr, relay.Function) else entry.body diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py index ca901b16b842..08834f14e851 100644 --- a/tests/python/relay/test_pass_fold_constant.py +++ b/tests/python/relay/test_pass_fold_constant.py @@ -25,7 +25,7 @@ def run_opt_pass(expr, opt_pass): assert isinstance(opt_pass, transform.Pass) - mod = relay.Module.from_expr(expr) + mod = tvm.IRModule.from_expr(expr) mod = opt_pass(mod) entry = mod["main"] return entry if isinstance(expr, relay.Function) else entry.body diff --git a/tests/python/relay/test_pass_fold_scale_axis.py b/tests/python/relay/test_pass_fold_scale_axis.py index d6f471bef04a..13995732d8ee 100644 --- a/tests/python/relay/test_pass_fold_scale_axis.py +++ b/tests/python/relay/test_pass_fold_scale_axis.py @@ -16,6 +16,7 @@ # under the License. import numpy as np +import tvm from tvm import relay from tvm.relay import transform @@ -25,7 +26,7 @@ def _get_positive_scale(size): def run_opt_pass(expr, opt_pass): assert isinstance(opt_pass, transform.Pass) - mod = relay.Module.from_expr(expr) + mod = tvm.IRModule.from_expr(expr) mod = opt_pass(mod) entry = mod["main"] return entry if isinstance(expr, relay.Function) else entry.body diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 7ec21eab12df..18916f758a6c 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -353,8 +353,8 @@ def expected(p0): dshape = (1, 16, 64, 64) x = relay.var("x", shape=dshape) orig = before(x) - fuse0(relay.Module.from_expr(orig)) - m = fuse2(relay.Module.from_expr(orig)) + fuse0(tvm.IRModule.from_expr(orig)) + m = fuse2(tvm.IRModule.from_expr(orig)) relay.build(m, 'llvm') after = run_opt_pass(expected(x), transform.InferType()) assert relay.analysis.alpha_equal(m["main"], after) @@ -408,8 +408,8 @@ def expected(dshape): dshape = (1, 16, 64, 64) x = relay.var("x", shape=dshape) orig = before(x) - fuse0(relay.Module.from_expr(orig)) - m = fuse2(relay.Module.from_expr(orig)) + fuse0(tvm.IRModule.from_expr(orig)) + m = fuse2(tvm.IRModule.from_expr(orig)) relay.build(m, 'llvm') after = run_opt_pass(expected(dshape), transform.InferType()) assert relay.analysis.alpha_equal(m["main"], after) @@ -475,8 +475,8 @@ def expected(dshape): dshape = (1, 16, 64, 64) orig = before(dshape) - fuse0(relay.Module.from_expr(orig)) - m = fuse2(relay.Module.from_expr(orig)) + fuse0(tvm.IRModule.from_expr(orig)) + m = fuse2(tvm.IRModule.from_expr(orig)) relay.build(m, 'llvm') after = run_opt_pass(expected(dshape), transform.InferType()) assert relay.analysis.alpha_equal(m["main"], after) @@ -519,7 +519,7 @@ def before(): y = relay.add(x, relay.const(1, "float32")) z = relay.exp(y) w = relay.squeeze(z) - mod = relay.module.Module() + mod = tvm.IRModule() mod["main"] = relay.Function([x], w) return mod @@ -531,7 +531,7 @@ def expected(): f1 = relay.Function([x], w) x = relay.var("x", shape=(10, 20)) y = relay.Call(f1, [x]) - mod = relay.module.Module() + mod = tvm.IRModule() mod["main"] = relay.Function([x], y) return mod @@ -548,7 +548,7 @@ def test_split(): a = relay.TupleGetItem(y, 0) b = relay.TupleGetItem(y, 1) c = relay.TupleGetItem(y, 2) - mod = relay.module.Module() + mod = tvm.IRModule() mod["main"] = relay.Function([x], a + relay.RefRead(relay.RefCreate(b)) + c) mod = transform.FuseOps()(mod) diff --git a/tests/python/relay/test_pass_gradient.py b/tests/python/relay/test_pass_gradient.py index 8e4b7010de30..6c2ea8ffa3b3 100644 --- a/tests/python/relay/test_pass_gradient.py +++ b/tests/python/relay/test_pass_gradient.py @@ -188,7 +188,7 @@ def test_tuple(): def test_pow(): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) add_nat_definitions(p) shape = (10, 10) diff --git a/tests/python/relay/test_pass_lambda_lift.py b/tests/python/relay/test_pass_lambda_lift.py index 550c85d4476b..a66c4c7d745a 100644 --- a/tests/python/relay/test_pass_lambda_lift.py +++ b/tests/python/relay/test_pass_lambda_lift.py @@ -22,7 +22,7 @@ from tvm.relay import transform def test_basic(): - mod = relay.Module() + mod = tvm.IRModule() x2 = relay.var('x2', shape=(10, 5)) y2 = relay.var('y2', shape=(1, 5)) level2_func = relay.Function([x2, y2], relay.op.add(x2, y2)) @@ -36,7 +36,7 @@ def test_basic(): assert len(new_mod.functions) == 2 def test_closure(): - mod = relay.Module() + mod = tvm.IRModule() x = relay.var('x', shape=(2,)) y = relay.var('y', shape=(2,)) @@ -47,9 +47,9 @@ def test_closure(): new_mod = transform.LambdaLift()(mod) assert len(new_mod.functions) == 3 - + def test_recursive(): - mod = relay.Module() + mod = tvm.IRModule() x = relay.var('x', shape=(2,)) i = relay.var('i', shape=(), dtype='int32') diff --git a/tests/python/relay/test_pass_legalize.py b/tests/python/relay/test_pass_legalize.py index e38c1aaa7a0e..e4e16c002abf 100644 --- a/tests/python/relay/test_pass_legalize.py +++ b/tests/python/relay/test_pass_legalize.py @@ -26,7 +26,7 @@ def run_opt_pass(expr, passes): passes = passes if isinstance(passes, list) else [passes] - mod = relay.Module.from_expr(expr) + mod = tvm.IRModule.from_expr(expr) seq = transform.Sequential(passes) with transform.PassContext(opt_level=3): mod = seq(mod) diff --git a/tests/python/relay/test_pass_mac_count.py b/tests/python/relay/test_pass_mac_count.py index 0ad1e3abe759..5ce0e41cfbac 100644 --- a/tests/python/relay/test_pass_mac_count.py +++ b/tests/python/relay/test_pass_mac_count.py @@ -23,7 +23,7 @@ def run_opt_pass(expr, opt_pass): assert isinstance(opt_pass, transform.Pass) - mod = relay.Module.from_expr(expr) + mod = tvm.IRModule.from_expr(expr) mod = opt_pass(mod) entry = mod["main"] return entry if isinstance(expr, relay.Function) else entry.body diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index bd055eebdbde..a13e5e93ea9c 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -79,9 +79,9 @@ class OptTester(): """A helper class for testing the pass manager.""" def __init__(self, mod): - if not isinstance(mod, relay.Module): + if not isinstance(mod, tvm.IRModule): raise TypeError("mod is expected to be the type of " - "relay.Module") + "tvm.IRModule") self.mod = mod def analysis(self): @@ -91,10 +91,10 @@ def analysis(self): @staticmethod def transform(node, ctx=None): """Perform optimization on node.""" - if isinstance(node, relay.Module): + if isinstance(node, tvm.IRModule): # Add a function to the module and return an updated module. gv, func = get_var_func() - mod = relay.Module({gv: func}) + mod = tvm.IRModule({gv: func}) mod.update(node) return mod if isinstance(node, relay.Function): @@ -121,7 +121,7 @@ def test_module_pass(): y = relay.var("y", tp) v_add = relay.GlobalVar("myAdd") func = relay.Function([x, y], x + y) - mod = relay.Module({v_add: func}) + mod = tvm.IRModule({v_add: func}) pass_name = "module_pass_test" opt_level = 0 @@ -150,10 +150,10 @@ def direct_transform(expr, ctx): def test_pass_run(): module_pass = transform - assert pass_name in module_pass.astext() + assert pass_name in str(module_pass) updated_mod = module_pass(mod) - assert isinstance(updated_mod, relay.Module) + assert isinstance(updated_mod, tvm.IRModule) # Check the abs function in the updated module. v_abs, myabs = get_var_func() @@ -206,10 +206,10 @@ def transform_function(self, func, mod, ctx): fpass = TestReplaceFunc(f1) assert fpass.info.opt_level == 1 assert fpass.info.name == "TestReplaceFunc" - mod = relay.Module.from_expr(f2) + mod = tvm.IRModule.from_expr(f2) mod = fpass(mod) # wrap in expr - mod2 = relay.Module.from_expr(f1) + mod2 = tvm.IRModule.from_expr(f1) assert relay.alpha_equal(mod["main"], mod2["main"]) @@ -220,7 +220,7 @@ def test_function_pass(): x = relay.var("x", tp) v_log = relay.GlobalVar("myLog") log = relay.Function([x], relay.log(x)) - mod = relay.Module({v_log: log}) + mod = tvm.IRModule({v_log: log}) pass_name = "function_pass_test" opt_level = 1 @@ -253,10 +253,10 @@ def direct_transform(expr, ctx): def test_pass_run(): function_pass = transform - assert pass_name in function_pass.astext() + assert pass_name in str(function_pass) updated_mod = function_pass(mod) - assert isinstance(updated_mod, relay.Module) + assert isinstance(updated_mod, tvm.IRModule) # Check the log function in the updated module. new_v_log = updated_mod.get_global_var(v_log.name_hint) @@ -297,8 +297,8 @@ def transform_module(self, mod, ctx): return mod x = relay.var("x", shape=(10, 20)) - m1 = relay.Module.from_expr(relay.Function([x], x)) - m2 = relay.Module.from_expr(relay.Function([x], relay.log(x))) + m1 = tvm.IRModule.from_expr(relay.Function([x], x)) + m2 = tvm.IRModule.from_expr(relay.Function([x], relay.log(x))) fpass = TestPipeline(m2, replace=True) assert fpass.info.name == "TestPipeline" mod3 = fpass(m1) @@ -326,7 +326,7 @@ def test_sequential_pass(): v_log = relay.GlobalVar("myLog") log = relay.Function([z], relay.log(z)) - mod = relay.Module({v_sub: sub, v_log: log}) + mod = tvm.IRModule({v_sub: sub, v_log: log}) def get_ref_log(): ref_log = relay.Function([x], relay.log(relay.add(x, x))) @@ -408,7 +408,7 @@ def test_only_function_pass(): def test_multiple_passes(): # Reset the current module since mod has been polluted by the previous # function pass. - mod = relay.Module({v_sub: sub, v_log: log}) + mod = tvm.IRModule({v_sub: sub, v_log: log}) passes = [module_pass, function_pass] sequential = _transform.Sequential(opt_level=1, passes=passes) required = ["mod_transform", "func_transform"] @@ -488,7 +488,7 @@ def expected(): relay.transform.AlterOpLayout() ]) - mod = relay.Module({"main": before()}) + mod = tvm.IRModule({"main": before()}) with relay.build_config(opt_level=3): with tvm.target.create("llvm"): mod = seq(mod) @@ -513,7 +513,7 @@ def test_print_ir(capfd): relay.transform.DeadCodeElimination() ]) - mod = relay.Module({"main": func}) + mod = tvm.IRModule({"main": func}) with relay.build_config(opt_level=3): mod = seq(mod) @@ -545,7 +545,7 @@ def test_print_debug_callback(): ]) assert __TRACE_COUNTER__ == 0 - mod = relay.Module({"main": func}) + mod = tvm.IRModule({"main": func}) with relay.build_config(opt_level=3, trace=_tracer): mod = seq(mod) diff --git a/tests/python/relay/test_pass_partial_eval.py b/tests/python/relay/test_pass_partial_eval.py index cf4f8f6cee74..2bec98c173d9 100644 --- a/tests/python/relay/test_pass_partial_eval.py +++ b/tests/python/relay/test_pass_partial_eval.py @@ -22,7 +22,7 @@ from tvm.relay.prelude import Prelude from tvm.relay import op, create_executor, transform from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead, RefWrite, RefCreate -from tvm.relay import TensorType, Tuple, If, Module, Clause, PatternConstructor, PatternVar, Match +from tvm.relay import TensorType, Tuple, If, Clause, PatternConstructor, PatternVar, Match from tvm.relay import GlobalVar, Call from tvm.relay.transform import gradient from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type @@ -37,7 +37,7 @@ def check_eval(expr, expected_result, mod=None, rtol=1e-07): def run_opt_pass(expr, passes): passes = passes if isinstance(passes, list) else [passes] - mod = relay.Module.from_expr(expr) + mod = tvm.IRModule.from_expr(expr) seq = transform.Sequential(passes) with transform.PassContext(opt_level=3): mod = seq(mod) @@ -171,7 +171,7 @@ def test_function_invalidate(): def test_head_cons(): - mod = Module() + mod = tvm.IRModule() p = Prelude(mod) hd = p.hd t = TypeVar("t") @@ -183,7 +183,7 @@ def test_head_cons(): def test_map(): - mod = Module() + mod = tvm.IRModule() p = Prelude(mod) f = GlobalVar("f") t = TypeVar("t") @@ -200,7 +200,7 @@ def test_map(): def test_loop(): - mod = Module() + mod = tvm.IRModule() t = TypeVar("t") x = Var("x", t) loop = GlobalVar("loop") @@ -214,7 +214,7 @@ def test_loop(): def test_swap_loop(): - mod = Module() + mod = tvm.IRModule() p = Prelude(mod) add_nat_definitions(p) nat = p.nat() @@ -230,7 +230,7 @@ def test_swap_loop(): def test_abs_diff(): # TODO(@M.K.): refactor using tuple pattern (not yet implemented) - mod = Module() + mod = tvm.IRModule() p = Prelude(mod) add_nat_definitions(p) nat = p.nat() @@ -251,7 +251,7 @@ def test_abs_diff(): def test_match_nat_id(): - mod = Module() + mod = tvm.IRModule() p = Prelude(mod) add_nat_definitions(p) nat = p.nat() @@ -268,7 +268,7 @@ def test_match_nat_id(): def test_nat_id(): - mod = Module() + mod = tvm.IRModule() p = Prelude(mod) add_nat_definitions(p) nat = p.nat() @@ -283,7 +283,7 @@ def test_nat_id(): def test_global_match_nat_id(): - mod = Module() + mod = tvm.IRModule() p = Prelude(mod) add_nat_definitions(p) nat = p.nat() @@ -297,7 +297,7 @@ def test_global_match_nat_id(): def test_double(): - mod = Module() + mod = tvm.IRModule() p = Prelude(mod) add_nat_definitions(p) orig = p.double(make_nat_expr(p, 3)) @@ -324,7 +324,7 @@ def test_triangle_number(): def test_nat_update(): - m = Module() + m = tvm.IRModule() p = Prelude(m) add_nat_definitions(p) m = transform.ToANormalForm()(m) diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 869dba82bea4..27a143bc455a 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -245,7 +245,7 @@ def test_multi_node_compiler(): r = relay.concatenate((q0, q1, q2), axis=0) f = relay.Function([x, w0, w1, w2, w3, w4, w5, w6, w7], r) - mod = relay.Module() + mod = tvm.IRModule() ann = CcompilerAnnotator() mod["main"] = ann.visit(f) mod = transform.PartitionGraph()(mod) @@ -286,7 +286,7 @@ def visit_call(self, call): f = relay.Function([x, y], z) x_data = np.random.rand(8, 8).astype('float32') y_data = np.random.rand(8, 8).astype('float32') - mod = relay.Module() + mod = tvm.IRModule() mod["main"] = f mod = MyAnnotator()(mod) mod = transform.PartitionGraph()(mod) @@ -319,7 +319,7 @@ def expected(): tvm.expr.IntImm("int32", 1)) fused_call = relay.Call(fused_func, [add_call]) main = relay.Function([x, y], fused_call) - mod = relay.Module() + mod = tvm.IRModule() mod["main"] = main return mod @@ -330,7 +330,7 @@ def expected(): exp = relay.exp(add) concat = relay.concatenate([log, exp], axis=0) f = relay.Function([x, y], concat) - mod = relay.Module() + mod = tvm.IRModule() mod["main"] = f mod = WhiteListAnnotator(["add", "subtract", "multiply"], "ccompiler")(mod) mod = transform.PartitionGraph()(mod) @@ -354,7 +354,7 @@ def test_extern_ccompiler(): f = relay.Function([x, y], p - z) x_data = np.random.rand(2, 2).astype('float32') y_data = np.random.rand(2, 2).astype('float32') - mod = relay.Module() + mod = tvm.IRModule() mod["main"] = f mod = WhiteListAnnotator(["add", "subtract", "multiply"], "ccompiler")(mod) mod = transform.PartitionGraph()(mod) @@ -386,11 +386,11 @@ def test_extern_dnnl(): f = relay.Function([data, weight1], out) - mod = relay.Module() + mod = tvm.IRModule() mod['main'] = WholeGraphAnnotator('dnnl').visit(f) mod = transform.PartitionGraph()(mod) - ref_mod = relay.Module() + ref_mod = tvm.IRModule() ref_mod['main'] = f i_data = np.random.uniform(0, 1, ishape).astype(dtype) diff --git a/tests/python/relay/test_pass_qnn_legalize.py b/tests/python/relay/test_pass_qnn_legalize.py index 6992f288c454..38fdb7dd07b1 100644 --- a/tests/python/relay/test_pass_qnn_legalize.py +++ b/tests/python/relay/test_pass_qnn_legalize.py @@ -34,7 +34,7 @@ def alpha_equal(x, y): def run_opt_pass(expr, passes): passes = passes if isinstance(passes, list) else [passes] - mod = relay.Module.from_expr(expr) + mod = tvm.IRModule.from_expr(expr) seq = transform.Sequential(passes) with transform.PassContext(opt_level=3): mod = seq(mod) @@ -114,7 +114,7 @@ def _get_mod(data_dtype, kernel_dtype): kernel_layout='OIHW') mod = relay.Function(relay.analysis.free_vars(func), func) - mod = relay.Module.from_expr(mod) + mod = tvm.IRModule.from_expr(mod) return mod # Check uint8 x uint8 and int8 x int8 transformation @@ -193,7 +193,7 @@ def _get_mod(data_dtype, kernel_dtype): out_dtype='int32') mod = relay.Function(relay.analysis.free_vars(func), func) - mod = relay.Module.from_expr(mod) + mod = tvm.IRModule.from_expr(mod) return mod # Check uint8 x uint8 and int8 x int8 transformation diff --git a/tests/python/relay/test_pass_remove_unused_functions.py b/tests/python/relay/test_pass_remove_unused_functions.py index 2a4cbd2579e7..bacc3126c7c4 100644 --- a/tests/python/relay/test_pass_remove_unused_functions.py +++ b/tests/python/relay/test_pass_remove_unused_functions.py @@ -22,7 +22,7 @@ def test_remove_all_prelude_functions(): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) x = relay.var("x", shape=(1, 16)) mod["main"] = relay.Function([x], x) @@ -32,7 +32,7 @@ def test_remove_all_prelude_functions(): def test_remove_all_prelude_functions_but_referenced_functions(): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) x = relay.var("x", shape=(1, 16)) id_func = relay.Function([x], x) @@ -46,7 +46,7 @@ def test_remove_all_prelude_functions_but_referenced_functions(): def test_keep_only_referenced_prelude_functions(): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) l = p.nil() for i in [4, 3, 2, 1, 0]: @@ -59,7 +59,7 @@ def test_keep_only_referenced_prelude_functions(): def test_multiple_entry_functions(): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) l = p.nil() for i in [4, 3, 2, 1, 0]: @@ -78,7 +78,7 @@ def test_multiple_entry_functions(): def test_globalvar_as_call_arg(): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) tensor_array = p.get_var('tensor_array', 'int32') tensor1 = p.get_var('tensor1', 'int32') @@ -96,7 +96,7 @@ def test_globalvar_as_call_arg(): def test_call_globalvar_without_args(): def get_mod(): - mod = relay.Module({}) + mod = tvm.IRModule({}) fn1 = relay.Function([], relay.const(1)) fn2 = relay.Function([], relay.const(2)) g1 = relay.GlobalVar('g1') diff --git a/tests/python/relay/test_pass_simplify_inference.py b/tests/python/relay/test_pass_simplify_inference.py index 4e62fa6dcb08..bb398939156e 100644 --- a/tests/python/relay/test_pass_simplify_inference.py +++ b/tests/python/relay/test_pass_simplify_inference.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from tvm.ir import IRModule from tvm import relay as rly from tvm.relay.transform import SimplifyInference @@ -50,7 +51,7 @@ def check(dim, axis, nstep): gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis, shape=ttype1.shape) - mod = rly.Module.from_expr(y1) + mod = IRModule.from_expr(y1) simplify = SimplifyInference() mod = simplify(mod) y1 = mod["main"].body diff --git a/tests/python/relay/test_pass_to_a_normal_form.py b/tests/python/relay/test_pass_to_a_normal_form.py index 865729002745..46bde4f490b8 100644 --- a/tests/python/relay/test_pass_to_a_normal_form.py +++ b/tests/python/relay/test_pass_to_a_normal_form.py @@ -26,7 +26,7 @@ def run_opt_pass(expr, passes): passes = passes if isinstance(passes, list) else [passes] - mod = relay.Module.from_expr(expr) + mod = tvm.IRModule.from_expr(expr) seq = transform.Sequential(passes) with transform.PassContext(opt_level=3): mod = seq(mod) @@ -110,7 +110,7 @@ def test_recursion(): } f(5); """ - mod = relay.Module() + mod = tvm.IRModule() i64 = relay.TensorType((), 'int64') f = relay.GlobalVar("f") n = relay.Var("n", i64) @@ -143,7 +143,7 @@ def test_ref(): def test_nat_add(): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) add_nat_definitions(p) nat = p.nat @@ -192,7 +192,7 @@ def test_gradient_if(): net = relay.If(cond, x, x) net = relay.add(x, net) net = relay.Function([cond,x,y], net) - mod = relay.Module.from_expr(net) + mod = tvm.IRModule.from_expr(net) mod = relay.transform.ToANormalForm()(mod) mod["main"] = relay.transform.gradient(mod["main"], mode='higher_order') mod = relay.transform.ToANormalForm()(mod) diff --git a/tests/python/relay/test_pass_to_cps.py b/tests/python/relay/test_pass_to_cps.py index 1d09c0d67f5b..4645e20c7468 100644 --- a/tests/python/relay/test_pass_to_cps.py +++ b/tests/python/relay/test_pass_to_cps.py @@ -42,7 +42,7 @@ def test_double(): # make sure cps work for recursion. def test_recursion(): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) add_nat_definitions(p) shape = (10, 10) diff --git a/tests/python/relay/test_pass_to_graph_normal_form.py b/tests/python/relay/test_pass_to_graph_normal_form.py index a29172471d48..5c5221f65a46 100644 --- a/tests/python/relay/test_pass_to_graph_normal_form.py +++ b/tests/python/relay/test_pass_to_graph_normal_form.py @@ -22,7 +22,7 @@ def run_opt_pass(expr, opt_pass): - mod = relay.Module.from_expr(expr) + mod = tvm.IRModule.from_expr(expr) mod = opt_pass(mod) entry = mod["main"] return entry if isinstance(expr, relay.Function) else entry.body @@ -30,7 +30,7 @@ def run_opt_pass(expr, opt_pass): def check_eval(expr, args, expected_result, mod=None, rtol=1e-07): if mod is None: - mod = relay.Module() + mod = tvm.IRModule() ctx = tvm.context("llvm", 0) intrp = create_executor(mod=mod, ctx=ctx, target="llvm") diff --git a/tests/python/relay/test_pass_unmatched_cases.py b/tests/python/relay/test_pass_unmatched_cases.py index b06de4c8e384..615d4e092291 100644 --- a/tests/python/relay/test_pass_unmatched_cases.py +++ b/tests/python/relay/test_pass_unmatched_cases.py @@ -47,7 +47,7 @@ def test_trivial_matches(): def test_single_constructor_adt(): - mod = relay.Module() + mod = tvm.IRModule() box = relay.GlobalTypeVar('box') a = relay.TypeVar('a') box_ctor = relay.Constructor('box', [a], box) @@ -76,7 +76,7 @@ def test_single_constructor_adt(): def test_too_specific_match(): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) v = relay.Var('v') @@ -117,7 +117,7 @@ def test_too_specific_match(): def test_multiple_constructor_clauses(): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) v = relay.Var('v') @@ -147,7 +147,7 @@ def test_multiple_constructor_clauses(): def test_missing_in_the_middle(): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) v = relay.Var('v') @@ -185,7 +185,7 @@ def test_missing_in_the_middle(): def test_mixed_adt_constructors(): - mod = relay.Module() + mod = tvm.IRModule() box = relay.GlobalTypeVar('box') a = relay.TypeVar('a') box_ctor = relay.Constructor('box', [a], box) diff --git a/tests/python/relay/test_pass_vars.py b/tests/python/relay/test_pass_vars.py index 70eb047ad03e..d8b77ba35612 100644 --- a/tests/python/relay/test_pass_vars.py +++ b/tests/python/relay/test_pass_vars.py @@ -82,7 +82,7 @@ def test_bound_vars(): def test_match_vars(): - mod = relay.Module() + mod = tvm.IRModule() p = relay.prelude.Prelude(mod) x = relay.Var('x') diff --git a/tests/python/relay/test_py_converter.py b/tests/python/relay/test_py_converter.py index e46b6d41eeb5..f489e9fcb04b 100644 --- a/tests/python/relay/test_py_converter.py +++ b/tests/python/relay/test_py_converter.py @@ -199,7 +199,7 @@ def test_local_function(): def test_global_function(): - mod = relay.Module() + mod = tvm.IRModule() ident = relay.GlobalVar('ident') a = relay.TypeVar('a') v = relay.Var('v', a) @@ -218,7 +218,7 @@ def test_global_function(): def test_constructor(): - mod = relay.Module() + mod = tvm.IRModule() box, box_ctor = init_box_adt(mod) init_box_int = box_ctor(relay.const(1)) @@ -235,7 +235,7 @@ def test_constructor(): def test_match_wildcard(): - mod = relay.Module() + mod = tvm.IRModule() box, box_ctor = init_box_adt(mod) v = relay.Var('v') match = relay.Let( @@ -249,7 +249,7 @@ def test_match_wildcard(): def test_match_var(): - mod = relay.Module() + mod = tvm.IRModule() box, box_ctor = init_box_adt(mod) v = relay.Var('v') w = relay.Var('w') @@ -265,7 +265,7 @@ def test_match_var(): def test_match_pattern(): - mod = relay.Module() + mod = tvm.IRModule() box, box_ctor = init_box_adt(mod) v = relay.Var('v') w = relay.Var('w') @@ -279,7 +279,7 @@ def test_match_pattern(): def test_nested_match_pattern(): - mod = relay.Module() + mod = tvm.IRModule() box, box_ctor = init_box_adt(mod) v = relay.Var('v') w = relay.Var('w') @@ -296,7 +296,7 @@ def test_nested_match_pattern(): assert_tensor_value(match_val, 2) def test_match_order(): - mod = relay.Module() + mod = tvm.IRModule() box, box_ctor = init_box_adt(mod) v = relay.Var('v') w = relay.Var('w') @@ -316,7 +316,7 @@ def test_match_order(): def test_local_recursion(): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) v = relay.Var('v') @@ -342,11 +342,11 @@ def test_local_recursion(): assert_tensor_value(val.fields[1].fields[0], 2) assert_constructor_value(val.fields[1].fields[1], p.cons, 2) assert_tensor_value(val.fields[1].fields[1].fields[0], 3) - assert_constructor_value(val.fields[1].fields[1].fields[1], p.nil, 0) + assert_constructor_value(val.fields[1].fields[1].fields[1], p.nil, 0) def test_global_recursion(): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) copy = relay.GlobalVar('copy') # same as above: it copies the given list @@ -398,7 +398,7 @@ def test_higher_order_call(): def test_match_effect_exactly_once(): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) # the list should be of length 1! @@ -423,7 +423,7 @@ def test_match_effect_exactly_once(): def test_arbitrary_let_nesting(): # something that is tricky to do in Python but comes naturally in Relay - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) x = relay.Var('x') r = relay.Var('r') diff --git a/tests/python/relay/test_type_functor.py b/tests/python/relay/test_type_functor.py index d09a8938bb54..b84a49ae3fda 100644 --- a/tests/python/relay/test_type_functor.py +++ b/tests/python/relay/test_type_functor.py @@ -64,7 +64,7 @@ def test_tuple_type(): def test_type_relation(): - func = tvm.get_env_func('tvm.relay.type_relation.Broadcast') + func = tvm.ir.EnvFunc.get('tvm.relay.type_relation.Broadcast') attrs = tvm.make.node('attrs.TestAttrs', name='attr', padding=(3,4)) tp = TypeVar('tp') tf = FuncType([], TupleType([]), [], []) diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index 3f6b0d2eb895..892c91d9c43a 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -17,6 +17,7 @@ """Test that type checker correcly computes types for expressions. """ +import tvm from tvm import relay from tvm.relay import op, transform, analysis from tvm.relay.analysis import assert_alpha_equal @@ -24,7 +25,7 @@ def run_infer_type(expr, mod=None): if not mod: - mod = relay.Module.from_expr(expr) + mod = tvm.IRModule.from_expr(expr) mod = transform.InferType()(mod) entry = mod["main"] return entry if isinstance(expr, relay.Function) else entry.body @@ -44,7 +45,7 @@ def run_infer_type(expr, mod=None): return mod[gv].body -def assert_has_type(expr, typ, mod=relay.module.Module({})): +def assert_has_type(expr, typ, mod=tvm.IRModule({})): checked_expr = run_infer_type(expr, mod) checked_type = checked_expr.checked_type if checked_type != typ: @@ -152,7 +153,7 @@ def @f(%n: int32, %data: float32) -> float32 { sb.ret(data) with sb.else_scope(): sb.ret(f(relay.subtract(n, relay.const(1, ti32)), relay.log(data))) - mod = relay.Module() + mod = tvm.IRModule() mod[f] = relay.Function([n, data], sb.get()) assert "@f(%1, %2) /* ty=float32 */" in mod.astext() assert mod[f].checked_type == relay.FuncType([ti32, tf32], tf32) @@ -267,7 +268,7 @@ def test_type_args(): def test_global_var_recursion(): - mod = relay.Module({}) + mod = tvm.IRModule({}) gv = relay.GlobalVar("main") x = relay.var('x', shape=[]) tt = relay.scalar_type('float32') @@ -289,7 +290,7 @@ def test_equal(): def test_constructor_type(): - mod = relay.Module() + mod = tvm.IRModule() box, constructor = initialize_box_adt(mod) a = relay.TypeVar('a') @@ -300,7 +301,7 @@ def test_constructor_type(): def test_constructor_call(): - mod = relay.Module() + mod = tvm.IRModule() box, constructor = initialize_box_adt(mod) box_unit = constructor(relay.Tuple([])) @@ -313,7 +314,7 @@ def test_constructor_call(): def test_adt_match(): - mod = relay.Module() + mod = tvm.IRModule() box, constructor = initialize_box_adt(mod) v = relay.Var('v', relay.TensorType((), 'float32')) @@ -331,7 +332,7 @@ def test_adt_match(): def test_adt_match_type_annotations(): - mod = relay.Module() + mod = tvm.IRModule() box, constructor = initialize_box_adt(mod) # the only type annotation is inside the match pattern var diff --git a/tests/python/relay/test_type_solver.py b/tests/python/relay/test_type_solver.py index 11274181af73..118066e7cf52 100644 --- a/tests/python/relay/test_type_solver.py +++ b/tests/python/relay/test_type_solver.py @@ -20,7 +20,7 @@ def make_rel(name, args, num_inputs=None, attrs=None): - func = tvm.get_env_func("tvm.relay.type_relation." + name) + func = tvm.ir.EnvFunc.get("tvm.relay.type_relation." + name) if num_inputs is None: num_inputs = len(args) - 1 return relay.ty.TypeRelation(func, args, num_inputs, attrs) diff --git a/tests/python/relay/test_typecall.py b/tests/python/relay/test_typecall.py index 1c663d2301e9..fa2601f30af1 100644 --- a/tests/python/relay/test_typecall.py +++ b/tests/python/relay/test_typecall.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import tvm from tvm import relay from tvm.relay import transform @@ -23,14 +24,14 @@ def test_dup_type(): make_id = relay.Function([av], relay.Tuple([av, av]), None, [a]) t = relay.scalar_type("float32") b = relay.Var("b", t) - mod = relay.Module.from_expr(make_id(b)) + mod = tvm.IRModule.from_expr(make_id(b)) mod = transform.InferType()(mod) inferred = mod["main"].body assert inferred.checked_type == relay.TupleType([t, t]) def test_id_type(): - mod = relay.Module() + mod = tvm.IRModule() id_type = relay.GlobalTypeVar("id") a = relay.TypeVar("a") mod[id_type] = relay.TypeData(id_type, [a], []) diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index cd8a875f5bf7..c4cd616cdec0 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -46,10 +46,10 @@ def check_result(args, expected_result, mod=None): def veval(f, *args, ctx=tvm.cpu(), target="llvm"): if isinstance(f, relay.Expr): - mod = relay.Module() + mod = tvm.IRModule() mod["main"] = f else: - assert isinstance(f, relay.Module), "expected expression or module" + assert isinstance(f, tvm.IRModule), "expected expression or module" mod = f exe = relay.vm.compile(mod, target) vm = runtime.vm.VirtualMachine(exe) @@ -92,7 +92,7 @@ def test_id(): x = relay.var('x', shape=(10, 10), dtype='float64') f = relay.Function([x], x) x_data = np.random.rand(10, 10).astype('float64') - mod = relay.Module() + mod = tvm.IRModule() mod["main"] = f check_result([x_data], x_data, mod=mod) @@ -100,7 +100,7 @@ def test_op(): x = relay.var('x', shape=(10, 10)) f = relay.Function([x], x + x) x_data = np.random.rand(10, 10).astype('float32') - mod = relay.Module() + mod = tvm.IRModule() mod["main"] = f check_result([x_data], 2 * x_data, mod=mod) @@ -116,7 +116,7 @@ def test_cond(): x_data = np.random.rand(10, 10).astype('float32') y_data = np.random.rand(10, 10).astype('float32') - mod = relay.Module() + mod = tvm.IRModule() mod["main"] = f # same check_result([x_data, x_data], True, mod=mod) @@ -132,7 +132,7 @@ def test_simple_if(): x_data = np.random.rand(10, 10).astype('float32') y_data = np.random.rand(10, 10).astype('float32') - mod = relay.Module() + mod = tvm.IRModule() mod["main"] = f # same check_result([x_data, x_data], x_data, mod=mod) @@ -141,7 +141,7 @@ def test_simple_if(): check_result([x_data, y_data], y_data, mod=mod) def test_simple_call(): - mod = relay.module.Module({}) + mod = tvm.IRModule({}) sum_up = relay.GlobalVar('sum_up') i = relay.var('i', shape=[], dtype='int32') sb = ScopeBuilder() @@ -154,7 +154,7 @@ def test_simple_call(): check_result([i_data], i_data, mod=mod) def test_count_loop(): - mod = relay.module.Module({}) + mod = tvm.IRModule({}) sum_up = relay.GlobalVar('sum_up') i = relay.var('i', shape=[], dtype='int32') sb = ScopeBuilder() @@ -174,7 +174,7 @@ def test_count_loop(): check_result([i_data], i_data, mod=mod) def test_sum_loop(): - mod = relay.module.Module({}) + mod = tvm.IRModule({}) sum_up = relay.GlobalVar('sum_up') i = relay.var('i', shape=[], dtype='int32') accum = relay.var('accum', shape=[], dtype='int32') @@ -201,7 +201,7 @@ def test_tuple_fst(): f = relay.Function([tup], relay.TupleGetItem(tup, 0)) i_data = np.random.rand(41).astype('float32') j_data = np.random.rand(10).astype('float32') - mod = relay.Module() + mod = tvm.IRModule() mod["main"] = f check_result([(i_data, j_data)], i_data, mod=mod) @@ -211,12 +211,12 @@ def test_tuple_second(): f = relay.Function([tup], relay.TupleGetItem(tup, 1)) i_data = np.random.rand(41).astype('float32') j_data = np.random.rand(10).astype('float32') - mod = relay.Module() + mod = tvm.IRModule() mod["main"] = f check_result([(i_data, j_data)], j_data, mod=mod) def test_list_constructor(): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) nil = p.nil @@ -251,7 +251,7 @@ def test_let_tensor(): f = relay.Function([x], body) x_data = np.random.rand(*shape).astype('float32') - mod = relay.Module() + mod = tvm.IRModule() mod["main"] = f check_result([x_data], x_data + 42.0, mod=mod) @@ -267,12 +267,12 @@ def test_let_scalar(): f = relay.Function([x], body) x_data = np.array(np.random.rand()).astype('float32') - mod = relay.Module() + mod = tvm.IRModule() mod["main"] = f check_result([x_data], x_data + 42.0, mod=mod) def test_compose(): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) compose = p.compose @@ -305,7 +305,7 @@ def test_compose(): tvm.testing.assert_allclose(result.asnumpy(), x_data + 2.0) def test_list_hd(): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) nil = p.nil @@ -326,7 +326,7 @@ def test_list_hd(): @pytest.mark.xfail def test_list_tl_empty_list(): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) nil = p.nil @@ -341,7 +341,7 @@ def test_list_tl_empty_list(): print(result) def test_list_tl(): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) nil = p.nil @@ -364,7 +364,7 @@ def test_list_nth(): expected = list(range(10)) for i in range(len(expected)): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) nil = p.nil @@ -382,7 +382,7 @@ def test_list_nth(): def test_list_update(): expected = list(range(10)) - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) nil = p.nil @@ -406,7 +406,7 @@ def test_list_update(): def test_list_length(): expected = list(range(10)) - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) nil = p.nil @@ -426,7 +426,7 @@ def test_list_length(): tvm.testing.assert_allclose(result.asnumpy(), 10) def test_list_map(): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) x = relay.var('x', 'int32') @@ -444,7 +444,7 @@ def test_list_map(): tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 2])) def test_list_foldl(): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) nil = p.nil @@ -462,7 +462,7 @@ def test_list_foldl(): tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 3, 2, 2, 1, 1])) def test_list_foldr(): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) nil = p.nil @@ -480,7 +480,7 @@ def test_list_foldr(): tvm.testing.assert_allclose(vmobj_to_list(result), np.array([1, 2, 3])) def test_list_sum(): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) nil = p.nil @@ -494,7 +494,7 @@ def test_list_sum(): tvm.testing.assert_allclose(result.asnumpy(), 6) def test_list_filter(): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) nil = p.nil @@ -530,7 +530,7 @@ def test_add_op_scalar(): return x + y; } """ - mod = relay.Module() + mod = tvm.IRModule() x = relay.var('x', shape=()) y = relay.var('y', shape=()) func = relay.Function([x, y], relay.op.add(x, y)) @@ -546,7 +546,7 @@ def test_add_op_tensor(): return x + y; } """ - mod = relay.Module() + mod = tvm.IRModule() x = relay.var('x', shape=(10, 5)) y = relay.var('y', shape=(10, 5)) func = relay.Function([x, y], relay.op.add(x, y)) @@ -562,7 +562,7 @@ def test_add_op_broadcast(): return x + y; } """ - mod = relay.Module() + mod = tvm.IRModule() x = relay.var('x', shape=(10, 5)) y = relay.var('y', shape=(1, 5)) func = relay.Function([x, y], relay.op.add(x, y)) diff --git a/tests/python/relay/test_vm_serialization.py b/tests/python/relay/test_vm_serialization.py index 515baa2ef6ce..9fed4955705f 100644 --- a/tests/python/relay/test_vm_serialization.py +++ b/tests/python/relay/test_vm_serialization.py @@ -22,7 +22,7 @@ from tvm.runtime import vm as _vm from tvm.relay import vm as rly_vm from tvm import relay -from tvm.relay.module import Module as rly_module + from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.prelude import Prelude from tvm.contrib import util @@ -30,12 +30,12 @@ def create_exec(f, target="llvm", params=None): if isinstance(f, relay.Expr): - mod = relay.Module() + mod = tvm.IRModule() mod["main"] = f executable = rly_vm.compile(mod, target=target, params=params) return executable else: - assert isinstance(f, relay.Module), "expected mod as relay.Module" + assert isinstance(f, tvm.IRModule), "expected mod as tvm.IRModule" executable = rly_vm.compile(f, target=target, params=params) return executable @@ -76,7 +76,7 @@ def get_serialized_output(mod, data, params, target, ctx, dtype='float32'): def test_serializer(): - mod = rly_module({}) + mod = tvm.IRModule({}) a = relay.const(1.0, "float32") x = relay.var('x', shape=(10, 10), dtype='float32') f1 = relay.Function([x], x + a) @@ -187,7 +187,7 @@ def test_if(): def test_loop(): - mod = relay.module.Module({}) + mod = tvm.IRModule({}) sum_up = relay.GlobalVar('sum_up') i = relay.var('i', shape=[], dtype='int32') accum = relay.var('accum', shape=[], dtype='int32') @@ -235,7 +235,7 @@ def test_tuple(): def test_adt_list(): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) l1 = p.cons(relay.const(1), p.nil()) @@ -263,7 +263,7 @@ def test_adt_list(): def test_adt_compose(): - mod = relay.Module() + mod = tvm.IRModule() p = Prelude(mod) compose = p.compose diff --git a/tests/python/unittest/test_codegen_cross_llvm.py b/tests/python/unittest/test_codegen_cross_llvm.py index 6604038ab2ef..1827ccf63d79 100644 --- a/tests/python/unittest/test_codegen_cross_llvm.py +++ b/tests/python/unittest/test_codegen_cross_llvm.py @@ -71,7 +71,7 @@ def build_arm(): port = int(os.environ['TVM_RPC_ARM_PORT']) try: remote = rpc.connect(host, port) - except tvm.TVMError as e: + except tvm.error.TVMError as e: pass if remote: diff --git a/tests/python/unittest/test_container.py b/tests/python/unittest/test_container.py index 5ed6e04e0b45..f7ffd0288f1b 100644 --- a/tests/python/unittest/test_container.py +++ b/tests/python/unittest/test_container.py @@ -42,7 +42,7 @@ def test_tuple_object(): ])) fn = relay.Function([x], relay.expr.TupleGetItem(x, 0)) - mod = relay.Module.from_expr(fn) + mod = tvm.IRModule.from_expr(fn) exe = relay.create_executor( kind="vm", mod=mod, ctx=nd.cpu(), target="llvm") diff --git a/tests/python/unittest/test_graph_tuner_core.py b/tests/python/unittest/test_graph_tuner_core.py index 32c16e239461..a8b22fd787ee 100644 --- a/tests/python/unittest/test_graph_tuner_core.py +++ b/tests/python/unittest/test_graph_tuner_core.py @@ -159,7 +159,7 @@ def test_DPTuner_run(): target_ops = [relay.nn.conv2d] g, records, ltf_records, ltf_keys, tasks = _create_data(target, dshape, dtype, layout) - mod = relay.module.Module() + mod = tvm.IRModule() mod["main"] = g costs = [0.02, 0.02, 0.045] config_list = [] diff --git a/tests/python/unittest/test_lang_container.py b/tests/python/unittest/test_lang_container.py index 92edbee9072f..4f8a93b8fbd3 100644 --- a/tests/python/unittest/test_lang_container.py +++ b/tests/python/unittest/test_lang_container.py @@ -26,8 +26,8 @@ def test_array(): def test_array_save_load_json(): a = tvm.convert([1,2,3]) - json_str = tvm.save_json(a) - a_loaded = tvm.load_json(json_str) + json_str = tvm.ir.save_json(a) + a_loaded = tvm.ir.load_json(json_str) assert(a_loaded[1].value == 2) @@ -59,8 +59,8 @@ def test_map_save_load_json(): b = tvm.var('b') amap = tvm.convert({a: 2, b: 3}) - json_str = tvm.save_json(amap) - amap = tvm.load_json(json_str) + json_str = tvm.ir.save_json(amap) + amap = tvm.ir.load_json(json_str) assert len(amap) == 2 dd = {kv[0].name : kv[1].value for kv in amap.items()} assert(dd == {"a": 2, "b": 3}) diff --git a/tests/python/unittest/test_lang_group.py b/tests/python/unittest/test_lang_group.py index 3efc9bc5096b..e78ffb3541d3 100644 --- a/tests/python/unittest/test_lang_group.py +++ b/tests/python/unittest/test_lang_group.py @@ -46,7 +46,7 @@ def test_scan_group(): # compute outside group error. s[s_update2].compute_at(s[s_init], s_init.op.axis[0]) assert False - except tvm.TVMError: + except tvm.error.TVMError: pass def test_compute_group(): diff --git a/tests/python/unittest/test_lang_operator.py b/tests/python/unittest/test_lang_operator.py index ac2ee6d88cc5..26783e62db13 100644 --- a/tests/python/unittest/test_lang_operator.py +++ b/tests/python/unittest/test_lang_operator.py @@ -19,7 +19,7 @@ def check_throws(f): try: f() - except tvm.TVMError: + except tvm.error.TVMError: pass else: raise AssertionError("Should have raised an exception but didn't.") diff --git a/tests/python/unittest/test_lang_reflection.py b/tests/python/unittest/test_lang_reflection.py index 18230bf5e1fa..b971e386cfc7 100644 --- a/tests/python/unittest/test_lang_reflection.py +++ b/tests/python/unittest/test_lang_reflection.py @@ -22,9 +22,9 @@ def test_const_saveload_json(): y = tvm.const(10, "int32") z = x + y z = z + z - json_str = tvm.save_json(z) - zz = tvm.load_json(json_str) - assert tvm.save_json(zz) == tvm.save_json(z) + json_str = tvm.ir.save_json(z) + zz = tvm.ir.load_json(json_str) + assert tvm.ir.save_json(zz) == tvm.ir.save_json(z) def test_make_smap(): @@ -33,8 +33,8 @@ def test_make_smap(): y = tvm.const(10, "int32") z = tvm.expr.Add(x, y) smap = tvm.convert({"z": z, "x": x}) - json_str = tvm.save_json(tvm.convert([smap])) - arr = tvm.load_json(json_str) + json_str = tvm.ir.save_json(tvm.convert([smap])) + arr = tvm.ir.load_json(json_str) assert len(arr) == 1 assert arr[0]["z"].a == arr[0]["x"] @@ -57,13 +57,13 @@ def test_make_attrs(): try: x = tvm.make.node("attrs.TestAttrs", unknown_key=1, name="xx") assert False - except tvm.TVMError as e: + except tvm.error.TVMError as e: assert str(e).find("unknown_key") != -1 try: x = tvm.make.node("attrs.TestAttrs", axis=100, name="xx") assert False - except tvm.TVMError as e: + except tvm.error.TVMError as e: assert str(e).find("upper bound") != -1 x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3,4)) @@ -75,7 +75,7 @@ def test_make_attrs(): dattr = tvm.make.node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0)) assert dattr.x.value == 1 - datrr = tvm.load_json(tvm.save_json(dattr)) + datrr = tvm.ir.load_json(tvm.ir.save_json(dattr)) assert dattr.name.value == "xyz" @@ -84,8 +84,8 @@ def test_make_sum(): A = tvm.placeholder((2, 10), name='A') k = tvm.reduce_axis((0,10), "k") B = tvm.compute((2,), lambda i: tvm.sum(A[i, k], axis=k), name="B") - json_str = tvm.save_json(B) - BB = tvm.load_json(json_str) + json_str = tvm.ir.save_json(B) + BB = tvm.ir.load_json(json_str) assert B.op.body[0].combiner is not None assert BB.op.body[0].combiner is not None @@ -96,10 +96,10 @@ def test(x): return x + 1 f = tvm.get_global_func("test.env_func") - x = tvm.get_env_func("test.env_func") + x = tvm.ir.EnvFunc.get("test.env_func") assert x.name == "test.env_func" - json_str = tvm.save_json([x]) - y = tvm.load_json(json_str)[0] + json_str = tvm.ir.save_json([x]) + y = tvm.ir.load_json(json_str)[0] assert y.name == x.name assert y(1) == 2 assert y.func(1) == 2 @@ -109,8 +109,8 @@ def test(x): assert x.padding[0].value == 3 assert x.padding[1].value == 4 assert x.axis == 10 - x = tvm.load_json(tvm.save_json(x)) - assert isinstance(x.func, tvm.container.EnvFunc) + x = tvm.ir.load_json(tvm.ir.save_json(x)) + assert isinstance(x.func, tvm.ir.EnvFunc) assert x.func(10) == 11 diff --git a/tests/python/unittest/test_lang_schedule.py b/tests/python/unittest/test_lang_schedule.py index eeab81b965b4..6b5b7fa2be67 100644 --- a/tests/python/unittest/test_lang_schedule.py +++ b/tests/python/unittest/test_lang_schedule.py @@ -36,8 +36,8 @@ def test_schedule_create(): assert T.op.axis[1] in s[T].leaf_iter_vars # save load json - json_str = tvm.save_json(s) - s_loaded = tvm.load_json(json_str) + json_str = tvm.ir.save_json(s) + s_loaded = tvm.ir.load_json(json_str) assert isinstance(s_loaded, tvm.schedule.Schedule) assert(str(s_loaded.outputs[0].body) == str(s.outputs[0].body)) @@ -65,7 +65,7 @@ def test_reorder(): # must raise an error s[T].reorder(xi2, xi1, xi2) assert False - except tvm.TVMError: + except tvm.error.TVMError: pass def test_split(): diff --git a/tests/python/unittest/test_lang_tag.py b/tests/python/unittest/test_lang_tag.py index fc884ea5bc92..201abf193eb4 100644 --- a/tests/python/unittest/test_lang_tag.py +++ b/tests/python/unittest/test_lang_tag.py @@ -48,7 +48,7 @@ def test_with(): assert "hello" in C.op.attrs assert "xx" not in C.op.attrs assert C.op.attrs["hello"].value == 1 - CC = tvm.load_json(tvm.save_json(C)) + CC = tvm.ir.load_json(tvm.ir.save_json(C)) assert CC.op.attrs["hello"].value == 1 assert CC.op.attrs["arr"][0].value == 10 # str format happened to be json compatible diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py index e363a2cf11be..a8b5fc094cca 100644 --- a/tests/python/unittest/test_lang_tensor.py +++ b/tests/python/unittest/test_lang_tensor.py @@ -96,8 +96,8 @@ def test_tensor_reduce(): rv = tvm.reduce_axis((0, A.shape[1]), "k") C = tvm.compute((m, n), lambda i, j: tvm.sum(T(i, j, rv+1), axis=rv)) # json load save - C_json = tvm.save_json(C) - C_loaded = tvm.load_json(C_json) + C_json = tvm.ir.save_json(C) + C_loaded = tvm.ir.load_json(C_json) assert(isinstance(C_loaded, tvm.tensor.Tensor)) assert(str(C_loaded) == str(C)) @@ -201,8 +201,8 @@ def test_scan_multi_out(): [s1, s2]) assert(r0.value_index == 0) assert(r1.value_index == 1) - json_str = tvm.save_json(r0.op) - zz = tvm.load_json(json_str) + json_str = tvm.ir.save_json(r0.op) + zz = tvm.ir.load_json(json_str) assert isinstance(zz, tvm.tensor.ScanOp) def test_extern(): diff --git a/tests/python/unittest/test_pass_inline.py b/tests/python/unittest/test_pass_inline.py index 511a1438f4be..e87353ed98a1 100644 --- a/tests/python/unittest/test_pass_inline.py +++ b/tests/python/unittest/test_pass_inline.py @@ -32,7 +32,7 @@ def test_inline(): stmt = tvm.ir_pass.Inline( T.op, [1,2,3], T.op.body, stmt) assert False - except tvm.TVMError: + except tvm.error.TVMError: pass def test_inline2(): diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py index 43bb79cf0363..ff5f46536d83 100644 --- a/tests/python/unittest/test_runtime_rpc.py +++ b/tests/python/unittest/test_runtime_rpc.py @@ -78,7 +78,7 @@ def remotethrow(name): try: f3("abc") assert False - except tvm.TVMError as e: + except tvm.error.TVMError as e: assert "abc" in str(e) f2 = client.get_function("rpc.test.strcat") diff --git a/topi/python/topi/arm_cpu/bitserial_conv2d.py b/topi/python/topi/arm_cpu/bitserial_conv2d.py index 9b8360dd1427..4de2b1438a92 100644 --- a/topi/python/topi/arm_cpu/bitserial_conv2d.py +++ b/topi/python/topi/arm_cpu/bitserial_conv2d.py @@ -366,7 +366,7 @@ def _bitserial_conv2d_legalize(attrs, inputs, arg_types): Parameters ---------- - attrs : tvm.attrs.Attrs + attrs : tvm.ir.Attrs Attributes of current convolution inputs : list of tvm.relay.Expr The args of the Relay expr to be legalized diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py index 57819c0850aa..b0d4d1361ccc 100644 --- a/topi/python/topi/arm_cpu/conv2d.py +++ b/topi/python/topi/arm_cpu/conv2d.py @@ -514,7 +514,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): Parameters ---------- - attrs : tvm.attrs.Attrs + attrs : tvm.ir.Attrs Attributes of current convolution inputs : tvm.relay.Expr Grouped input symbols diff --git a/topi/python/topi/cuda/conv2d_winograd.py b/topi/python/topi/cuda/conv2d_winograd.py index cc165983234d..dfa569a556ce 100644 --- a/topi/python/topi/cuda/conv2d_winograd.py +++ b/topi/python/topi/cuda/conv2d_winograd.py @@ -311,7 +311,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, F): Parameters ---------- - attrs : tvm.attrs.Attrs + attrs : tvm.ir.Attrs Attributes of current convolution inputs : tvm.relay.Expr Grouped input symbols diff --git a/topi/python/topi/nn/bitserial_conv2d.py b/topi/python/topi/nn/bitserial_conv2d.py index 932c141450ac..e1f8f819968f 100644 --- a/topi/python/topi/nn/bitserial_conv2d.py +++ b/topi/python/topi/nn/bitserial_conv2d.py @@ -435,7 +435,7 @@ def bitserial_conv2d_legalize(attrs, inputs, types): Parameters ---------- - attrs : tvm.attrs.Attrs + attrs : tvm.ir.Attrs Attributes of current convolution inputs : list of tvm.relay.Expr The args of the Relay expr to be legalized diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index 046c48e7d87c..abdb5f22e369 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -79,7 +79,7 @@ def conv2d_legalize(attrs, inputs, types): Parameters ---------- - attrs : tvm.attrs.Attrs + attrs : tvm.ir.Attrs Attributes of current convolution inputs : list of tvm.relay.Expr The args of the Relay expr to be legalized @@ -101,7 +101,7 @@ def conv2d_alter_layout(attrs, inputs, tinfos, F): Parameters ---------- - attrs : tvm.attrs.Attrs + attrs : tvm.ir.Attrs Attributes of current convolution inputs : tvm.relay.Expr Grouped input symbols diff --git a/topi/python/topi/nn/conv2d_transpose.py b/topi/python/topi/nn/conv2d_transpose.py index a240b687c86d..e635f43cdbc4 100644 --- a/topi/python/topi/nn/conv2d_transpose.py +++ b/topi/python/topi/nn/conv2d_transpose.py @@ -111,7 +111,7 @@ def conv2d_transpose_legalize(attrs, inputs, types): Parameters ---------- - attrs : tvm.attrs.Attrs + attrs : tvm.ir.Attrs Attributes of current Transposed 2D convolution inputs : list of tvm.relay.Expr The args of the Relay expr to be legalized diff --git a/topi/python/topi/x86/conv2d_alter_op.py b/topi/python/topi/x86/conv2d_alter_op.py index cd612c34e5a2..9374387fb23a 100644 --- a/topi/python/topi/x86/conv2d_alter_op.py +++ b/topi/python/topi/x86/conv2d_alter_op.py @@ -173,7 +173,7 @@ def _conv2d_legalize(attrs, inputs, arg_types): Parameters ---------- - attrs : tvm.attrs.Attrs + attrs : tvm.ir.Attrs Attributes of current convolution inputs : list of tvm.relay.Expr The args of the Relay expr to be legalized diff --git a/tutorials/autotvm/tune_relay_arm.py b/tutorials/autotvm/tune_relay_arm.py index 67b7d96f3894..4cbdf52163d6 100644 --- a/tutorials/autotvm/tune_relay_arm.py +++ b/tutorials/autotvm/tune_relay_arm.py @@ -99,7 +99,7 @@ def get_network(name, batch_size): mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) net = mod["main"] net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs) - mod = relay.Module.from_expr(net) + mod = tvm.IRModule.from_expr(net) else: raise ValueError("Unsupported network: " + name) diff --git a/tutorials/autotvm/tune_relay_cuda.py b/tutorials/autotvm/tune_relay_cuda.py index 2cd99497259d..72fc2bed3d0e 100644 --- a/tutorials/autotvm/tune_relay_cuda.py +++ b/tutorials/autotvm/tune_relay_cuda.py @@ -99,7 +99,7 @@ def get_network(name, batch_size): mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) net = mod["main"] net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs) - mod = relay.Module.from_expr(net) + mod = tvm.IRModule.from_expr(net) else: raise ValueError("Unsupported network: " + name) diff --git a/tutorials/autotvm/tune_relay_mobile_gpu.py b/tutorials/autotvm/tune_relay_mobile_gpu.py index eb7b96e6972b..3c56524078c2 100644 --- a/tutorials/autotvm/tune_relay_mobile_gpu.py +++ b/tutorials/autotvm/tune_relay_mobile_gpu.py @@ -100,7 +100,7 @@ def get_network(name, batch_size): mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) net = mod["main"] net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs) - mod = relay.Module.from_expr(net) + mod = tvm.IRModule.from_expr(net) else: raise ValueError("Unsupported network: " + name) diff --git a/tutorials/autotvm/tune_relay_x86.py b/tutorials/autotvm/tune_relay_x86.py index 93a073170388..5e26f5858bbc 100644 --- a/tutorials/autotvm/tune_relay_x86.py +++ b/tutorials/autotvm/tune_relay_x86.py @@ -69,7 +69,7 @@ def get_network(name, batch_size): mod, params = relay.frontend.from_mxnet(block, shape={input_name: input_shape}, dtype=dtype) net = mod["main"] net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs) - mod = relay.Module.from_expr(net) + mod = tvm.IRModule.from_expr(net) else: raise ValueError("Unsupported network: " + name) diff --git a/tutorials/dev/relay_pass_infra.py b/tutorials/dev/relay_pass_infra.py index 87a3bf1c3ca7..d27e236a2572 100644 --- a/tutorials/dev/relay_pass_infra.py +++ b/tutorials/dev/relay_pass_infra.py @@ -99,7 +99,7 @@ def alter_conv2d(attrs, inputs, tinfos): # Let's first create a relay Module which contains one or multiple Relay # functions for optimization. f = example() -mod = relay.Module.from_expr(f) +mod = tvm.IRModule.from_expr(f) # Now we can apply constant folding on the module. # fold_const here is a callback that doesn't take any parameters. @@ -151,7 +151,7 @@ def alter_conv2d(attrs, inputs, tinfos): # Now let's execute some passes through `Sequential`_ f = example() -mod = relay.Module.from_expr(f) +mod = tvm.IRModule.from_expr(f) # Glob the interested passes. seq = relay.transform.Sequential([relay.transform.FoldConstant(), relay.transform.EliminateCommonSubexpr(), @@ -228,7 +228,7 @@ def visit_const(self, c): return ReplaceConstant().visit(func) f = example() -mod = relay.Module.from_expr(f) +mod = tvm.IRModule.from_expr(f) custom_pass = CustomPipeline(multiplier=relay.const(3, "float")) assert custom_pass.info.name == "CustomPipeline" mod3 = custom_pass(mod) @@ -243,7 +243,7 @@ def visit_const(self, c): # them. f = example() -mod = relay.Module.from_expr(f) +mod = tvm.IRModule.from_expr(f) seq = relay.transform.Sequential([relay.transform.FoldConstant(), relay.transform.PrintIR(), relay.transform.EliminateCommonSubexpr(), diff --git a/vta/python/vta/build_module.py b/vta/python/vta/build_module.py index cec217cbd393..df67faaac2bf 100644 --- a/vta/python/vta/build_module.py +++ b/vta/python/vta/build_module.py @@ -33,7 +33,7 @@ def early_rewrite(stmt): """Try to do storage rewrite in early pass.""" try: return tvm.ir_pass.StorageRewrite(stmt) - except tvm.TVMError: + except tvm.error.TVMError: return stmt diff --git a/vta/python/vta/top/graphpack.py b/vta/python/vta/top/graphpack.py index b14f937b35df..76b3dc54b113 100644 --- a/vta/python/vta/top/graphpack.py +++ b/vta/python/vta/top/graphpack.py @@ -17,6 +17,7 @@ # pylint: disable=unused-argument """A Relay implementation of graph packing.""" +import tvm from tvm import relay from tvm.relay import op, transform from tvm.relay import ExprMutator @@ -24,7 +25,7 @@ def run_opt_pass(expr, opt_pass): """Exectue a relay pass.""" assert isinstance(opt_pass, transform.Pass) - mod = relay.Module.from_expr(expr) + mod = tvm.IRModule.from_expr(expr) mod = opt_pass(mod) entry = mod["main"] return entry if isinstance(expr, relay.Function) else entry.body diff --git a/vta/tutorials/autotvm/tune_relay_vta.py b/vta/tutorials/autotvm/tune_relay_vta.py index 4cf08e93ba14..3221c3b77b1f 100644 --- a/vta/tutorials/autotvm/tune_relay_vta.py +++ b/vta/tutorials/autotvm/tune_relay_vta.py @@ -353,7 +353,7 @@ def tune_and_evaluate(tuning_opt): # Perform task extraction on Relay program print("Extract tasks...") relay_prog, params = compile_network(env, target, network, start_pack, stop_pack) - mod = relay.Module.from_expr(relay_prog) + mod = tvm.IRModule.from_expr(relay_prog) tasks = autotvm.task.extract_from_program(mod, params=params, ops=(tvm.relay.op.nn.conv2d, ),