diff --git a/docs/api/python/ir.rst b/docs/api/python/ir.rst index 1f0dc0c5e23ca..c2a1a1e106d57 100644 --- a/docs/api/python/ir.rst +++ b/docs/api/python/ir.rst @@ -21,3 +21,11 @@ tvm.ir :members: :imported-members: :autosummary: + + +tvm.transform +------------- +.. automodule:: tvm.transform + :members: + :imported-members: + :autosummary: diff --git a/docs/dev/convert_layout.rst b/docs/dev/convert_layout.rst index 715d810321985..7345c15b6702a 100644 --- a/docs/dev/convert_layout.rst +++ b/docs/dev/convert_layout.rst @@ -227,7 +227,7 @@ ConvertLayout pass is extremely easy to use. The pass is not a part of default r # Convert the layout to NCHW # RemoveUnunsedFunctions is used to clean up the graph. - seq = relay.transform.Sequential([relay.transform.RemoveUnusedFunctions(), + seq = tvm.transform.Sequential([relay.transform.RemoveUnusedFunctions(), relay.transform.ConvertLayout('NCHW')]) with relay.transform.PassContext(opt_level=3): mod = seq(mod) diff --git a/docs/dev/relay_pass_infra.rst b/docs/dev/relay_pass_infra.rst index b42e12852f711..3b443fab9e572 100644 --- a/docs/dev/relay_pass_infra.rst +++ b/docs/dev/relay_pass_infra.rst @@ -582,7 +582,7 @@ using ``Sequential`` associated with other types of passes. func = relay.Function([x], z2) # Customize the optimization pipeline. - seq = _transform.Sequential([ + seq = tvm.transform.Sequential([ relay.transform.InferType(), relay.transform.FoldConstant(), relay.transform.EliminateCommonSubexpr(), @@ -609,7 +609,7 @@ sequential pass example could be like the following to enable IR dumping for .. code:: python - seq = _transform.Sequential([ + seq = tvm.transform.Sequential([ relay.transform.InferType(), relay.transform.FoldConstant(), relay.transform.PrintIR(), diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 4c55204547b99..3680f6db9afec 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -361,9 +361,11 @@ TVM_DLL Pass CreateModulePass( /*! * \brief A special trace pass that prints the header and IR to LOG(INFO). + * \param header The header to be attached to the output. + * \param show_meta_data Whether should we show meta data. * \return The pass. */ -TVM_DLL Pass PrintIR(std::string header); +TVM_DLL Pass PrintIR(std::string header = "", bool show_meta_data = false); } // namespace transform } // namespace tvm diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index e091cd12a2085..9a881cfa6d5be 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -106,7 +106,7 @@ def _update_global_key(item, _): "relay.PassInfo": _rename("transform.PassInfo"), "relay.PassContext": _rename("transform.PassContext"), "relay.ModulePass": _rename("transform.ModulePass"), - "relay.Sequantial": _rename("transform.Sequantial"), + "relay.Sequential": _rename("transform.Sequential"), # TIR "Variable": _update_tir_var("tir.Var"), "SizeVar": _update_tir_var("tir.SizeVar"), diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index da74fb227a2e1..614f9690903ab 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -329,7 +329,7 @@ def create_module_pass(pass_arg): return create_module_pass -def PrintIR(header): +def PrintIR(header="", show_meta_data=False): """A special trace pass that prints the header and IR. Parameters @@ -337,8 +337,11 @@ def PrintIR(header): header : str The header to be displayed along with the dump. + show_meta_data : bool + A boolean flag to indicate if meta data should be printed. + Returns -------- The pass """ - return _ffi_transform_api.PrintIR(header) + return _ffi_transform_api.PrintIR(header, show_meta_data) diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 1517cf9484cf6..4e520198664c1 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -128,20 +128,9 @@ # Scope builder ScopeBuilder = scope_builder.ScopeBuilder -module_pass = transform.module_pass -function_pass = transform.function_pass - # Parser fromtext = parser.fromtext # Param Serialization save_param_dict = param_dict.save_param_dict load_param_dict = param_dict.load_param_dict - -# Pass manager -PassInfo = transform.PassInfo -PassContext = transform.PassContext -Pass = transform.Pass -ModulePass = transform.ModulePass -FunctionPass = transform.FunctionPass -Sequential = transform.Sequential diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index 9c4be2975d6cd..213a6c6f4c258 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -210,10 +210,10 @@ def optimize(self): opt_mod : tvm.IRModule The optimized module. """ - seq = transform.Sequential([transform.SimplifyInference(), - transform.FuseOps(0), - transform.ToANormalForm(), - transform.InferType()]) + seq = tvm.transform.Sequential([transform.SimplifyInference(), + transform.FuseOps(0), + transform.ToANormalForm(), + transform.InferType()]) return seq(self.mod) def _make_executor(self, expr=None): diff --git a/python/tvm/relay/qnn/transform.py b/python/tvm/relay/qnn/transform.py index 6d38490b2d194..492c73993bae4 100644 --- a/python/tvm/relay/qnn/transform.py +++ b/python/tvm/relay/qnn/transform.py @@ -60,7 +60,7 @@ def @main(%quantized_data: Tensor[(200), int32]) -> Tensor[(200), int8] { Returns ------- - ret : tvm.relay.Pass + ret : tvm.transform.Pass The registered pass that canonicalizes QNN ops to Relay ops. """ @@ -108,7 +108,7 @@ def Legalize(): Returns ------- - ret : tvm.relay.Pass + ret : tvm.transform.Pass The registered pass that legalizes QNN ops. """ diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 2ad4e18771d78..958d0dc5d6ceb 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -17,6 +17,7 @@ #pylint: disable=unused-argument, not-context-manager """Automatic quantization toolkit.""" import tvm.ir +import tvm from tvm.runtime import Object from . import _quantize @@ -240,7 +241,7 @@ def partition(): Returns ------- - ret: tvm.relay.Pass + ret: tvm.transform.Pass The registered pass for VTA rewrite. """ return _quantize.QuantizePartition() @@ -253,7 +254,7 @@ def annotate(): Returns ------- - ret: tvm.relay.Pass + ret: tvm.transform.Pass The registered pass for quantization annotation. """ return _quantize.QuantizeAnnotate() @@ -267,7 +268,7 @@ def realize(): Returns ------- - ret: tvm.relay.Pass + ret: tvm.transform.Pass The registered pass for quantization realization. """ return _quantize.QuantizeRealize() @@ -298,11 +299,12 @@ def prerequisite_optimize(mod, params=None): """ Prerequisite optimization passes for quantization. Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and "CanonicalizeOps" optimization before quantization. """ - optimize = _transform.Sequential([_transform.SimplifyInference(), - _transform.FoldConstant(), - _transform.FoldScaleAxis(), - _transform.CanonicalizeOps(), - _transform.FoldConstant()]) + optimize = tvm.transform.Sequential( + [_transform.SimplifyInference(), + _transform.FoldConstant(), + _transform.FoldScaleAxis(), + _transform.CanonicalizeOps(), + _transform.FoldConstant()]) if params: mod['main'] = _bind_params(mod['main'], params) @@ -336,19 +338,20 @@ def quantize(mod, params=None, dataset=None): """ mod = prerequisite_optimize(mod, params) - calibrate_pass = _transform.module_pass(calibrate(dataset), opt_level=1, - name="QuantizeCalibrate") + calibrate_pass = tvm.transform.module_pass( + calibrate(dataset), opt_level=1, + name="QuantizeCalibrate") quant_passes = [partition(), annotate(), calibrate_pass] if not current_qconfig().do_simulation: quant_passes.append(realize()) quant_passes.append(_transform.FoldConstant()) - quantize_seq = _transform.Sequential(quant_passes) - with _transform.PassContext(opt_level=3, - required_pass=["QuantizeAnnotate", - "QuantizeCalibrate", - "QuantizeRealize"]): + quantize_seq = tvm.transform.Sequential(quant_passes) + with tvm.transform.PassContext(opt_level=3, + required_pass=["QuantizeAnnotate", + "QuantizeCalibrate", + "QuantizeRealize"]): with quantize_context(): mod = quantize_seq(mod) diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index 54c909179e4f0..58c6fe89831a2 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -47,7 +47,7 @@ from ..transform import gradient def run_opt_pass(expr, opt_pass): - assert isinstance(opt_pass, transform.Pass) + assert isinstance(opt_pass, tvm.transform.Pass) mod = tvm.IRModule.from_expr(expr) mod = opt_pass(mod) entry = mod["main"] diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py index eec5e16fdd139..61a04ec392dd4 100644 --- a/python/tvm/relay/testing/py_converter.py +++ b/python/tvm/relay/testing/py_converter.py @@ -95,8 +95,8 @@ def optimize(self, prog: Expr): # necessary pass: SimplifyInference (otherwise we can't generate code for some operators) # and fusion (to get primitive functions) - opts = relay.transform.Sequential([relay.transform.SimplifyInference(), - relay.transform.FuseOps(fuse_opt_level=0)]) + opts = tvm.transform.Sequential([relay.transform.SimplifyInference(), + relay.transform.FuseOps(fuse_opt_level=0)]) mod = opts(mod) optimized = mod['main'] return optimized if isinstance(unwrapped, Function) else optimized.body diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 918894f69603f..292c5fd39acb9 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -22,10 +22,9 @@ import inspect import functools -import tvm +import tvm.ir from tvm import te from tvm.runtime import ndarray as _nd -from tvm.ir.transform import PassInfo, PassContext, Pass, ModulePass, Sequential, module_pass from tvm import relay from . import _ffi_api @@ -78,12 +77,13 @@ def build_config(opt_level=2, pass_context: PassContext The pass context for optimizations. """ - return PassContext(opt_level, fallback_device, required_pass, - disabled_pass, trace) + return tvm.ir.transform.PassContext( + opt_level, fallback_device, required_pass, + disabled_pass, trace) @tvm._ffi.register_object("relay.FunctionPass") -class FunctionPass(Pass): +class FunctionPass(tvm.ir.transform.Pass): """A pass that works on each tvm.relay.Function in a module. A function pass class should be created through `function_pass`. """ @@ -94,7 +94,7 @@ def InferType(): Returns ------- - ret : tvm.relay.Pass + ret : tvm.transform.Pass The registered type inference pass. """ return _ffi_api.InferType() @@ -106,7 +106,7 @@ def FoldScaleAxis(): Returns ------- - ret : tvm.relay.Pass + ret : tvm.transform.Pass The registered pass to fold expressions. Note @@ -123,7 +123,7 @@ def BackwardFoldScaleAxis(): Returns ------- - ret : tvm.relay.Pass + ret : tvm.transform.Pass The registered pass to backward fold expressions. Note @@ -144,7 +144,7 @@ def RemoveUnusedFunctions(entry_functions=None): Returns ------- - ret : tvm.relay.Pass + ret : tvm.transform.Pass The registered pass to remove unused functions. """ if entry_functions is None: @@ -156,7 +156,7 @@ def ForwardFoldScaleAxis(): Returns ------- - ret : tvm.relay.Pass + ret : tvm.transform.Pass The registered pass to forward fold expressions. Note @@ -174,7 +174,7 @@ def SimplifyInference(): Returns ------- - ret: tvm.relay.Pass + ret: tvm.transform.Pass The registered pass to perform operator simplification. """ return _ffi_api.SimplifyInference() @@ -185,7 +185,7 @@ def FastMath(): Returns ------- - ret: tvm.relay.Pass + ret: tvm.transform.Pass The registered pass to perform fast math operations. """ return _ffi_api.FastMath() @@ -198,7 +198,7 @@ def CanonicalizeOps(): Returns ------- - ret: tvm.relay.Pass + ret: tvm.transform.Pass The registered pass performing the canonicalization. """ return _ffi_api.CanonicalizeOps() @@ -214,7 +214,7 @@ def DeadCodeElimination(inline_once=False): Returns ------- - ret: tvm.relay.Pass + ret: tvm.transform.Pass The registered pass that eliminates the dead code in a Relay program. """ return _ffi_api.DeadCodeElimination(inline_once) @@ -227,7 +227,7 @@ def LazyGradientInit(): Returns ------- - ret: tvm.relay.Pass + ret: tvm.transform.Pass A pass which delays and/or reduces memory allocation, by lazily allocating 0 or one filled tensors. """ @@ -238,7 +238,7 @@ def FoldConstant(): Returns ------- - ret : tvm.relay.Pass + ret : tvm.transform.Pass The registered pass for constant folding. """ return _ffi_api.FoldConstant() @@ -255,7 +255,7 @@ def FuseOps(fuse_opt_level=-1): Returns ------- - ret : tvm.relay.Pass + ret : tvm.transform.Pass The registered pass for operator fusion. """ return _ffi_api.FuseOps(fuse_opt_level) @@ -272,7 +272,7 @@ def CombineParallelConv2D(min_num_branches=3): Returns ------- - ret: tvm.relay.Pass + ret: tvm.transform.Pass The registered pass that combines parallel conv2d operators. """ return _ffi_api.CombineParallelConv2D(min_num_branches) @@ -304,7 +304,7 @@ def CombineParallelDense(min_num_branches=3): Returns ------- - ret: tvm.relay.Pass + ret: tvm.transform.Pass The registered pass that combines parallel dense operators. """ return _ffi_api.CombineParallelDense(min_num_branches) @@ -318,7 +318,7 @@ def AlterOpLayout(): Returns ------- - ret : tvm.relay.Pass + ret : tvm.transform.Pass The registered pass that alters the layout of operators. """ return _ffi_api.AlterOpLayout() @@ -366,7 +366,7 @@ def Legalize(legalize_map_attr_name="FTVMLegalize"): Returns ------- - ret : tvm.relay.Pass + ret : tvm.transform.Pass The registered pass that rewrites an expr. """ return _ffi_api.Legalize(legalize_map_attr_name) @@ -387,7 +387,7 @@ def MergeComposite(pattern_table): Returns ------- - ret : tvm.relay.Pass + ret : tvm.transform.Pass The registered pass that merges operators into a single composite relay function. """ @@ -413,7 +413,7 @@ def MergeCompilerRegions(): Returns ------- - ret : tvm.relay.Pass + ret : tvm.transform.Pass The registered pass that merges compiler regions. """ return _ffi_api.MergeCompilerRegions() @@ -433,7 +433,7 @@ def RewriteAnnotatedOps(fallback_device): Returns ------- - ret: tvm.relay.Pass + ret: tvm.transform.Pass The registered pass that rewrites an expression with annotated `on_device` operators. """ @@ -448,7 +448,7 @@ def ToANormalForm(): Returns ------- - ret: Union[tvm.relay.Pass, tvm.relay.Expr] + ret: Union[tvm.transform.Pass, tvm.relay.Expr] The registered pass that transforms an expression into A Normal Form. """ return _ffi_api.ToANormalForm() @@ -462,7 +462,7 @@ def ToCPS(expr, mod=None): Returns ------- - result: tvm.relay.Pass + result: tvm.transform.Pass The registered pass that transforms an expression into CPS. """ return _ffi_api.to_cps(expr, mod) @@ -481,7 +481,7 @@ def EtaExpand(expand_constructor=False, expand_global_var=False): Returns ------- - ret: tvm.relay.Pass + ret: tvm.transform.Pass The registered pass that eta expands an expression. """ return _ffi_api.EtaExpand(expand_constructor, expand_global_var) @@ -492,7 +492,7 @@ def ToGraphNormalForm(): Returns ------- - ret : tvm.relay.Pass + ret : tvm.transform.Pass The registered pass that transforms an expression into Graph Normal Form. """ return _ffi_api.ToGraphNormalForm() @@ -509,7 +509,7 @@ def EliminateCommonSubexpr(fskip=None): Returns ------- - ret : tvm.relay.Pass + ret : tvm.transform.Pass The registered pass that eliminates common subexpressions. """ return _ffi_api.EliminateCommonSubexpr(fskip) @@ -527,7 +527,7 @@ def PartialEvaluate(): Returns ------- - ret: tvm.relay.Pass + ret: tvm.transform.Pass The registered pass that performs partial evaluation on an expression. """ return _ffi_api.PartialEvaluate() @@ -539,7 +539,7 @@ def CanonicalizeCast(): Returns ------- - ret : tvm.relay.Pass + ret : tvm.transform.Pass The registered pass that canonicalizes cast expression. """ return _ffi_api.CanonicalizeCast() @@ -551,36 +551,19 @@ def LambdaLift(): Returns ------- - ret : tvm.relay.Pass + ret : tvm.transform.Pass The registered pass that lifts the lambda function. """ return _ffi_api.LambdaLift() -def PrintIR(show_meta_data=True): - """ - Print the IR for a module to help debugging. - - Parameters - ---------- - show_meta_data : bool - A boolean flag to indicate if meta data should be printed. - - Returns - ------- - ret : tvm.relay.Pass - The registered pass that prints the module IR. - """ - return _ffi_api.PrintIR(show_meta_data) - - def PartitionGraph(): """Partition a Relay program into regions that can be executed on different backends. Returns ------- - ret: tvm.relay.Pass + ret: tvm.transform.Pass The registered pass that partitions the Relay program. """ return _ffi_api.PartitionGraph() @@ -598,7 +581,7 @@ def AnnotateTarget(targets): Returns ------- - ret : tvm.relay.Pass + ret : tvm.transform.Pass The annotated pass that wrapps ops with subgraph_start and subgraph_end. """ @@ -614,7 +597,7 @@ def Inline(): Returns ------- - ret: tvm.relay.Pass + ret: tvm.transform.Pass The registered pass that performs inlining for a Relay IR module. """ return _ffi_api.Inline() @@ -809,7 +792,7 @@ def transform(func, mod, ctx): def create_function_pass(pass_arg): """Internal function that creates a function pass""" fname = name if name else pass_arg.__name__ - info = PassInfo(opt_level, fname, required) + info = tvm.transform.PassInfo(opt_level, fname, required) if inspect.isclass(pass_arg): return _wrap_class_function_pass(pass_arg, info) if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 0161cb377f0d9..c1547d5205a4d 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -474,10 +474,10 @@ TVM_REGISTER_GLOBAL("transform.ExitPassContext") .set_body_typed(PassContext::Internal::ExitScope); -Pass PrintIR(std::string header) { - auto pass_func =[header](IRModule mod, const PassContext& ctx) { +Pass PrintIR(std::string header, bool show_meta_data) { + auto pass_func =[header, show_meta_data](IRModule mod, const PassContext& ctx) { LOG(INFO) << "PrintIR(" << header << "):\n" - << mod; + << AsText(mod, show_meta_data); return mod; }; return CreateModulePass(pass_func, 0, "PrintIR", {}); diff --git a/src/relay/transforms/print_ir.cc b/src/relay/transforms/print_ir.cc deleted file mode 100644 index cf06b5004d3d7..0000000000000 --- a/src/relay/transforms/print_ir.cc +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * - * \file src/relay/transforms/print_ir.cc - * - * \brief Print the module IR to help debugging. - */ -#include -#include - -namespace tvm { -namespace relay { - -namespace transform { - -Pass PrintIR(bool show_meta_data) { - runtime::TypedPackedFunc pass_func = - [=](IRModule m, PassContext pc) { - LOG(INFO) << "Dumping the module IR: " << std::endl << AsText(m, show_meta_data); - return m; - }; - return CreateModulePass(pass_func, 0, "PrintIR", {}); -} - -TVM_REGISTER_GLOBAL("relay._transform.PrintIR") -.set_body_typed(PrintIR); - -} // namespace transform - -} // namespace relay -} // namespace tvm diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 30e25067fb010..5e57c802e6519 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -53,10 +53,10 @@ def test_checkpoint_alpha_equal(): df = transform.gradient(run_infer_type(f)) # run PE and DCE - with transform.PassContext(opt_level=3): + with tvm.transform.PassContext(opt_level=3): passes = [transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)] - mod = transform.Sequential(passes)(tvm.IRModule.from_expr(df)) + mod = tvm.transform.Sequential(passes)(tvm.IRModule.from_expr(df)) df = mod["main"] df_parsed = relay.parser.fromtext( @@ -109,10 +109,10 @@ def test_checkpoint_alpha_equal_tuple(): df = transform.gradient(run_infer_type(f)) # run PE and DCE - with transform.PassContext(opt_level=3): + with tvm.transform.PassContext(opt_level=3): passes = [transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)] - mod = transform.Sequential(passes)(tvm.IRModule.from_expr(df)) + mod = tvm.transform.Sequential(passes)(tvm.IRModule.from_expr(df)) df = mod["main"] df_parsed = relay.parser.fromtext( diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index a30492f116343..2a2e265dbe5bb 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -26,8 +26,8 @@ def run_opt_pass(expr, passes): passes = passes if isinstance(passes, list) else [passes] mod = tvm.IRModule.from_expr(expr) - seq = transform.Sequential(passes) - with transform.PassContext(opt_level=3): + seq = tvm.transform.Sequential(passes) + with tvm.transform.PassContext(opt_level=3): mod = seq(mod) entry = mod["main"] return entry if isinstance(expr, relay.Function) else entry.body diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index ea92546fa1d28..582d46aa40cf8 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -28,8 +28,8 @@ def run_opt_pass(expr, passes): passes = passes if isinstance(passes, list) else [passes] mod = tvm.IRModule.from_expr(expr) - seq = transform.Sequential(passes) - with transform.PassContext(opt_level=3): + seq = tvm.transform.Sequential(passes) + with tvm.transform.PassContext(opt_level=3): mod = seq(mod) return mod["main"] diff --git a/tests/python/relay/test_pass_canonicalize_cast.py b/tests/python/relay/test_pass_canonicalize_cast.py index 7b6617a3c7f4c..e13547b0aca49 100644 --- a/tests/python/relay/test_pass_canonicalize_cast.py +++ b/tests/python/relay/test_pass_canonicalize_cast.py @@ -54,9 +54,9 @@ def check(shape): bias2 = relay.var("bias2", shape=(16, 1, 1), dtype="int32") y = before(data, conv_weight, bias1, bias2) mod = tvm.IRModule.from_expr(y) - seq = _transform.Sequential([_transform.InferType(), _transform.CanonicalizeCast(), + seq = tvm.transform.Sequential([_transform.InferType(), _transform.CanonicalizeCast(), _transform.InferType()]) - with _transform.PassContext(opt_level=3): + with tvm.transform.PassContext(opt_level=3): mod = seq(mod) y = mod["main"] y_expected = expected(data, conv_weight, bias1, bias2) diff --git a/tests/python/relay/test_pass_combine_parallel_conv2d.py b/tests/python/relay/test_pass_combine_parallel_conv2d.py index 345f068e39d25..7f7f18598589d 100644 --- a/tests/python/relay/test_pass_combine_parallel_conv2d.py +++ b/tests/python/relay/test_pass_combine_parallel_conv2d.py @@ -26,7 +26,7 @@ def run_combine_parallel(expr, min_num_branches=3): return mod["main"] def run_opt_pass(expr, opt_pass): - assert isinstance(opt_pass, transform.Pass) + assert isinstance(opt_pass, tvm.transform.Pass) mod = tvm.IRModule.from_expr(expr) mod = opt_pass(mod) return mod["main"] diff --git a/tests/python/relay/test_pass_combine_parallel_dense.py b/tests/python/relay/test_pass_combine_parallel_dense.py index f0f2e1858fb11..12beafb2c578d 100644 --- a/tests/python/relay/test_pass_combine_parallel_dense.py +++ b/tests/python/relay/test_pass_combine_parallel_dense.py @@ -26,7 +26,7 @@ def run_combine_parallel(expr, min_num_branches=3): return mod["main"] def run_opt_pass(expr, opt_pass): - assert isinstance(opt_pass, transform.Pass) + assert isinstance(opt_pass, tvm.transform.Pass) mod = tvm.IRModule.from_expr(expr) mod = opt_pass(mod) return mod["main"] diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index c783971c05686..c5a7b0e0c14db 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -26,8 +26,8 @@ def run_opt_pass(expr, passes): passes = passes if isinstance(passes, list) else [passes] mod = tvm.IRModule.from_expr(expr) - seq = transform.Sequential(passes) - with transform.PassContext(opt_level=3): + seq = tvm.transform.Sequential(passes) + with tvm.transform.PassContext(opt_level=3): mod = seq(mod) entry = mod["main"] return entry if isinstance(expr, relay.Function) else entry.body diff --git a/tests/python/relay/test_pass_dead_code_elimination.py b/tests/python/relay/test_pass_dead_code_elimination.py index 60dfa622ba8b4..35fd444239785 100644 --- a/tests/python/relay/test_pass_dead_code_elimination.py +++ b/tests/python/relay/test_pass_dead_code_elimination.py @@ -47,7 +47,7 @@ def __init__(self): def run_opt_pass(expr, opt_pass): - assert isinstance(opt_pass, transform.Pass) + assert isinstance(opt_pass, tvm.transform.Pass) mod = tvm.IRModule.from_expr(expr) mod = opt_pass(mod) entry = mod["main"] diff --git a/tests/python/relay/test_pass_eliminate_common_subexpr.py b/tests/python/relay/test_pass_eliminate_common_subexpr.py index 89e3b6784a70c..7af524d3ae018 100644 --- a/tests/python/relay/test_pass_eliminate_common_subexpr.py +++ b/tests/python/relay/test_pass_eliminate_common_subexpr.py @@ -24,7 +24,7 @@ def run_opt_pass(expr, opt_pass): - assert isinstance(opt_pass, transform.Pass) + assert isinstance(opt_pass, tvm.transform.Pass) mod = tvm.IRModule.from_expr(expr) mod = opt_pass(mod) entry = mod["main"] diff --git a/tests/python/relay/test_pass_eta_expand.py b/tests/python/relay/test_pass_eta_expand.py index 84ff54a3b21ed..e0a189b5c2eea 100644 --- a/tests/python/relay/test_pass_eta_expand.py +++ b/tests/python/relay/test_pass_eta_expand.py @@ -33,8 +33,8 @@ def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) { @aux } """) - seq = _transform.Sequential([_transform.EtaExpand(expand_global_var=True)]) - with _transform.PassContext(opt_level=3): + seq = tvm.transform.Sequential([_transform.EtaExpand(expand_global_var=True)]) + with tvm.transform.PassContext(opt_level=3): mod = seq(mod) expected = relay.fromtext(r""" v0.0.4 @@ -62,8 +62,8 @@ def @main[A]() -> (fn(A, List[A]) -> List[A]) { Cons } """) - seq = _transform.Sequential([_transform.EtaExpand(expand_constructor=True)]) - with _transform.PassContext(opt_level=3): + seq = tvm.transform.Sequential([_transform.EtaExpand(expand_constructor=True)]) + with tvm.transform.PassContext(opt_level=3): mod = seq(mod) expected = relay.fromtext(r""" v0.0.4 diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py index 3ddafd73b114c..4f44d2b3043fb 100644 --- a/tests/python/relay/test_pass_fold_constant.py +++ b/tests/python/relay/test_pass_fold_constant.py @@ -24,7 +24,7 @@ def run_opt_pass(expr, opt_pass): - assert isinstance(opt_pass, transform.Pass) + assert isinstance(opt_pass, tvm.transform.Pass) mod = tvm.IRModule.from_expr(expr) mod = opt_pass(mod) @@ -174,7 +174,7 @@ def expected(): add = relay.add(conv, bias) return relay.Function(relay.analysis.free_vars(add), add) - remove_bn_pass = transform.Sequential([ + remove_bn_pass = tvm.transform.Sequential([ relay.transform.InferType(), relay.transform.SimplifyInference(), relay.transform.FoldConstant(), diff --git a/tests/python/relay/test_pass_fold_scale_axis.py b/tests/python/relay/test_pass_fold_scale_axis.py index bf2a708ceea9d..d7c437adcc994 100644 --- a/tests/python/relay/test_pass_fold_scale_axis.py +++ b/tests/python/relay/test_pass_fold_scale_axis.py @@ -26,7 +26,7 @@ def _get_positive_scale(size): def run_opt_pass(expr, opt_pass): - assert isinstance(opt_pass, transform.Pass) + assert isinstance(opt_pass, tvm.transform.Pass) mod = tvm.IRModule.from_expr(expr) mod = opt_pass(mod) entry = mod["main"] diff --git a/tests/python/relay/test_pass_lazy_gradient_init.py b/tests/python/relay/test_pass_lazy_gradient_init.py index f9c762e5f9055..414926802870a 100644 --- a/tests/python/relay/test_pass_lazy_gradient_init.py +++ b/tests/python/relay/test_pass_lazy_gradient_init.py @@ -80,7 +80,7 @@ def test_add_tuple(): mod["main"] = y mod = transform.LazyGradientInit()(mod) - mod = transform.PrintIR(show_meta_data=True)(mod) + mod = tvm.transform.PrintIR(show_meta_data=True)(mod) y = mod["main"] assert mod["main"].checked_type == relay.FuncType([t], tensor_type) @@ -116,7 +116,7 @@ def test_mult(): def test_ret_tuple(): """Test tuple return type. Check types and semantic equivalence.""" mod = tvm.IRModule() - + shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) @@ -141,7 +141,7 @@ def test_ret_tuple(): def test_add_broadcast(): """Test adding matrices of different size. Check types and semantic equivalence.""" mod = tvm.IRModule() - + shape1 = (3, 4, 1) shape2 = (1, 5) dtype = 'float32' @@ -173,7 +173,7 @@ def test_reverse_ad_identity(): """Simple test with reverse mode ad.""" # of f(x) = x mod = tvm.IRModule() - + shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) @@ -201,7 +201,7 @@ def test_reverse_ad_identity(): def test_multivar_reverse_ad(): """Simple test with multivariate reverse mode ad.""" mod = tvm.IRModule() - + shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) @@ -232,7 +232,7 @@ def test_multivar_reverse_ad(): def test_after_partial_eval(): """Test transformation following reverse mode ad and PartialEval""" mod = tvm.IRModule() - + shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) @@ -248,7 +248,7 @@ def test_after_partial_eval(): mod["main"] = back_func back_func = mod["main"] - seq = transform.Sequential([ + seq = tvm.transform.Sequential([ transform.PartialEvaluate(), transform.LazyGradientInit(), transform.DeadCodeElimination() @@ -270,7 +270,7 @@ def test_after_partial_eval(): def test_before_partial_eval(): """Test transformation before PartialEval""" mod = tvm.IRModule() - + shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) @@ -284,7 +284,7 @@ def test_before_partial_eval(): back_func = run_infer_type(back_func) mod["main"] = back_func - seq = transform.Sequential([ + seq = tvm.transform.Sequential([ transform.LazyGradientInit(), transform.PartialEvaluate(), transform.DeadCodeElimination() @@ -306,7 +306,7 @@ def test_before_partial_eval(): def test_zeros(): """Simple test using "zeros" op""" mod = tvm.IRModule() - + shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) @@ -328,7 +328,7 @@ def test_zeros(): def test_ones(): """Simple test using "ones" op""" mod = tvm.IRModule() - + shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) @@ -350,7 +350,7 @@ def test_ones(): def test_zeros_like(): """Simple test using "zeros_like" op""" mod = tvm.IRModule() - + shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) @@ -372,7 +372,7 @@ def test_zeros_like(): def test_ones_like(): """Simple test using "ones_like" op""" mod = tvm.IRModule() - + shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) diff --git a/tests/python/relay/test_pass_legalize.py b/tests/python/relay/test_pass_legalize.py index 1456700c46276..0882149d575df 100644 --- a/tests/python/relay/test_pass_legalize.py +++ b/tests/python/relay/test_pass_legalize.py @@ -28,8 +28,8 @@ def run_opt_pass(expr, passes): passes = passes if isinstance(passes, list) else [passes] mod = tvm.IRModule.from_expr(expr) - seq = transform.Sequential(passes) - with transform.PassContext(opt_level=3): + seq = tvm.transform.Sequential(passes) + with tvm.transform.PassContext(opt_level=3): mod = seq(mod) entry = mod["main"] return entry if isinstance(expr, relay.Function) else entry.body diff --git a/tests/python/relay/test_pass_mac_count.py b/tests/python/relay/test_pass_mac_count.py index 697aad8eedb7d..d490ac7505ece 100644 --- a/tests/python/relay/test_pass_mac_count.py +++ b/tests/python/relay/test_pass_mac_count.py @@ -23,7 +23,7 @@ def run_opt_pass(expr, opt_pass): - assert isinstance(opt_pass, transform.Pass) + assert isinstance(opt_pass, tvm.transform.Pass) mod = tvm.IRModule.from_expr(expr) mod = opt_pass(mod) entry = mod["main"] diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index 0a6555b5c5be5..28ccf6f5f9411 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -129,13 +129,13 @@ def test_module_pass(): opt_tester = OptTester(mod) pass_ctx = None - @_transform.module_pass(opt_level=opt_level, name=pass_name) + @tvm.transform.module_pass(opt_level=opt_level, name=pass_name) def transform(expr, ctx): return opt_tester.transform(expr, ctx) def test_pass_registration(): mod_pass = transform - assert isinstance(mod_pass, _transform.ModulePass) + assert isinstance(mod_pass, tvm.transform.ModulePass) pass_info = mod_pass.info assert pass_info.name == pass_name assert pass_info.opt_level == opt_level @@ -143,8 +143,8 @@ def test_pass_registration(): def test_pass_registration_no_decorator(): def direct_transform(expr, ctx): return opt_tester.transform(expr, ctx) - mod_pass = _transform.module_pass(direct_transform, opt_level=3) - assert isinstance(mod_pass, _transform.ModulePass) + mod_pass = tvm.transform.module_pass(direct_transform, opt_level=3) + assert isinstance(mod_pass, tvm.transform.ModulePass) pass_info = mod_pass.info assert pass_info.name == "direct_transform" assert pass_info.opt_level == 3 @@ -285,7 +285,7 @@ def test_pass_run(): def test_module_class_pass(): - @relay.transform.module_pass(opt_level=1) + @tvm.transform.module_pass(opt_level=1) class TestPipeline: """Simple test function to replace one argument to another.""" def __init__(self, new_mod, replace): @@ -309,7 +309,7 @@ def transform_module(self, mod, ctx): def test_pass_info(): - info = relay.transform.PassInfo(opt_level=1, name="xyz") + info = tvm.transform.PassInfo(opt_level=1, name="xyz") assert info.opt_level == 1 assert info.name == "xyz" @@ -350,7 +350,7 @@ def get_ref_abs(): opt_tester = OptTester(mod) pass_ctx = None - @_transform.module_pass(opt_level=1) + @tvm.transform.module_pass(opt_level=1) def mod_transform(expr, ctx): return opt_tester.transform(expr, ctx) @@ -367,21 +367,21 @@ def test_pass_registration(): passes = [module_pass, function_pass] opt_level = 2 pass_name = "sequential" - sequential = _transform.Sequential(passes=passes, opt_level=opt_level) + sequential = tvm.transform.Sequential(passes=passes, opt_level=opt_level) pass_info = sequential.info assert pass_info.name == pass_name assert pass_info.opt_level == opt_level def test_no_pass(): passes = [] - sequential = _transform.Sequential(opt_level=1, passes=passes) + sequential = tvm.transform.Sequential(opt_level=1, passes=passes) ret_mod = sequential(mod) mod_func = ret_mod[v_sub] check_func(sub, mod_func) def test_only_module_pass(): passes = [module_pass] - sequential = _transform.Sequential(opt_level=1, passes=passes) + sequential = tvm.transform.Sequential(opt_level=1, passes=passes) with relay.build_config(required_pass=["mod_transform"]): ret_mod = sequential(mod) # Check the subtract function. @@ -396,7 +396,7 @@ def test_only_module_pass(): def test_only_function_pass(): # Check the subtract function. passes = [function_pass] - sequential = _transform.Sequential(opt_level=1, passes=passes) + sequential = tvm.transform.Sequential(opt_level=1, passes=passes) with relay.build_config(required_pass=["func_transform"]): ret_mod = sequential(mod) _, new_sub = extract_var_func(ret_mod, v_sub.name_hint) @@ -411,7 +411,7 @@ def test_multiple_passes(): # function pass. mod = tvm.IRModule({v_sub: sub, v_log: log}) passes = [module_pass, function_pass] - sequential = _transform.Sequential(opt_level=1, passes=passes) + sequential = tvm.transform.Sequential(opt_level=1, passes=passes) required = ["mod_transform", "func_transform"] with relay.build_config(required_pass=required): ret_mod = sequential(mod) @@ -482,7 +482,7 @@ def expected(): z1 = relay.add(z, z) return relay.Function([x], z1) - seq = _transform.Sequential([ + seq = tvm.transform.Sequential([ relay.transform.InferType(), relay.transform.FoldConstant(), relay.transform.EliminateCommonSubexpr(), @@ -507,10 +507,10 @@ def test_print_ir(capfd): y = relay.multiply(y, relay.const(2, "float32")) func = relay.Function([x], y) - seq = _transform.Sequential([ + seq = tvm.transform.Sequential([ relay.transform.InferType(), relay.transform.FoldConstant(), - relay.transform.PrintIR(), + tvm.transform.PrintIR(), relay.transform.DeadCodeElimination() ]) @@ -520,7 +520,7 @@ def test_print_ir(capfd): out = capfd.readouterr().err - assert "Dumping the module IR" in out + assert "PrintIR" in out assert "multiply" in out __TRACE_COUNTER__ = 0 @@ -539,7 +539,7 @@ def test_print_debug_callback(): y = relay.multiply(y, relay.const(2, "float32")) func = relay.Function([x], y) - seq = _transform.Sequential([ + seq = tvm.transform.Sequential([ relay.transform.InferType(), relay.transform.FoldConstant(), relay.transform.DeadCodeElimination() diff --git a/tests/python/relay/test_pass_partial_eval.py b/tests/python/relay/test_pass_partial_eval.py index 0f3eea663f694..45593b43ecb12 100644 --- a/tests/python/relay/test_pass_partial_eval.py +++ b/tests/python/relay/test_pass_partial_eval.py @@ -38,8 +38,8 @@ def check_eval(expr, expected_result, mod=None, rtol=1e-07): def run_opt_pass(expr, passes): passes = passes if isinstance(passes, list) else [passes] mod = tvm.IRModule.from_expr(expr) - seq = transform.Sequential(passes) - with transform.PassContext(opt_level=3): + seq = tvm.transform.Sequential(passes) + with tvm.transform.PassContext(opt_level=3): mod = seq(mod) entry = mod["main"] return entry if isinstance(expr, relay.Function) else entry.body @@ -58,7 +58,7 @@ def dcpe(expr, mod=None, grad=False): if mod: assert isinstance(expr, Function) mod["main"] = expr - seq = transform.Sequential(passes) + seq = tvm.transform.Sequential(passes) mod = seq(mod) return mod["main"] return run_opt_pass(expr, passes) diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 5148d4ef09d13..2ee8538e30ed9 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -496,7 +496,7 @@ def partition(): op_list = ["nn.batch_norm", "nn.conv2d"] mod = WhiteListAnnotator(op_list, "test_compiler")(mod) - opt_pass = transform.Sequential([ + opt_pass = tvm.transform.Sequential([ transform.InferType(), transform.PartitionGraph(), transform.SimplifyInference(), @@ -578,7 +578,7 @@ def partition(): op_list = ["nn.batch_norm", "nn.conv2d"] mod = WhiteListAnnotator(op_list, "test_compiler")(mod) - opt_pass = transform.Sequential([ + opt_pass = tvm.transform.Sequential([ transform.InferType(), transform.PartitionGraph(), transform.SimplifyInference(), @@ -878,13 +878,13 @@ def get_partitoned_mod(mod, params, pattern_table): # This is required for constant folding mod["main"] = bind_params_by_name(mod["main"], params) - remove_bn_pass = transform.Sequential([ + remove_bn_pass = tvm.transform.Sequential([ transform.InferType(), transform.SimplifyInference(), transform.FoldConstant(), transform.FoldScaleAxis(), ]) - composite_partition = transform.Sequential([ + composite_partition = tvm.transform.Sequential([ remove_bn_pass, transform.MergeComposite(pattern_table), transform.AnnotateTarget("dnnl"), diff --git a/tests/python/relay/test_pass_qnn_legalize.py b/tests/python/relay/test_pass_qnn_legalize.py index c291c4e6b1708..5f7deff23b064 100644 --- a/tests/python/relay/test_pass_qnn_legalize.py +++ b/tests/python/relay/test_pass_qnn_legalize.py @@ -37,8 +37,8 @@ def alpha_equal(x, y): def run_opt_pass(expr, passes): passes = passes if isinstance(passes, list) else [passes] mod = tvm.IRModule.from_expr(expr) - seq = transform.Sequential(passes) - with transform.PassContext(opt_level=3): + seq = tvm.transform.Sequential(passes) + with tvm.transform.PassContext(opt_level=3): mod = seq(mod) entry = mod["main"] return entry if isinstance(expr, relay.Function) else entry.body diff --git a/tests/python/relay/test_pass_to_a_normal_form.py b/tests/python/relay/test_pass_to_a_normal_form.py index d7babf37ed2a0..5a63db7d9d063 100644 --- a/tests/python/relay/test_pass_to_a_normal_form.py +++ b/tests/python/relay/test_pass_to_a_normal_form.py @@ -28,8 +28,8 @@ def run_opt_pass(expr, passes): passes = passes if isinstance(passes, list) else [passes] mod = tvm.IRModule.from_expr(expr) - seq = transform.Sequential(passes) - with transform.PassContext(opt_level=3): + seq = tvm.transform.Sequential(passes) + with tvm.transform.PassContext(opt_level=3): mod = seq(mod) entry = mod["main"] return entry if isinstance(expr, relay.Function) else entry.body diff --git a/tests/python/relay/test_pass_to_cps.py b/tests/python/relay/test_pass_to_cps.py index 4aaa9a0f0e9a5..6edf185446b54 100644 --- a/tests/python/relay/test_pass_to_cps.py +++ b/tests/python/relay/test_pass_to_cps.py @@ -71,7 +71,8 @@ def destroy_ref(x): x = run_infer_type(x) y = un_cps(x) y = run_infer_type(y) - x = run_opt_pass(x, transform.Sequential([transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)])) + x = run_opt_pass(x, tvm.transform.Sequential( + [transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)])) assert Feature.fRefCreate not in detect_feature(x) unit = relay.Function([], relay.const(0., dtype='float32')) f_ref = relay.Var("f_ref") diff --git a/tutorials/dev/relay_pass_infra.py b/tutorials/dev/relay_pass_infra.py index 6b844ff24be0e..980d96ccc119f 100644 --- a/tutorials/dev/relay_pass_infra.py +++ b/tutorials/dev/relay_pass_infra.py @@ -29,7 +29,7 @@ The optimizations of a Relay program could be applied at various granularity, namely function-level and module-level using :py:class:`tvm.relay.transform.FunctionPass` and py:class:`tvm.relay.transform.ModulePass` -respectively. Or users can rely on py:class:`tvm.relay.transform.Sequential` to apply a sequence of passes +respectively. Or users can rely on py:class:`tvm.transform.Sequential` to apply a sequence of passes on a Relay program where the dependencies between passes can be resolved by the pass infra. For more details about each type of these passes, please refer to the :ref:`relay-pass-infra` @@ -130,22 +130,22 @@ def alter_conv2d(attrs, inputs, tinfos, out_type): # fusion, as this pass generates let bindings for each expression to # canonicalize a Relay program. # -# Relay, hence, provides :py:class:`tvm.relay.transform.Sequential` to alleviate developers from handling +# Relay, hence, provides :py:class:`tvm.transform.Sequential` to alleviate developers from handling # these issues explicitly by specifying the required passes of each pass and # packing them as a whole to execute. For example, the same passes can now be -# applied using the sequential style as the following. :py:class:`tvm.relay.transform.Sequential` is +# applied using the sequential style as the following. :py:class:`tvm.transform.Sequential` is # similiar to `torch.nn.sequential `_ # and `mxnet.gluon.block `_. # For example, `torch.nn.sequential` is used to contain a sequence of PyTorch # `Modules` that will be added to build a network. It focuses on the network -# layers. Instead, the :py:class:`tvm.relay.transform.Sequential` in our pass infra works on the optimizing +# layers. Instead, the :py:class:`tvm.transform.Sequential` in our pass infra works on the optimizing # pass. -# Now let's execute some passes through :py:class:`tvm.relay.transform.Sequential` +# Now let's execute some passes through :py:class:`tvm.transform.Sequential` f = example() mod = tvm.IRModule.from_expr(f) # Glob the interested passes. -seq = relay.transform.Sequential([relay.transform.FoldConstant(), +seq = tvm.transform.Sequential([relay.transform.FoldConstant(), relay.transform.EliminateCommonSubexpr(), relay.transform.FuseOps(fuse_opt_level=2)]) mod1 = seq(mod) @@ -156,7 +156,7 @@ def alter_conv2d(attrs, inputs, tinfos, out_type): # identical addition operations. This is because `EliminateCommonSubexpr` # was not actually performed. The reason is because only the passes that have # optimization level less or equal to 2 will be executed by default under -# :py:class:`tvm.relay.transform.Sequential`. The pass infra, +# :py:class:`tvm.transform.Sequential`. The pass infra, # however, provides a configuration interface # for users to customize the optimization level that they want to execute. @@ -186,7 +186,7 @@ def alter_conv2d(attrs, inputs, tinfos, out_type): mod4 = seq(mod) print(mod4) -seq1 = relay.transform.Sequential([relay.transform.AlterOpLayout()]) +seq1 = tvm.transform.Sequential([relay.transform.AlterOpLayout()]) with relay.build_config(opt_level=3): with tvm.target.create("llvm"): mod5 = seq1(mod) @@ -237,11 +237,11 @@ def visit_constant(self, c): f = example() mod = tvm.IRModule.from_expr(f) -seq = relay.transform.Sequential([relay.transform.FoldConstant(), - relay.transform.PrintIR(False), - relay.transform.EliminateCommonSubexpr(), - relay.transform.FuseOps(), - relay.transform.PrintIR(False)]) +seq = tvm.transform.Sequential([relay.transform.FoldConstant(), + tvm.transform.PrintIR(), + relay.transform.EliminateCommonSubexpr(), + relay.transform.FuseOps(), + tvm.transform.PrintIR()]) with relay.build_config(opt_level=3): mod = seq(mod) diff --git a/vta/python/vta/top/graphpack.py b/vta/python/vta/top/graphpack.py index aca00a60618de..2334de7e6905e 100644 --- a/vta/python/vta/top/graphpack.py +++ b/vta/python/vta/top/graphpack.py @@ -24,7 +24,7 @@ def run_opt_pass(expr, opt_pass): """Exectue a relay pass.""" - assert isinstance(opt_pass, transform.Pass) + assert isinstance(opt_pass, tvm.transform.Pass) mod = tvm.IRModule.from_expr(expr) mod = opt_pass(mod) entry = mod["main"]