diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index a2fdc19badab..05b588051a1c 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -32,10 +32,17 @@ - The other way is to implement the function by themselves to check the attributes of the op and decide if it should be offloaded to DNNL. """ +import logging + import tvm.ir +from tvm.relay import transform +from tvm.relay.build_module import bind_params_by_name + from ...dataflow_pattern import wildcard, is_op from .register import register_pattern_table +logger = logging.getLogger("DNNL") + def _register_external_op_helper(op_name, supported=True): """The helper function to indicate that a given operator can be supported @@ -63,11 +70,26 @@ def _func_wrapper(expr): _register_external_op_helper("nn.conv2d") _register_external_op_helper("nn.dense") _register_external_op_helper("nn.relu") +_register_external_op_helper("tanh") +_register_external_op_helper("sigmoid") _register_external_op_helper("add") _register_external_op_helper("multiply") -def make_pattern(with_bias=True): +def make_conv_pattern(with_bias=True, with_eltwise=None): + """Create patterns related to nn.conv2d. + + Parameters + ---------- + with_bias : bool + Whether attach `bias_add` to `nn.conv2d`. + with_eltwise : str + The attached elementwise post-op name. + Returns + ------- + conv_out : CallPattern + Call node sequence. + """ data = wildcard() weight = wildcard() bias = wildcard() @@ -76,12 +98,120 @@ def make_pattern(with_bias=True): conv_out = is_op("add")(conv, bias) else: conv_out = conv - return is_op("nn.relu")(conv_out) + if with_eltwise: + return is_op(with_eltwise)(conv_out) + return conv_out + + +def make_dense_pattern(with_bias=True, with_eltwise=None): + """Create patterns related to nn.dense. + + Parameters + ---------- + with_bias : bool + Whether attach `bias_add` to `nn.dense`. + with_eltwise : str + The attached elementwise post-op name. + Returns + ------- + dense_out : CallPattern + Call node sequence. + """ + data = wildcard() + weight = wildcard() + bias = wildcard() + dense = is_op("nn.dense")(data, weight) + if with_bias: + dense_out = is_op("add")(dense, bias) + else: + dense_out = dense + if with_eltwise: + dense_out = is_op(with_eltwise)(dense_out) + return dense_out + + +def make_dnnl_pattern(op, with_bias, with_eltwise): + """Create dnnl patterns. + + Parameters + ---------- + op : str + The first call node's op name. + with_bias : bool + Whether attach `bias_add` to `nn.dense`. + with_eltwise : str + The attached elementwise post-op name. + Returns + ------- + pattern : Tuple(pattern_name, CallPattern) + Created pattern name, along with its CallPattern. + """ + pat_name = "dnnl." + op + pat_name += "_bias" if with_bias else "" + pat_name += ("_" + with_eltwise.split(".")[-1]) if with_eltwise else "" + if op == "conv2d": + dnnl_pattern = (pat_name, make_conv_pattern(with_bias, with_eltwise)) + elif op == "dense": + dnnl_pattern = (pat_name, make_dense_pattern(with_bias, with_eltwise)) + else: + logger.warning("Currently, only conv2d and dense op are supported, but got %s.", op) + dnnl_pattern = () + return dnnl_pattern @register_pattern_table("dnnl") def pattern_table(): - conv2d_bias_relu_pat = ("dnnl.conv2d_bias_relu", make_pattern(with_bias=True)) - conv2d_relu_pat = ("dnnl.conv2d_relu", make_pattern(with_bias=False)) - dnnl_patterns = [conv2d_bias_relu_pat, conv2d_relu_pat] + """Create dnnl patterns. + + Returns + ------- + dnnl_patterns : List[dnnl_pattern] + Created patterns. + """ + elt_list = ["nn.relu", "tanh", "sigmoid", None] + dnnl_patterns = [] + for with_bias in [True, False]: + for elt in elt_list: + if not with_bias and not elt: + return dnnl_patterns + dnnl_patterns.append(make_dnnl_pattern("conv2d", with_bias, elt)) + dnnl_patterns.append(make_dnnl_pattern("dense", with_bias, elt)) return dnnl_patterns + + +def partition_for_dnnl(mod, params=None): + """Partition the graph greedily offloading supported operators to DNNL. + + Parameters + ---------- + mod : Module + The module to run passes on. + params : Optional[Dict[str, NDArray]] + Constant input parameters. + Returns + ------- + mod : Module + Annotated and partitioned module. + """ + + if params: + mod["main"] = bind_params_by_name(mod["main"], params) + seq = tvm.transform.Sequential( + [ + transform.CanonicalizeOps(), + transform.InferType(), + transform.SimplifyInference(), + transform.FoldConstant(), + transform.FoldScaleAxis(), + # fold consecutive add ops to simplify pattern `conv2d-bias_add-bn-relu` + transform.SimplifyExpr(), + transform.FoldConstant(), + transform.MergeComposite(pattern_table()), + transform.AnnotateTarget("dnnl"), + transform.MergeCompilerRegions(), + transform.PartitionGraph(), + ] + ) + with tvm.transform.PassContext(opt_level=3): + mod = seq(mod) + return mod diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index fa1dbc66d8a7..b1b2f580cf94 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -455,9 +455,27 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer { if (name == "dnnl.conv2d_bias_relu") { call = GetRootCall(fn->body.as(), 2, {"nn.conv2d", "add", "nn.relu"}); + } else if (name == "dnnl.conv2d_bias_tanh") { + call = GetRootCall(fn->body.as(), 2, {"nn.conv2d", "add", "tanh"}); + ICHECK(call->op.as()) << "Not op node"; + } else if (name == "dnnl.conv2d_bias_sigmoid") { + call = GetRootCall(fn->body.as(), 2, {"nn.conv2d", "add", "sigmoid"}); + ICHECK(call->op.as()) << "Not op node"; + } else if (name == "dnnl.conv2d_bias") { + call = GetRootCall(fn->body.as(), 1, {"nn.conv2d", "add"}); + ICHECK(call->op.as()) << "Not op node"; } else if (name == "dnnl.conv2d_relu") { call = GetRootCall(fn->body.as(), 1, {"nn.conv2d", "nn.relu"}); ICHECK(call->op.as()) << "Not op node"; + } else if (name == "dnnl.conv2d_tanh") { + call = GetRootCall(fn->body.as(), 1, {"nn.conv2d", "tanh"}); + ICHECK(call->op.as()) << "Not op node"; + } else if (name == "dnnl.conv2d_sigmoid") { + call = GetRootCall(fn->body.as(), 1, {"nn.conv2d", "sigmoid"}); + ICHECK(call->op.as()) << "Not op node"; + } else if (name == "dnnl.dense_bias") { + call = GetRootCall(fn->body.as(), 1, {"nn.dense", "add"}); + ICHECK(call->op.as()) << "Not op node"; } else { LOG(FATAL) << "Unrecognized DNNL pattern: " << name; } diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index b32d137a2566..f9f1961e2697 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -103,15 +103,31 @@ class DNNLJSONRuntime : public JSONRuntimeBase { if ("nn.conv2d" == op_name) { Conv2d(nid); } else if ("dnnl.conv2d_relu" == op_name) { - Conv2d(nid, true, false); + Conv2d(nid, true, false, dnnl::algorithm::eltwise_relu); + } else if ("dnnl.conv2d_tanh" == op_name) { + Conv2d(nid, true, false, dnnl::algorithm::eltwise_tanh); + } else if ("dnnl.conv2d_sigmoid" == op_name) { + Conv2d(nid, true, false, dnnl::algorithm::eltwise_logistic); + } else if ("dnnl.conv2d_bias" == op_name) { + Conv2d(nid, false, true); } else if ("dnnl.conv2d_bias_relu" == op_name) { - Conv2d(nid, true, true); + Conv2d(nid, true, true, dnnl::algorithm::eltwise_relu); + } else if ("dnnl.conv2d_bias_tanh" == op_name) { + Conv2d(nid, true, true, dnnl::algorithm::eltwise_tanh); + } else if ("dnnl.conv2d_bias_sigmoid" == op_name) { + Conv2d(nid, true, true, dnnl::algorithm::eltwise_logistic); } else if ("nn.dense" == op_name) { Dense(nid); + } else if ("dnnl.dense_bias" == op_name) { + Dense(nid, true); } else if ("nn.batch_norm" == op_name) { BatchNorm(nid); } else if ("nn.relu" == op_name) { - Relu(nid); + Eltwise(nid, dnnl::algorithm::eltwise_relu); + } else if ("tanh" == op_name) { + Eltwise(nid, dnnl::algorithm::eltwise_tanh); + } else if ("sigmoid" == op_name) { + Eltwise(nid, dnnl::algorithm::eltwise_logistic); } else if ("add" == op_name) { Binary(nid, dnnl::algorithm::binary_add); } else if ("multiply" == op_name) { @@ -150,7 +166,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase { return entry_out_mem_[eid].first; } - void Conv2d(const size_t& nid, const bool has_relu = false, const bool has_bias = false) { + void Conv2d(const size_t& nid, const bool has_elt = false, const bool has_bias = false, + dnnl::algorithm algo = dnnl::algorithm::eltwise_relu) { auto node = nodes_[nid]; // Setup attributes. @@ -159,24 +176,29 @@ class DNNLJSONRuntime : public JSONRuntimeBase { dnnl::memory::dims input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; dnnl::memory::dims weight_shape = nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_]; std::vector str_strides = node.GetAttr>("strides"); + std::vector str_dilates = node.GetAttr>("dilation"); std::vector str_padding = node.GetAttr>("padding"); dnnl::memory::dim groups = std::stoi(node.GetAttr>("groups")[0]); - dnnl::memory::dim N = input_shape[0], // batch size - IC = input_shape[1], // input channels - IH = input_shape[2], // input height - IW = input_shape[3], // input width - OC = weight_shape[0], // output channels - KH = weight_shape[2], // weight height - KW = weight_shape[3], // weight width - PW_L = std::stoi(str_padding[1]), // width padding: left - PW_R = std::stoi(str_padding[3]), // width padding: right - PH_L = std::stoi(str_padding[0]), // height padding: top - PH_R = std::stoi(str_padding[2]), // height padding: bottom - SH = std::stoi(str_strides[0]), // height-wise stride - SW = std::stoi(str_strides[1]), // weight-wise stride - OH = (IH - KH + PH_L + PH_R) / SH + 1, // output height - OW = (IW - KW + PW_L + PW_R) / SW + 1; // output width + dnnl::memory::dim N = input_shape[0], // batch size + IC = input_shape[1], // input channels + IH = input_shape[2], // input height + IW = input_shape[3], // input width + OC = weight_shape[0], // output channels + KH = weight_shape[2], // weight height + KW = weight_shape[3], // weight width + PW_L = std::stoi(str_padding[1]), // width padding: left + PW_R = std::stoi(str_padding[3]), // width padding: right + PH_L = std::stoi(str_padding[0]), // height padding: top + PH_R = std::stoi(str_padding[2]), // height padding: bottom + SH = std::stoi(str_strides[0]), // height-wise stride + SW = std::stoi(str_strides[1]), // weight-wise stride + DH = std::stoi(str_dilates[0]) - 1, // height-wise dilate + DW = std::stoi(str_dilates[1]) - 1, // weight-wise dilate + DKH = 1 + (KH - 1) * (DH + 1), // dilated weight height + DKW = 1 + (KW - 1) * (DW + 1), // dilated weight width + OH = (IH - DKH + PH_L + PH_R) / SH + 1, // output height + OW = (IW - DKW + PW_L + PW_R) / SW + 1; // output width // Memory shapes. dnnl::memory::dims src_dims = {N, IC, IH, IW}; @@ -187,6 +209,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { dnnl::memory::dims bias_dims = {OC}; dnnl::memory::dims dst_dims = {N, OC, OH, OW}; dnnl::memory::dims strides_dims = {SH, SW}; + dnnl::memory::dims dilates_dims = {DH, DW}; dnnl::memory::dims padding_dims_l = {PH_L, PW_L}; dnnl::memory::dims padding_dims_r = {PH_R, PW_R}; @@ -199,13 +222,14 @@ class DNNLJSONRuntime : public JSONRuntimeBase { // Covn2d description. auto conv_desc = dnnl::convolution_forward::desc( dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct, conv_src_md, - conv_weights_md, conv_bias_md, conv_dst_md, strides_dims, padding_dims_l, padding_dims_r); + conv_weights_md, conv_bias_md, conv_dst_md, strides_dims, dilates_dims, padding_dims_l, + padding_dims_r); - // Enable ReLU + // Enable elementwise post-ops dnnl::primitive_attr attr; - if (has_relu) { + if (has_elt) { dnnl::post_ops ops; - ops.append_eltwise(1.f, dnnl::algorithm::eltwise_relu, 0.f, 0.f); + ops.append_eltwise(1.f, algo, 0.f, 0.f); attr.set_post_ops(ops); } @@ -245,7 +269,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { {DNNL_ARG_DST, conv2d_dst_memory}}); } - void Dense(const size_t& nid) { + void Dense(const size_t& nid, const bool has_bias = false) { auto node = nodes_[nid]; // Setup attributes. @@ -281,9 +305,18 @@ class DNNLJSONRuntime : public JSONRuntimeBase { // Memories. auto data_memory = BindDNNLMemory(data_entry, data_md); auto weight_memory = BindDNNLMemory(weight_entry, weight_md); + + // Bias memory. auto bias_memory = dnnl::memory(bias_md, engine_); - float bias[OC] = {0}; - write_to_dnnl_memory(bias, bias_memory, OC * sizeof(float)); + if (has_bias) { + auto bias_entry = node.GetInputs()[2]; + BindDNNLMemory(bias_entry, bias_memory); + } else { + float bias[OC] = {0}; + write_to_dnnl_memory(bias, bias_memory, OC * sizeof(float)); + } + + // Output memory. JSONGraphNodeEntry out_entry(nid, 0); auto dst_memory = BindDNNLMemory(out_entry, dense_prim_desc.dst_desc()); @@ -335,20 +368,20 @@ class DNNLJSONRuntime : public JSONRuntimeBase { {DNNL_ARG_VARIANCE, variance_memory}}); } - void Relu(const size_t& nid) { + void Eltwise(const size_t& nid, dnnl::algorithm algo) { auto node = nodes_[nid]; auto data_entry = node.GetInputs()[0]; dnnl::memory::dims shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dt::f32); - auto relu_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, - dnnl::algorithm::eltwise_relu, data_md, 0); - auto relu_prim_desc = dnnl::eltwise_forward::primitive_desc(relu_desc, engine_); - ICHECK(data_md == relu_prim_desc.dst_desc()); + auto elt_desc = + dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo, data_md, 0); + auto elt_prim_desc = dnnl::eltwise_forward::primitive_desc(elt_desc, engine_); + ICHECK(data_md == elt_prim_desc.dst_desc()); - auto relu = dnnl::eltwise_forward(relu_prim_desc); - net_.push_back(relu); + auto elt = dnnl::eltwise_forward(elt_prim_desc); + net_.push_back(elt); auto data_memory = BindDNNLMemory(data_entry, data_md); JSONGraphNodeEntry out_entry(nid, 0); diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py new file mode 100755 index 000000000000..7adf3e40ad33 --- /dev/null +++ b/tests/python/contrib/test_dnnl.py @@ -0,0 +1,350 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import pytest +import itertools +import tvm +import tvm.relay.testing +from tvm import relay +from tvm.relay.op.contrib import dnnl +import tvm.testing + +has_dnnl_codegen = pytest.mark.skipif( + not tvm.get_global_func("relay.ext.dnnl", True), reason="DNNL codegen not available" +) + +run_module = tvm.testing.parameter( + pytest.param(False, marks=[has_dnnl_codegen, *tvm.testing.requires_llvm()]), + pytest.param(True, marks=[has_dnnl_codegen, *tvm.testing.requires_llvm()]), + ids=["compile", "run"], +) + + +def vmobj_to_list(o): + if isinstance(o, tvm.nd.NDArray): + return [o.numpy()] + elif isinstance(o, tvm.runtime.container.ADT) or isinstance(o, list): + return [vmobj_to_list(f) for f in o] + else: + raise RuntimeError("Unknown object type: %s" % type(o)) + + +def assert_result_dict_holds(result_dict): + for k1, k2 in itertools.combinations(result_dict, 2): + res1 = vmobj_to_list(result_dict[k1]) + res2 = vmobj_to_list(result_dict[k2]) + for r1, r2 in zip(res1, res2): + tvm.testing.assert_allclose(r1, r2, rtol=1e-3, atol=1e-3) + + +def run_and_verify(mod, input, params, target, run_module): + def check_dnnl_used(mod): + num_dnnl_subgraphs = sum( + [1 if "dnnl" in gv.name_hint else 0 for gv in mod.get_global_vars()] + ) + assert num_dnnl_subgraphs >= 1 + + dev = tvm.cpu() + result_dict = dict() + for mode in ["graph", "vm"]: + for use_dnnl in [False, True]: + result_key = mode + ("_dnnl" if use_dnnl else "") + if use_dnnl: + mod = dnnl.partition_for_dnnl(mod, params) + with tvm.transform.PassContext(opt_level=3): + func = relay.create_executor(mode, mod=mod, device=dev, target=target).evaluate() + if run_module: + if isinstance(input, dict): + result_dict[result_key] = func(**input, **params) + else: + result_dict[result_key] = func(input, **params) + + if run_module: + assert_result_dict_holds(result_dict) + + +def run_and_verify_func(config, run_module, target="llvm", dtype="float32"): + """Test a Relay func by compiling, running, and comparing TVM and DNNL outputs. + + Parameters + ---------- + config : Tuple[relay.Function, Dict[str, NDArray], List[str]] + A tuple containing 1) The function to test, 2) A dictionary of var names to input shapes and + 3) A list of which vars should be considered params. + + run_module: bool + If True, the built module will be run after being compiled. + """ + f, input_shapes, is_param = config + params = {x: np.random.uniform(-1, 1, input_shapes[x]).astype(dtype) for x in is_param} + input_dict = { + k: np.random.uniform(-1, 1, v).astype(dtype) + for k, v in input_shapes.items() + if k not in is_param + } + run_and_verify(f, input_dict, params, target, run_module) + + +def get_conv2d( + x_shape=(1, 32, 8, 8), + k_shape=(16, 32, 3, 3), + groups=1, + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + activation=None, + dtype="float32", +): + x = relay.var("x", shape=(x_shape), dtype=dtype) + kernel = relay.var("kernel", shape=(k_shape), dtype=dtype) + out = relay.nn.conv2d( + x, + kernel, + kernel_size=k_shape[2:4], + groups=groups, + padding=padding, + strides=strides, + dilation=dilation, + channels=k_shape[0], + ) + dic = {"x": x_shape, "kernel": k_shape} + param_lst = ["kernel"] + + if activation == "relu": + return relay.nn.relu(out), dic, param_lst + elif activation == "tanh": + return relay.tanh(out), dic, param_lst + elif activation == "sigmoid": + return relay.sigmoid(out), dic, param_lst + else: + return out, dic, param_lst + + +def get_conv2d_weights_const( + x_shape=(1, 32, 8, 8), + k_shape=(16, 32, 3, 3), + groups=1, + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + dtype="float32", +): + x = relay.var("x", shape=(x_shape), dtype=dtype) + kernel = relay.const(np.ones(k_shape).astype(dtype)) + out = relay.nn.conv2d( + x, + kernel, + channels=k_shape[0], + kernel_size=k_shape[2:4], + groups=groups, + padding=padding, + strides=strides, + dilation=dilation, + ) + dic = {"x": x_shape} + param_lst = [] + return out, dic, param_lst + + +def get_conv2d_bias( + x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), activation=None, dtype="float32" +): + conv, dic, param_lst = get_conv2d(x_shape=x_shape, k_shape=k_shape, dtype=dtype) + bias = relay.var("bias", shape=(k_shape[0],), dtype=dtype) + out = relay.nn.bias_add(conv, bias) + dic["bias"] = (k_shape[0],) + param_lst += ["bias"] + + if activation == "relu": + return relay.nn.relu(out), dic, param_lst + elif activation == "tanh": + return relay.tanh(out), dic, param_lst + elif activation == "sigmoid": + return relay.sigmoid(out), dic, param_lst + else: + return out, dic, param_lst + + +def get_conv2d_bias_bn_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype="float32"): + conv2d_bias, dic, param_lst = get_conv2d_bias(x_shape, k_shape, dtype=dtype) + beta = relay.const(np.zeros(k_shape[0]).astype(dtype)) + gamma = relay.const(np.ones(k_shape[0]).astype(dtype)) + moving_mean = relay.const(np.zeros(k_shape[0]).astype(dtype)) + moving_var = relay.const(np.ones(k_shape[0]).astype(dtype)) + conv2d_bias_bn, _, _ = relay.nn.batch_norm( + conv2d_bias, + gamma=gamma, + beta=beta, + moving_mean=moving_mean, + moving_var=moving_var, + axis=1, + center=True, + scale=True, + epsilon=1e-5, + ) + return relay.nn.relu(conv2d_bias_bn), dic, param_lst + + +def get_dense(x_shape=(1, 16), k_shape=(32, 16), activation=None, dtype="float32"): + x = relay.var("x", shape=(x_shape), dtype=dtype) + kernel = relay.var("kernel", shape=(k_shape), dtype=dtype) + out = relay.nn.dense(x, kernel, units=k_shape[0]) + dic = {"x": x_shape, "kernel": k_shape} + param_lst = ["kernel"] + return out, dic, param_lst + + +def get_dense_bias(x_shape=(1, 16), k_shape=(32, 16), activation=None, dtype="float32"): + dense, dic, param_lst = get_dense(x_shape=x_shape, k_shape=k_shape, dtype=dtype) + bias = relay.var("bias", shape=(k_shape[0],), dtype=dtype) + out = relay.nn.bias_add(dense, bias) + dic["bias"] = (k_shape[0],) + param_lst += ["bias"] + return out, dic, param_lst + + +def test_dnnl_not_compatible(run_module, target="llvm", dtype="float32"): + xshape = (1, 32, 14, 14) + x_data = np.random.uniform(-1, 1, xshape).astype(dtype) + + x = relay.var("x", shape=(xshape), dtype=dtype) + y = relay.add(x, x) + z = relay.cast(relay.cast(y, "int32"), "float32") + out = relay.nn.relu(z) + f = relay.Function([x], out) + mod = tvm.IRModule() + mod["main"] = f + mod = dnnl.partition_for_dnnl(mod) + for mode in ["graph", "vm"]: + with tvm.transform.PassContext(opt_level=3): + func = relay.create_executor(mode, mod=mod, device=tvm.cpu(0), target=target).evaluate() + if run_module: + results = func(x_data) + + +def test_multiple_outputs(run_module, dtype="float32"): + def get_graph(): + x = relay.var("x", shape=(1, 3), dtype=dtype) + y = relay.var("y", shape=(1, 3), dtype=dtype) + z = relay.add(x, y) + w = relay.add(z, y) + out = relay.Tuple((z, w)) + f = tvm.IRModule.from_expr(out) + return f, {"x": (1, 3), "y": (1, 3)}, [] + + run_and_verify_func(get_graph(), run_module=run_module, dtype=dtype) + + +def test_unary(run_module): + def get_graph(op, x_shape=(1, 8, 3, 3)): + x = relay.var("x", shape=(x_shape), dtype="float32") + out = op(x) + f = tvm.IRModule.from_expr(out) + return f, {"x": x_shape}, [] + + for op in [ + relay.nn.relu, + relay.tanh, + relay.sigmoid, + ]: + run_and_verify_func(get_graph(op), run_module=run_module) + + +def test_conv2d(run_module, dtype="float32"): + x_shape = (1, 32, 8, 8) + for k_shape, groups in [((16, 32, 3, 3), 1), ((32, 1, 3, 3), 32)]: + for padding in [(0, 0), (1, 1)]: + for strides in [(1, 1), (2, 2)]: + for dilation in [(1, 1), (2, 2)]: + conv2d, dic, param_lst = get_conv2d( + x_shape=x_shape, + k_shape=k_shape, + groups=groups, + padding=padding, + strides=strides, + dilation=dilation, + dtype=dtype, + ) + conv2d = tvm.IRModule.from_expr(conv2d) + config = conv2d, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + +def test_conv2d_weights_const(run_module, dtype="float32"): + x_shape = (1, 32, 8, 8) + k_shape = (16, 32, 3, 3) + conv2d, dic, param_lst = get_conv2d_weights_const(x_shape, k_shape, dtype=dtype) + conv2d = tvm.IRModule.from_expr(conv2d) + config = conv2d, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + +def test_conv2d_pattern(run_module, dtype="float32"): + x_shape = (1, 32, 8, 8) + k_shape = (16, 32, 3, 3) + activation_lst = [None, "relu", "tanh", "sigmoid"] + for a in activation_lst: + conv2d, dic, param_lst = get_conv2d(x_shape, k_shape, activation=a, dtype=dtype) + conv2d = tvm.IRModule.from_expr(conv2d) + config = conv2d, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + conv2d_bias, dic, param_lst = get_conv2d_bias(x_shape, k_shape, activation=a, dtype=dtype) + conv2d_bias = tvm.IRModule.from_expr(conv2d_bias) + config = conv2d_bias, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + conv2d_bias_bn_relu, dic, param_lst = get_conv2d_bias_bn_relu(x_shape, k_shape, dtype=dtype) + conv2d_bias_bn_relu = tvm.IRModule.from_expr(conv2d_bias_bn_relu) + config = conv2d_bias_bn_relu, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + +def test_dense(run_module, dtype="float32"): + x_shape = (1, 16) + k_shape = (32, 16) + + dense, dic, param_lst = get_dense(x_shape, k_shape, dtype=dtype) + dense = tvm.IRModule.from_expr(dense) + config = dense, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + dense, dic, param_lst = get_dense(x_shape, k_shape=(1, 16), dtype=dtype) + dense = tvm.IRModule.from_expr(dense) + config = dense, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + +def test_dense_pattern(run_module, dtype="float32"): + x_shape = (1, 16) + k_shape = (32, 16) + + dense, dic, param_lst = get_dense(x_shape, k_shape, dtype=dtype) + dense = tvm.IRModule.from_expr(dense) + config = dense, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + dense_bias, dic, param_lst = get_dense_bias(x_shape, k_shape, dtype=dtype) + dense_bias = tvm.IRModule.from_expr(dense_bias) + config = dense_bias, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 80fb2e03af2c..736ece265bde 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -919,10 +919,25 @@ def expected(): def test_dnnl_fuse(): dnnl_patterns = get_pattern_table("dnnl") - conv2d_bias_relu_pat, conv2d_relu_pat = dnnl_patterns - - def get_blocks(prefix, data, in_channel, out_channel, include_bn=True, include_sigmoid=False): + ( + conv2d_bias_relu_pat, + conv2d_bias_sigmoid_pat, + conv2d_bias_pat, + conv2d_relu_pat, + conv2d_sigmoid_pat, + ) = (dnnl_patterns[0], dnnl_patterns[4], dnnl_patterns[6], dnnl_patterns[8], dnnl_patterns[12]) + + def get_blocks( + prefix, + data, + in_channel, + out_channel, + include_bias_add=True, + include_bn=True, + include_sigmoid=False, + ): weight = relay.var(prefix + "weight") + bias = relay.var(prefix + "bias") bn_gamma = relay.var(prefix + "bn_gamma") bn_beta = relay.var(prefix + "bn_beta") bn_mmean = relay.var(prefix + "bn_mean") @@ -931,6 +946,8 @@ def get_blocks(prefix, data, in_channel, out_channel, include_bn=True, include_s layer = relay.nn.conv2d( data=data, weight=weight, kernel_size=(3, 3), channels=out_channel, padding=(1, 1) ) + if include_bias_add: + layer = relay.nn.bias_add(layer, bias) if include_bn: bn_output = relay.nn.batch_norm(layer, bn_gamma, bn_beta, bn_mmean, bn_mvar) layer = bn_output[0] @@ -940,11 +957,11 @@ def get_blocks(prefix, data, in_channel, out_channel, include_bn=True, include_s layer = relay.nn.relu(layer) return layer - def get_net(include_bn=True, include_sigmoid=False): + def get_net(include_bias_add=True, include_bn=True, include_sigmoid=False): data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32")) - block1 = get_blocks("block1_", data, 3, 8, include_bn, include_sigmoid) + block1 = get_blocks("block1_", data, 3, 8, include_bias_add, include_bn, include_sigmoid) # The second block is always conv + relu, to make it more interesting - block2 = get_blocks("block2_", block1, 8, 8, False, include_sigmoid) + block2 = get_blocks("block2_", block1, 8, 8, False, False, include_sigmoid) return relay.Function(relay.analysis.free_vars(block2), block2) def get_partitoned_mod(mod, params, pattern_table): @@ -959,9 +976,18 @@ def get_partitoned_mod(mod, params, pattern_table): transform.FoldScaleAxis(), ] ) + # fold consecutive add ops to simplify pattern `conv2d-bias_add-bn-relu` + remove_linear_pass = tvm.transform.Sequential( + [ + transform.SimplifyExpr(), + transform.FoldConstant(), + ] + ) composite_partition = tvm.transform.Sequential( [ + transform.CanonicalizeOps(), remove_bn_pass, + remove_linear_pass, transform.MergeComposite(pattern_table), transform.AnnotateTarget("dnnl"), transform.PartitionGraph(), @@ -971,25 +997,38 @@ def get_partitoned_mod(mod, params, pattern_table): with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): return composite_partition(mod) - def test_detect_pattern(pattern_table, include_bn, include_sigmoid, num_expected_partition): - net = get_net(include_bn, include_sigmoid) + def test_detect_pattern( + pattern_table, include_bias_add, include_bn, include_sigmoid, num_expected_partition + ): + net = get_net(include_bias_add, include_bn, include_sigmoid) mod, params = tvm.relay.testing.create_workload(net) mod = get_partitoned_mod(mod, params, pattern_table) assert len(mod.functions) - 1 == num_expected_partition # -1 for main def test_partition(): # conv + bn + relu, conv + relu -> fused conv_bias_relu, conv, and relu - test_detect_pattern([conv2d_bias_relu_pat], True, False, 3) + test_detect_pattern([conv2d_bias_relu_pat], False, True, False, 3) # conv + bn + relu, conv + relu -> conv, bias, relu, and fused conv_relu - test_detect_pattern([conv2d_relu_pat], True, False, 4) + test_detect_pattern([conv2d_relu_pat], False, True, False, 4) # conv + bn + relu, conv + relu -> fused conv_bias_relu, and fused conv_relu - test_detect_pattern([conv2d_bias_relu_pat, conv2d_relu_pat], True, False, 2) + test_detect_pattern([conv2d_bias_relu_pat, conv2d_relu_pat], False, True, False, 2) + # conv + bias_add + bn + relu, conv + relu -> fused conv_bias_relu, and fused conv_relu + test_detect_pattern([conv2d_bias_relu_pat, conv2d_relu_pat], True, True, False, 2) # conv + relu, conv + relu -> two fused conv_relu - test_detect_pattern([conv2d_relu_pat], False, False, 2) + test_detect_pattern([conv2d_relu_pat], False, False, False, 2) # conv + relu, conv + relu -> no fusion, 4 partition each with a single op - test_detect_pattern([conv2d_bias_relu_pat], False, False, 4) + test_detect_pattern([conv2d_bias_relu_pat], False, False, False, 4) # conv + bn + sigmoid + relu, conv + sigmoid + relu -> no fusion - test_detect_pattern([conv2d_bias_relu_pat, conv2d_relu_pat], True, True, 5) + test_detect_pattern([conv2d_bias_relu_pat, conv2d_relu_pat], False, True, True, 7) + # conv + bias_add + bn + sigmoid + relu, conv + sigmoid + relu -> fused conv_bias + # and single op sigmoid, relu, conv, sigmoid, relu + test_detect_pattern([conv2d_bias_pat, conv2d_relu_pat], True, True, True, 6) + # conv + bias_add + bn + sigmoid + relu, conv + sigmoid + relu -> fused conv_bias_sigmoid + # and single op relu, conv, sigmoid, relu + test_detect_pattern([conv2d_bias_sigmoid_pat, conv2d_relu_pat], True, True, True, 5) + # conv + bias_add + bn + sigmoid + relu, conv + sigmoid + relu -> fused conv_bias_sigmoid, + # fused conv_sigmoid and single op relu, relu + test_detect_pattern([conv2d_bias_sigmoid_pat, conv2d_sigmoid_pat], True, True, True, 4) def test_partition_mobilenet(): mod, params = relay.testing.mobilenet.get_workload()