Skip to content

Commit

Permalink
[ETHOSN] Adding support for Leaky ReLU (#11261)
Browse files Browse the repository at this point in the history
* [ETHOSN] Adding support for Leaky ReLU

Change-Id: Icad69b2ae6ed4b3f3949cf5673efe2571aa66f5f

* add some missing error reporting

Change-Id: I935054c4d19a939e122092fab3c6c77204d9ead8
  • Loading branch information
lhutton1 committed May 11, 2022
1 parent cfb5674 commit 3be5622
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 1 deletion.
14 changes: 14 additions & 0 deletions python/tvm/relay/op/contrib/ethosn.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@ def qnn_tanh_pattern():
pattern = is_op("qnn.quantize")(pattern, is_constant(), is_constant())
return pattern

def qnn_leaky_relu_pattern():
pattern = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant())
pattern = is_op("nn.leaky_relu")(pattern)
pattern = is_op("qnn.quantize")(pattern, is_constant(), is_constant())
return pattern

def check_conv2d(extract):
"""Check if a conv2d is supported by Ethos-N."""
if not ethosn_available():
Expand Down Expand Up @@ -173,13 +179,21 @@ def check_tanh(extract):

return support.tanh(extract)

def check_leaky_relu(extract):
"""Check if Leaky ReLU is supported."""
if not ethosn_available():
return False

return support.leaky_relu(extract)

return [
("ethos-n.qnn_conv2d", qnn_conv_pattern(), check_conv2d),
("ethos-n.qnn_avg_pool2d", qnn_avg_pool2d_pattern(), check_avg_pool2d),
("ethos-n.qnn_sigmoid", qnn_sigmoid_pattern(), check_sigmoid),
("ethos-n.qnn_fc", qnn_fc_pattern(), check_fc),
("ethos-n.qnn_mean", qnn_mean_pattern(), check_mean),
("ethos-n.qnn_tanh", qnn_tanh_pattern(), check_tanh),
("ethos-n.qnn_leaky_relu", qnn_leaky_relu_pattern(), check_leaky_relu),
]


Expand Down
44 changes: 43 additions & 1 deletion src/relay/backend/contrib/ethosn/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ void InferTensorsVisitor::InferCall(const CallNode* cn) {
TanhParams params;
err += EthosnAPI::Tanh(cn->op.as<FunctionNode>()->body, &params);
tensor_table_[cn->args[0]] = {params.input_info};
} else if (IsEthosnFunc(call, "ethos-n.qnn_leaky_relu")) {
LeakyReLUParams params;
err += EthosnAPI::LeakyReLU(cn->op.as<FunctionNode>()->body, &params);
tensor_table_[cn->args[0]] = {params.input_info};
} else if (IsEthosnOp(call, "qnn.concatenate")) {
ConcatenateParams params;
err = EthosnAPI::Concatenate(call, &params);
Expand Down Expand Up @@ -290,6 +294,9 @@ sl::TensorsAndId ConstructNetworkVisitor::HandleCall(const CallNode* cn) {
} else if (IsEthosnFunc(call, "ethos-n.qnn_tanh")) {
if ((err = MakeTanhLayer(call, &tensor))) ReportFatalError(call, err);
return MakeOps(tensor);
} else if (IsEthosnFunc(call, "ethos-n.qnn_leaky_relu")) {
if ((err = MakeLeakyReLULayer(call, &tensor))) ReportFatalError(call, err);
return MakeOps(tensor);
} else if (IsEthosnOp(call, "qnn.concatenate")) {
if ((err = MakeConcatenateLayer(call, &tensor))) ReportFatalError(call, err);
return MakeOps(tensor);
Expand Down Expand Up @@ -492,6 +499,24 @@ EthosnError ConstructNetworkVisitor::MakeTanhLayer(const Call& call,
return EthosnError();
}

EthosnError ConstructNetworkVisitor::MakeLeakyReLULayer(const Call& call,
sl::TensorAndId<sl::Operand>* out) {
LeakyReLUParams params;
params.input_info = GetTensorInfo(tensor_table_, call);
if (auto err = EthosnAPI::LeakyReLU(call->op.as<FunctionNode>()->body, &params)) {
return err;
}

auto input = operand_table_[call->args[0]][0];

try {
*out = AddLeakyRelu(network_, *input, params.leaky_relu_info);
} catch (const sl::NotSupportedException& e) {
return EthosnError(e.what());
}
return EthosnError();
}

EthosnError ConstructNetworkVisitor::MakeConcatenateLayer(const Call& call,
sl::TensorAndId<sl::Operand>* out) {
ConcatenateParams params;
Expand Down Expand Up @@ -793,7 +818,24 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.tanh")
TanhParams params;
auto err = EthosnAPI::Tanh(call, &params);
err += EthosnCompiler::SupportedSetup();
*rv = !err && EthosnCompiler::GetSupported()->IsTanhSupported(params.input_info);
char reason[kReasonMaxLength];
reason[0] = '\0';
*rv = !err && EthosnCompiler::GetSupported()->IsTanhSupported(params.input_info, nullptr,
reason, sizeof(reason));
err += EthosnError(reason);
});

TVM_REGISTER_GLOBAL("relay.ethos-n.support.leaky_relu")
.set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) {
Call call = args[0];
LeakyReLUParams params;
auto err = EthosnAPI::LeakyReLU(call, &params);
err += EthosnCompiler::SupportedSetup();
char reason[kReasonMaxLength];
reason[0] = '\0';
*rv = !err && EthosnCompiler::GetSupported()->IsLeakyReluSupported(
params.leaky_relu_info, params.input_info, nullptr, reason, sizeof(reason));
err += EthosnError(reason);
});

