From 703ab73d92fb98b1e604b2ae741ee4db4df57cc9 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Thu, 9 Jan 2020 21:03:35 +0000 Subject: [PATCH] remove annotation helper --- include/tvm/relay/op_attr_types.h | 16 --- python/tvm/relay/__init__.py | 2 +- python/tvm/relay/build_module.py | 29 ----- python/tvm/relay/op/__init__.py | 2 +- python/tvm/relay/op/contrib/__init__.py | 1 - .../tvm/relay/op/contrib/annotate_compiler.py | 119 ------------------ .../relay/op/contrib/ccompiler/__init__.py | 20 --- .../op/contrib/ccompiler/annotate_compiler.py | 34 ----- python/tvm/relay/op/contrib/dnnl/__init__.py | 20 --- .../op/contrib/dnnl/annotate_compiler.py | 50 -------- python/tvm/relay/op/op.py | 20 --- python/tvm/relay/transform.py | 18 --- src/relay/op/annotation/annotation.cc | 4 +- src/relay/pass/annotate_compiler.cc | 102 --------------- src/relay/pass/partition_graph.cc | 10 +- .../python/relay/test_pass_partition_graph.py | 65 ++++++++-- 16 files changed, 63 insertions(+), 449 deletions(-) delete mode 100644 python/tvm/relay/op/contrib/annotate_compiler.py delete mode 100644 python/tvm/relay/op/contrib/ccompiler/__init__.py delete mode 100644 python/tvm/relay/op/contrib/ccompiler/annotate_compiler.py delete mode 100644 python/tvm/relay/op/contrib/dnnl/__init__.py delete mode 100644 python/tvm/relay/op/contrib/dnnl/annotate_compiler.py delete mode 100644 src/relay/pass/annotate_compiler.cc diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h index 1e7dc178c4c01..b6221e0ba8a5b 100644 --- a/include/tvm/relay/op_attr_types.h +++ b/include/tvm/relay/op_attr_types.h @@ -163,22 +163,6 @@ using FTVMLegalize = runtime::TypedPackedFunc< const Array& args, const Array& arg_types)>; -/*! - * \brief Annotates an expression to indicate which compiler an op - * should be used for codegen. - * - * \param attrs The attribute of the original expr. - * \param args The arguments of the original expr. - * \param compiler The compiler that is used to compile the op. - * - * \return true if this op should be registered to invoke a specific compiler - * for codegen, otherwise, false. - */ -using FTVMAnnotateCompiler = runtime::TypedPackedFunc< - bool(const Attrs& attrs, // NOLINT(*) - const Array& args, - const std::string& compiler)>; - /*! * \brief Forward rewriting rule for a specific op. * diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 7901dc4f5074a..c7cbcf096a6cf 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -29,7 +29,7 @@ from . import adt from . import analysis from . import transform -from .build_module import build, create_executor, optimize, build_extern_compiler +from .build_module import build, create_executor, optimize from .transform import build_config from . import prelude from . import parser diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 7775b7ca4c21a..28ce16b9b4523 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -30,7 +30,6 @@ from .module import Module as _Module from .backend import interpreter as _interpreter from .backend.vm import VMExecutor -from . import transform as _transform def _update_target(target): target = target if target else _target.current_target() @@ -297,34 +296,6 @@ def optimize(mod, target=None, params=None): return mod, params -def build_extern_compiler(mod, compiler): - """Helper function that annotates a Relay module and patitions the - expression init into various regions. These regions will be handled - by either default compilers in TVM stack or the provided external compiler. - - Parameters - ---------- - mod : relay.Module - The module to build. Using relay.Function is deprecated. - - compiler : str - The name of the external compiler. - - Returns - ------- - mod : relay.Module - The relay module contains partitioned program regions (e.g. functions) - that will be compiled using different compilers. - """ - if isinstance(mod, _expr.Function): - mod = _Module.from_expr(mod) - - seq = _transform.Sequential([_transform.AnnotateCompiler(compiler), - _transform.PartitionGraph()]) - mod = seq(mod) - return mod - - class GraphExecutor(_interpreter.Executor): """Wrapper around Executor interface. diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index 702573ddeb0d1..a089cab669c92 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -19,7 +19,7 @@ # operator defs from .op import get, register, register_schedule, register_compute, register_gradient, \ register_pattern, register_alter_op_layout, register_legalize, \ - register_annotate_compiler, schedule_injective, Op, OpPattern, debug + schedule_injective, Op, OpPattern, debug # Operators from .reduce import * diff --git a/python/tvm/relay/op/contrib/__init__.py b/python/tvm/relay/op/contrib/__init__.py index b7d6d92b9edd7..3159006486b33 100644 --- a/python/tvm/relay/op/contrib/__init__.py +++ b/python/tvm/relay/op/contrib/__init__.py @@ -18,5 +18,4 @@ """Neural network related operators.""" from __future__ import absolute_import as _abs from .contrib import * -from .annotate_compiler import * from . import _contrib diff --git a/python/tvm/relay/op/contrib/annotate_compiler.py b/python/tvm/relay/op/contrib/annotate_compiler.py deleted file mode 100644 index 4d1eeaeb01cfe..0000000000000 --- a/python/tvm/relay/op/contrib/annotate_compiler.py +++ /dev/null @@ -1,119 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name, unused-argument -""" -External compiler related feature registration. - -It implements dispatchers that check if an operator should use a given compiler -to generate code. - -Each compiler can customize the support of an operator. For example, they can -check the attribute of the operator and/or the features of the input arguments -to decide if we should use the compiler for codegen. -""" -from __future__ import absolute_import - -import logging -import pkgutil -from pathlib import Path -from importlib import import_module - -from .. import op as reg - -logger = logging.getLogger('AnnotateCompiler') - -# Load available contrib compilers -compilers = {} -for _, name, _ in pkgutil.iter_modules([Path(__file__).parent]): - compilers[name] = import_module( - '.%s' % name, package='.'.join(__name__.split('.')[:-1])) - - -def get_annotate_compiler(compiler, op_name): - """Get the annotate_compiler function from the registered compilers. - - Parameters - ---------- - compiler : Str - The name of a compiler that is used to generate code. - - op_name : Str - The name of an operator. - - Returns - ------- - ret : bool - If the operator uses the provided compiler for codegen. - """ - if compiler in compilers: - if hasattr(compilers[compiler], 'annotate_compiler'): - annotate_compiler = getattr(compilers[compiler], 'annotate_compiler') - if hasattr(annotate_compiler, op_name): - return getattr(annotate_compiler, op_name) - - logger.warning("%s in %s is not registered. Fallback to CPU", op_name, - compiler) - return lambda x, y: False - - -@reg.register_annotate_compiler("nn.conv2d") -def annotate_conv2d(attrs, args, compiler): - """Check if the provided compiler should be used for conv2d. - """ - return get_annotate_compiler(compiler, 'conv2d')(attrs, args) - - -@reg.register_annotate_compiler("nn.dense") -def annotate_dense(attrs, args, compiler): - """Check if the provided compiler should be used for dense. - """ - return get_annotate_compiler(compiler, 'dense')(attrs, args) - - -@reg.register_annotate_compiler("nn.relu") -def annotate_relu(attrs, args, compiler): - """Check if the provided compiler should be used for relu. - """ - return get_annotate_compiler(compiler, 'relu')(attrs, args) - - -@reg.register_annotate_compiler("nn.batch_norm") -def annotate_batch_norm(attrs, args, compiler): - """Check if the provided compiler should be used for batch_norm. - """ - return get_annotate_compiler(compiler, 'batch_norm')(attrs, args) - - -@reg.register_annotate_compiler("subtract") -def annotate_subtract(attrs, args, compiler): - """Check if the provided compiler should be used for subtract. - """ - return get_annotate_compiler(compiler, 'subtract')(attrs, args) - - -@reg.register_annotate_compiler("add") -def annotate_add(attrs, args, compiler): - """Check if the provided compiler should be used for add. - """ - return get_annotate_compiler(compiler, 'add')(attrs, args) - - -@reg.register_annotate_compiler("multiply") -def annotate_multiply(attrs, args, compiler): - """Check if the provided compiler should be used for multiply. - """ - return get_annotate_compiler(compiler, 'multiply')(attrs, args) diff --git a/python/tvm/relay/op/contrib/ccompiler/__init__.py b/python/tvm/relay/op/contrib/ccompiler/__init__.py deleted file mode 100644 index fba7a13e160b0..0000000000000 --- a/python/tvm/relay/op/contrib/ccompiler/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=wildcard-import -"""Utilities that are defined in the ccompiler namespace.""" -from __future__ import absolute_import as _abs -from .annotate_compiler import * diff --git a/python/tvm/relay/op/contrib/ccompiler/annotate_compiler.py b/python/tvm/relay/op/contrib/ccompiler/annotate_compiler.py deleted file mode 100644 index 3f6bc110a1481..0000000000000 --- a/python/tvm/relay/op/contrib/ccompiler/annotate_compiler.py +++ /dev/null @@ -1,34 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name, unused-argument -"""C/C++ compiler supported operators.""" -from __future__ import absolute_import - -def subtract(attrs, args): - """Check if the external C source codegen should be used. - """ - return True - -def add(attrs, args): - """Check if the external C source codegen should be used. - """ - return True - -def multiply(attrs, args): - """Check if the external C source codegen should be used. - """ - return True diff --git a/python/tvm/relay/op/contrib/dnnl/__init__.py b/python/tvm/relay/op/contrib/dnnl/__init__.py deleted file mode 100644 index 07abbe951c809..0000000000000 --- a/python/tvm/relay/op/contrib/dnnl/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=wildcard-import -"""Utilities that are defined in the dnnl namespace.""" -from __future__ import absolute_import as _abs -from .annotate_compiler import * diff --git a/python/tvm/relay/op/contrib/dnnl/annotate_compiler.py b/python/tvm/relay/op/contrib/dnnl/annotate_compiler.py deleted file mode 100644 index b527395538d82..0000000000000 --- a/python/tvm/relay/op/contrib/dnnl/annotate_compiler.py +++ /dev/null @@ -1,50 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name, unused-argument -"""DNNL library supported operators.""" -from __future__ import absolute_import - - -def conv2d(attrs, args): - """Check if the external DNNL codegen should be used. - """ - return True - - -def dense(attrs, args): - """Check if the external DNNL codegen should be used. - """ - return True - - -def relu(attrs, args): - """Check if the external DNNL codegen should be used. - """ - return True - - -def batch_norm(attrs, args): - """Check if the external DNNL codegen should be used. - FIXME(@zhiics, @comaniac): Turn off due to not support of multiple outputs. - """ - return False - - -def add(attrs, args): - """Check if the external DNNL codegen should be used. - """ - return True diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 33e0282104855..382f667b86a92 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -246,7 +246,6 @@ def register_pattern(op_name, pattern, level=10): """ return register(op_name, "TOpPattern", pattern, level) - def register_gradient(op_name, fgradient=None, level=10): """Register operator pattern for an op. @@ -284,25 +283,6 @@ def register_shape_func(op_name, data_dependant, shape_func=None, level=10): get(op_name).set_attr("TShapeDataDependant", data_dependant, level) return register(op_name, "FShapeFunc", shape_func, level) -def register_annotate_compiler(op_name, fannotate=None, level=10): - """Register the compiler for an op. - - Parameters - ---------- - op_name : str - The name of the operator. - - fannotate : function (attrs: Attrs, args: List[Expr], compiler: str) - -> new_expr: Expr - The function for wrapping a call expr with compiler_begin and - compiler_end. - - level : int - The priority level - """ - return register(op_name, "FTVMAnnotateCompiler", fannotate, level) - - _init_api("relay.op", __name__) @register_func("relay.op.compiler._lower") diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index 13173f2dac784..c4fbde60a6eb9 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -508,24 +508,6 @@ def Legalize(legalize_map_attr_name="FTVMLegalize"): return _transform.Legalize(legalize_map_attr_name) -def AnnotateCompiler(compiler): - """Annotate ops in an experession with a provied compiler and then use it - for codegen. - - Parameters - ---------- - compiler : str - The compiler used for codegen. - - Returns - ------- - ret : tvm.relay.Pass - The annotated pass that wrapps ops with subgraph_start and - subgraph_end. - """ - return _transform.AnnotateCompiler(compiler) - - def RewriteAnnotatedOps(fallback_device): """Rewrite the annotated program where annotation operators, e.g. `on_deivce`, mark which device an expression should be scheduled to. diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/op/annotation/annotation.cc index 61b9e50cb683a..3d03f884e2470 100644 --- a/src/relay/op/annotation/annotation.cc +++ b/src/relay/op/annotation/annotation.cc @@ -189,7 +189,7 @@ Beginning of a region that is handled by a given compiler. }); TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_begin") -.set_body_typed([](Expr expr, std::string compiler) { +.set_body_typed([](Expr expr, std::string compiler) { auto attrs = make_object(); attrs->compiler = compiler; static const Op& op = Op::Get("annotation.compiler_begin"); @@ -214,7 +214,7 @@ End of a region that is handled by a given compiler. }); TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_end") -.set_body_typed([](Expr expr, std::string compiler) { +.set_body_typed([](Expr expr, std::string compiler) { auto attrs = make_object(); attrs->compiler = compiler; static const Op& op = Op::Get("annotation.compiler_end"); diff --git a/src/relay/pass/annotate_compiler.cc b/src/relay/pass/annotate_compiler.cc deleted file mode 100644 index 4b88a8ca29dc7..0000000000000 --- a/src/relay/pass/annotate_compiler.cc +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/relay/pass/annotate_compiler.cc - * \brief Wraps a call with compiler_begin and compiler_end to indicate that - * the op of this call node will use external compiler. - */ - -#include -#include -#include -#include -#include - -namespace tvm { -namespace relay { -namespace annotate_compiler { - -// A helper class to insert annotation boundaries for a program region that will -// be handled by a specific compiler. -class AnnotateCompilerWrapper : public ExprMutator { - public: - explicit AnnotateCompilerWrapper(const std::string& compiler) : compiler_(compiler) {} - - Expr VisitExpr_(const CallNode* cn) { - auto new_e = ExprMutator::VisitExpr_(cn); - - Call call = Downcast(new_e); - static auto fannotate = Op::GetAttr("FTVMAnnotateCompiler"); - Op op = Downcast(call->op); - CHECK(op.defined()); - - if (fannotate.count(op)) { - bool external = fannotate[op](call->attrs, call->args, compiler_); - if (external) { - tvm::Array compiler_begins; - for (const auto& it : call->args) { - const auto* begin_op = - runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); - CHECK(begin_op); - Expr begin = (*begin_op)(it, compiler_); - compiler_begins.push_back(begin); - } - Expr update_call = CallNode::make(call->op, compiler_begins, call->attrs); - const auto* end_op = - runtime::Registry::Get("relay.op.annotation._make.compiler_end"); - CHECK(end_op); - Expr end = (*end_op)(update_call, compiler_); - return end; - } - } else { - LOG(WARNING) << op->name << " in " << compiler_ << " is not registered"; - } - return new_e; - } - - private: - std::string compiler_; -}; - -Expr AnnotateCompiler(const Expr& expr, const std::string& compiler) { - return AnnotateCompilerWrapper(compiler).Mutate(expr); -} - -} // namespace annotate_compiler - -namespace transform { - -Pass AnnotateCompiler(const std::string& compiler) { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { - return Downcast(relay::annotate_compiler::AnnotateCompiler(f, compiler)); - }; - auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateCompilerFunc", - {ir::StringImm::make("InferType")}); - return transform::Sequential({func_pass, InferType()}, "AnnotateCompiler"); -} - -TVM_REGISTER_GLOBAL("relay._transform.AnnotateCompiler") -.set_body_typed(AnnotateCompiler); - -} // namespace transform - -} // namespace relay -} // namespace tvm diff --git a/src/relay/pass/partition_graph.cc b/src/relay/pass/partition_graph.cc index 748b9fdb6f9cf..6d8ab4cc248b5 100644 --- a/src/relay/pass/partition_graph.cc +++ b/src/relay/pass/partition_graph.cc @@ -61,7 +61,7 @@ struct Subgraph { std::vector> args; /*! \brief Nodes in this subgraph. */ - std::unordered_set nodes; + std::unordered_set nodes; }; /*! @@ -210,10 +210,10 @@ class Partitioner : public ExprMutator { Expr arg0 = call->args[0]; std::string name = compiler_attrs->compiler + "_" + std::to_string(subgraph->id); subgraph_func = - FunctionSetAttr(subgraph_func, attr::kExternalSymbol, tvm::ir::StringImm::make(name)); + FunctionSetAttr(subgraph_func, attr::kExternalSymbol, tvm::ir::StringImmNode::make(name)); subgraph_func = FunctionSetAttr(subgraph_func, attr::kPrimitive, tvm::Integer(1)); subgraph_func = FunctionSetAttr(subgraph_func, attr::kCompiler, - tvm::ir::StringImm::make(compiler_attrs->compiler)); + tvm::ir::StringImmNode::make(compiler_attrs->compiler)); return CallNode::make(subgraph_func, args); } } @@ -367,8 +367,8 @@ Expr PartitionGraph(const Expr& expr) { namespace transform { Pass PartitionGraph() { - runtime::TypedPackedFunc part_func = - [=](Function f, Module m, PassContext pc) { + runtime::TypedPackedFunc part_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(partitioning::PartitionGraph(f)); }; auto partitioned = CreateFunctionPass(part_func, 0, "PartitionGraph", {}); diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index bf574909fb482..4ffb373116968 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -22,12 +22,39 @@ import tvm import tvm.relay.testing -import tvm.relay.transform +import tvm.relay.transform as transform from tvm import relay from tvm.contrib import util from tvm.relay.annotation import compiler_begin, compiler_end from tvm.relay.expr_functor import ExprMutator +# Leverage the pass manager to write a simple white list based annotator +@transform.function_pass(opt_level=0) +class WhiteListAnnotator: + def __init__(self, op_list, compiler): + assert isinstance(op_list, (list, tuple, set)) + self.op_list = op_list + self.compiler = compiler + + def transform_function(self, func, mod, ctx): + + annotator = self + class Annotator(tvm.relay.ExprMutator): + def visit_call(self, call): + op_name = call.op.name + if op_name in annotator.op_list: + new_args = [] + for arg in call.args: + ann = compiler_begin(super().visit(arg), + annotator.compiler) + new_args.append(ann) + new_call = relay.Call(call.op, new_args, call.attrs, + call.type_args) + return compiler_end(new_call, annotator.compiler) + else: + return super().visit_call(call) + return Annotator().visit(func) + class CcompilerAnnotator(ExprMutator): """ @@ -220,8 +247,8 @@ def test_multi_node_compiler(): mod = relay.Module() ann = CcompilerAnnotator() mod["main"] = ann.visit(f) - mod = relay.transform.PartitionGraph()(mod) - mod = relay.transform.InferType()(mod) + mod = transform.PartitionGraph()(mod) + mod = transform.InferType()(mod) x_data = np.random.rand(10, 10).astype('float32') w_data = [] @@ -239,6 +266,19 @@ def test_multi_node_compiler(): def test_extern_ccompiler_single_op(): + @transform.function_pass(opt_level=0) + class MyAnnotator: + def transform_function(self, func, mod, ctx): + class Annotator(tvm.relay.ExprMutator): + def visit_call(self, call): + new_args = [] + for arg in call.args: + ann = compiler_begin(self.visit(arg), "ccompiler") + new_args.append(ann) + new_call = relay.Call(call.op, new_args) + return compiler_end(new_call, "ccompiler") + return Annotator().visit(func) + x = relay.var('x', shape=(8, 8)) y = relay.var('y', shape=(8, 8)) z = x + y @@ -247,7 +287,8 @@ def test_extern_ccompiler_single_op(): y_data = np.random.rand(8, 8).astype('float32') mod = relay.Module() mod["main"] = f - mod = relay.build_extern_compiler(mod, "ccompiler") + mod = MyAnnotator()(mod) + mod = transform.PartitionGraph()(mod) check_result(mod, {"x": x_data, "y": y_data}, (8, 8), x_data + y_data) @@ -290,9 +331,10 @@ def expected(): f = relay.Function([x, y], concat) mod = relay.Module() mod["main"] = f - mod = relay.build_extern_compiler(mod, "ccompiler") + mod = WhiteListAnnotator(["add", "subtract", "multiply"], "ccompiler")(mod) + mod = transform.PartitionGraph()(mod) - fused_mod = relay.transform.FuseOps(2)(mod) + fused_mod = transform.FuseOps(2)(mod) expected_mod = expected() assert relay.alpha_equal(fused_mod, expected_mod) @@ -313,7 +355,8 @@ def test_extern_ccompiler(): y_data = np.random.rand(2, 2).astype('float32') mod = relay.Module() mod["main"] = f - mod = relay.build_extern_compiler(mod, "ccompiler") + mod = WhiteListAnnotator(["add", "subtract", "multiply"], "ccompiler")(mod) + mod = transform.PartitionGraph()(mod) check_result(mod, {"x": x_data, "y": y_data}, (2, 2), (y_data * y_data) - (x_data + x_data)) @@ -344,7 +387,7 @@ def test_extern_dnnl(): mod = relay.Module() mod['main'] = WholeGraphAnnotator('dnnl').visit(f) - mod = relay.transform.PartitionGraph()(mod) + mod = transform.PartitionGraph()(mod) ref_mod = relay.Module() ref_mod['main'] = f @@ -368,8 +411,9 @@ def test_extern_dnnl_mobilenet(): mod, params = relay.testing.mobilenet.get_workload( batch_size=1, dtype='float32') - mod = relay.build_extern_compiler(mod, "dnnl") - + op_list = ["nn.conv2d", "nn.dense", "nn.relu", "add"] + mod = WhiteListAnnotator(op_list, "dnnl")(mod) + mod = transform.PartitionGraph()(mod) i_data = np.random.uniform(0, 1, ishape).astype(dtype) ref_mod, params = relay.testing.mobilenet.get_workload(batch_size=1, @@ -381,7 +425,6 @@ def test_extern_dnnl_mobilenet(): (1, 1000), ref_res.asnumpy(), tol=1e-5, params=params) - if __name__ == "__main__": test_multi_node_compiler() test_extern_ccompiler_single_op()