TVM_REGISTER_GLOBAL("relay.ethos-n.support.concatenate")
Expand Down
1 change: 1 addition & 0 deletions src/relay/backend/contrib/ethosn/codegen_ethosn.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ class ConstructNetworkVisitor : public MixedModeVisitor, private ErrorReportingP
EthosnError MakeSplitLayer(const Call& call, sl::TensorsAndId* outs);
EthosnError MakeDepthToSpaceLayer(const Call& call, sl::TensorAndId<sl::Operand>* out);
EthosnError MakeReluLayer(const Call& call, sl::TensorAndId<sl::Operand>* out);
EthosnError MakeLeakyReLULayer(const Call& call, sl::TensorAndId<sl::Operand>* out);

/*! \brief A look-up table from Expr to layers. */
std::map<Expr, std::vector<std::shared_ptr<sl::Operand>>> operand_table_;
Expand Down
32 changes: 32 additions & 0 deletions src/relay/backend/contrib/ethosn/ethosn_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,38 @@ EthosnError EthosnAPI::Tanh(const Expr& expr, TanhParams* params) {
return err;
}

EthosnError EthosnAPI::LeakyReLU(const Expr& expr, LeakyReLUParams* params) {
Call quantize = Downcast<Call>(expr);
Call leaky_relu = Downcast<Call>(quantize->args[0]);
Call dequantize = Downcast<Call>(leaky_relu->args[0]);

const auto* input_dtype = quantize->checked_type().as<TensorTypeNode>();
sl::TensorShape input_tensor_shape = {1, 1, 1, 1};
sl::DataType input_tensor_dtype;
EthosnError err = Tvm2Npu(input_dtype->shape, &input_tensor_shape);
err += Tvm2Npu(input_dtype->dtype, &input_tensor_dtype);
float input_sc;
int input_zp;
err += AsConstant(dequantize->args[2], &input_zp);
err += AsConstant(dequantize->args[1], &input_sc);
float output_sc;
int output_zp;
err += AsConstant(quantize->args[2], &output_zp);
err += AsConstant(quantize->args[1], &output_sc);

const auto* attrs = leaky_relu->attrs.as<LeakyReluAttrs>();
double alpha = attrs->alpha;
if (alpha >= 1.0f || alpha <= 0.0f) {
err += EthosnError(
ErrStrm() << "leaky relu alpha must be less than 1 and greater than 0, but was " << alpha);
return err;
}
params->leaky_relu_info = sl::LeakyReluInfo(alpha, sl::QuantizationInfo(output_zp, output_sc));
params->input_info = sl::TensorInfo(input_tensor_shape, input_tensor_dtype, sl::DataFormat::NHWC,
sl::QuantizationInfo(input_zp, input_sc));
return err;
}

EthosnError EthosnAPI::Concatenate(const Expr& expr, ConcatenateParams* params) {
Call call = Downcast<Call>(expr);
const auto& attrs = call->attrs.as<ConcatenateAttrs>();
Expand Down
7 changes: 7 additions & 0 deletions src/relay/backend/contrib/ethosn/ethosn_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ struct TanhParams {
sl::TensorInfo input_info;
};

struct LeakyReLUParams {
sl::LeakyReluInfo leaky_relu_info;
sl::TensorInfo input_info;
};

struct ConcatenateParams {
sl::QuantizationInfo qInfo;
sl::ConcatenationInfo concat_info = sl::ConcatenationInfo(1, qInfo);
Expand Down Expand Up @@ -204,6 +209,8 @@ class EthosnAPI {
static EthosnError Mean(const Expr& expr, MeanParams* params);
/*! \brief Extract the Support Library tanh params from a Relay an ethos-n tanh func */
static EthosnError Tanh(const Expr& expr, TanhParams* params);
/*! \brief Extract the Support Library leaky relu params from an ethos-n leaky relu Relu call. */
static EthosnError LeakyReLU(const Expr& expr, LeakyReLUParams* params);
/*! \brief Extract the Support Library concatenate params from a Relay qnn.concatenate call */
static EthosnError Concatenate(const Expr& expr, ConcatenateParams* params);
/*! \brief Extract the Support Library split params from a Relay split call */
Expand Down
86 changes: 86 additions & 0 deletions tests/python/contrib/test_ethosn/test_leaky_relu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Integration tests for Leaky ReLU"""

import pytest
import numpy as np

import tvm
from tvm import relay
from tvm.testing import requires_ethosn

from . import infrastructure as tei


def _get_model(shape, input_zp, input_sc, output_zp, output_sc, dtype, alpha):
x = relay.var("x", shape=shape, dtype=dtype)
x = relay.qnn.op.dequantize(
x,
input_scale=relay.const(input_sc, "float32"),
input_zero_point=relay.const(input_zp, "int32"),
)
x = relay.nn.leaky_relu(x, alpha=alpha)
return relay.qnn.op.quantize(
x,
output_scale=relay.const(output_sc, "float32"),
output_zero_point=relay.const(output_zp, "int32"),
out_dtype=dtype,
)


@requires_ethosn
@pytest.mark.parametrize("dtype", ["uint8", "int8"])
@pytest.mark.parametrize("shape", [(1, 52, 52, 3), (1, 3, 8, 2)])
@pytest.mark.parametrize("alpha", [0.001, 0.5678])
def test_leaky_relu(dtype, shape, alpha):
"""Compare Leaky ReLU output with TVM."""
np.random.seed(0)

iinfo = np.iinfo(dtype)
zp_min = iinfo.min
zp_max = iinfo.max
input_zp = zp_min + 120
input_sc = 0.0068132
output_zp = zp_min + 128
output_sc = 0.0078125

inputs = {"x": tvm.nd.array(np.random.randint(zp_min, high=zp_max, size=shape, dtype=dtype))}
outputs = []
for npu in [False, True]:
model = _get_model(shape, input_zp, input_sc, output_zp, output_sc, dtype, alpha)
mod = tei.make_module(model, [])
outputs.append(tei.build_and_run(mod, inputs, 1, {}, npu=npu))

tei.verify(outputs, dtype, 1)


@requires_ethosn
@pytest.mark.parametrize("dtype", ["int8"])
@pytest.mark.parametrize("shape", [(1, 14, 14, 2)])
@pytest.mark.parametrize("alpha", [-1.34, 2.32, 1, 0])
def test_leaky_relu_unsupported_alpha(dtype, shape, alpha):
"""Test unsupported values of alpha (<= 0, >= 1) in Leaky ReLU."""
iinfo = np.iinfo(dtype)
zp_min = iinfo.min

err_msg = f"leaky relu alpha must be less than 1 and greater than 0, but was {alpha}"

model = _get_model(shape, zp_min + 120, 0.0068132, zp_min + 128, 0.0078125, dtype, alpha)
model = tei.make_ethosn_composite(model, "ethos-n.qnn_leaky_relu")
mod = tei.make_ethosn_partition(model)
tei.test_error(mod, {}, err_msg)

0 comments on commit 3be5622

Please sign in to comment.