diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h index f0280a90c604..deb900d52d09 100644 --- a/include/tvm/relay/qnn/attrs.h +++ b/include/tvm/relay/qnn/attrs.h @@ -36,6 +36,7 @@ namespace qnn { struct RequantizeAttrs : public tvm::AttrsNode { int axis; std::string rounding; + std::string compute_dtype; DataType out_dtype; TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") { @@ -44,7 +45,7 @@ struct RequantizeAttrs : public tvm::AttrsNode { "The output channel axis for channel wise quantization. Default value is -1," "which corresponds to the last axis.") .set_default(-1); - TVM_ATTR_FIELD(rounding).set_default("UPWARD").describe( + TVM_ATTR_FIELD(rounding).set_default("None").describe( "Defines the rounding direction when the value is midway between" "two representable values. There are two supported modes - UPWARD" "or TONEAREST. Both modes behave exactly same except at the" @@ -54,6 +55,11 @@ struct RequantizeAttrs : public tvm::AttrsNode { "value is rounded away from zero at midpoints (for example, -1.5" "rounds to -2). More context can be found at following gblic manual" "https://www.gnu.org/software/libc/manual/html_node/Rounding.html."); + TVM_ATTR_FIELD(compute_dtype) + .set_default("None") + .describe( + "Specifies the data type used during requantize. Supported " + "options: \"int64\", \"float32\", \"float64\""); TVM_ATTR_FIELD(out_dtype) .set_default(NullValue()) .describe("Output data type, set to explicit type under mixed precision setting"); diff --git a/python/tvm/relay/qnn/op/_requantize.py b/python/tvm/relay/qnn/op/_requantize.py new file mode 100644 index 000000000000..2e2fd9fd2980 --- /dev/null +++ b/python/tvm/relay/qnn/op/_requantize.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +"""Internal module for qnn requantization.""" +import tvm._ffi + +tvm._ffi._init_api("relay._requantize", __name__) diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index 7f707c093ff3..aef514d81cc1 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -14,19 +14,109 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name +# pylint: disable=invalid-name,unused-argument, not-context-manager """QNN dialect operators.""" from __future__ import absolute_import as _abs +import tvm +import tvm.ir from tvm import relay +from tvm.runtime import Object from tvm.relay.expr import Tuple, TupleWrapper from tvm.relay.op.nn.utils import get_pad_tuple2d from tvm.topi.nn.qnn import SQNN_DTYPE_TO_CODE - +from tvm.target import Target +from tvm.topi.x86.utils import target_has_sse41 from ... import op as reg from ...op import OpPattern from . import _make +from . import _requantize + + +@tvm._ffi.register_object("relay.qnn.op.RequantizeConfig") +class RequantizeConfig(Object): + """Configure the requantization behavior by setting config variables. + + Note + ---- + This object is backed by node system in C++, with arguments that can be + exchanged between python and C++. + + Do not construct directly, use requantize_config instead. + + The fields that are backed by the C++ node are immutable once an instance + is constructed. Use _node_defaults getters to get results for the fields. + """ + + @staticmethod + def _get_node_default_rounding(): + return "UPWARD" + + @staticmethod + def _get_node_default_compute_dtype(): + target = Target.current(True) + if target and str(target.kind) == "llvm" and target_has_sse41(target.mcpu): + return "float32" + + return "int64" + + _node_defaults = { + "rounding": _get_node_default_rounding.__func__, + "compute_dtype": _get_node_default_compute_dtype.__func__, + } + + # pylint: disable=no-member + def __init__(self, handle): + """Initialize the function with handle + + Parameters + ---------- + handle : SymbolHandle + the handle to the underlying C++ Symbol + """ + super(RequantizeConfig, self).__init__(handle) + self.handle = handle + + def __enter__(self): + # pylint: disable=protected-access + _requantize._EnterRequantizeConfigScope(self) + return self + + def __exit__(self, ptype, value, trace): + _requantize._ExitRequantizeConfigScope() + + def __setattr__(self, name, value): + if name in RequantizeConfig._node_defaults: + raise AttributeError("'%s' object cannot set attribute '%s'" % (str(type(self)), name)) + return super(RequantizeConfig, self).__setattr__(name, value) + + +def current_requantize_config(): + """Get the current requantization configuration.""" + return _requantize._GetCurrentRequantizeConfig() + + +def requantize_config(**kwargs): + """Configure the requantization behavior by setting config variables. + + Parameters + --------- + rounding: "UPWARD" or "TONEAREST" + Rounding direction for fixed point multiplications. + compute_dtype: + Specifies the data type used during requantize. + Supported options: \"int64\", \"float32\", \"float64\" + + Returns + ------- + config: RequantizeConfig + The requantization configuration + """ + node_args = { + k: v() if k not in kwargs else kwargs[k] for k, v in RequantizeConfig._node_defaults.items() + } + return tvm.ir.make_node("relay.qnn.op.RequantizeConfig", **node_args) def requantize( @@ -36,7 +126,8 @@ def requantize( output_scale, output_zero_point, axis=-1, - rounding="UPWARD", + rounding="None", + compute_dtype="None", out_dtype="int8", ): r"""Requantized operator. @@ -70,7 +161,9 @@ def requantize( rounding : string, optional Defines the rounding direction when the value is midway between two representable values. - + compute_dtype: + Specifies the data type used during requantize. + Supported options: \"int64\", \"float32\", \"float64\" out_dtype : str, optional Specifies the output data type. @@ -88,6 +181,7 @@ def requantize( output_zero_point, axis, rounding, + compute_dtype, out_dtype, ) diff --git a/python/tvm/topi/x86/utils.py b/python/tvm/topi/x86/utils.py index 50c5c848ee0a..c364027022da 100644 --- a/python/tvm/topi/x86/utils.py +++ b/python/tvm/topi/x86/utils.py @@ -18,6 +18,23 @@ import tvm +@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_sse41") +def target_has_sse41(target): + return ( + target_has_sse42(target) + or target_has_avx(target) + or target_has_avx2(target) + or target_has_avx512(target) + or target_has_vnni(target) + or target + in { + "btver2", + "penryn", + } + ) + + +@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_sse42") def target_has_sse42(target): return ( target_has_avx(target) @@ -42,6 +59,7 @@ def target_has_sse42(target): ) +@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_avx") def target_has_avx(target): return ( target_has_avx2(target) @@ -51,6 +69,7 @@ def target_has_avx(target): ) +@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_avx2") def target_has_avx2(target): return ( target_has_avx512(target) @@ -70,6 +89,7 @@ def target_has_avx2(target): ) +@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_avx512") def target_has_avx512(target): return target in { "skylake-avx512", @@ -89,6 +109,7 @@ def target_has_avx512(target): } +@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_vnni") def target_has_vnni(target): return target in { "cascadelake", @@ -102,6 +123,7 @@ def target_has_vnni(target): } +@tvm._ffi.register_func("tvm.topi.x86.utils.get_simd_32bit_lanes") def get_simd_32bit_lanes(): mcpu = tvm.target.Target.current().mcpu fp32_vec_len = 4 diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index a7d214761b9b..ea143fe41713 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -26,9 +26,11 @@ #include #include +#include "../../op/op_common.h" #include "../../transforms/infer_layout_utils.h" #include "../../transforms/pattern_utils.h" #include "../utils.h" +#include "./requantize_config.h" namespace tvm { namespace relay { @@ -111,6 +113,65 @@ InferCorrectLayoutOutput RequantizeInferCorrectLayout(const Attrs& attrs, return InferCorrectLayoutOutput(input_layouts, output_layouts, Attrs(param)); } +bool has_current_target_sse41_support() { + auto target = Target::Current(true); + Optional mcpu = + target.defined() ? target->GetAttr("mcpu") : Optional(nullptr); + auto target_has_sse41_fn_ptr = tvm::runtime::Registry::Get("tvm.topi.x86.utils.target_has_sse41"); + ICHECK(target_has_sse41_fn_ptr) << "Function tvm.topi.x86.utils.target_has_sse41 not found"; + return mcpu && (*target_has_sse41_fn_ptr)(mcpu.value()); +} + +/* + * \brief TONEAREST is the standard rounding where the value is rounded away + * from zero at midpoints (for example, -1.5 rounds to -2). + * \param input_tensor The input tensor to rounding op. + * \return The sequence of existing Relay ops. + */ +template +Expr Tonearest(const Expr& input_tensor) { + if (has_current_target_sse41_support()) return Round(input_tensor); + + auto half = MakeConstantScalar(DataType::Float(Bits), 0.5f); + auto zero = MakeConstantScalar(DataType::Float(Bits), 0.f); + auto pos_one = MakeConstantScalar(DataType::Float(Bits), +1.f); + auto neg_one = MakeConstantScalar(DataType::Float(Bits), -1.f); + auto multiplier = Where(Less(input_tensor, zero), neg_one, pos_one); + auto half_multiplied = Multiply(half, multiplier); + auto input_tensor_biased = Add(input_tensor, half_multiplied); + auto input_tensor_biased_multiplied = Multiply(input_tensor_biased, multiplier); + auto input_tensor_biased_multiplied_int = + Cast(input_tensor_biased_multiplied, DataType::Int(Bits)); + auto input_tensor_biased_multiplied_float = + Cast(input_tensor_biased_multiplied_int, DataType::Float(Bits)); + auto input_tensor_rounded = Multiply(input_tensor_biased_multiplied_float, multiplier); + return Where(IsFinite(input_tensor), input_tensor_rounded, input_tensor); +} + +/* + * \brief UPWARD is the standard rounding except at midpoints where the value + * is rounded to positive infinity (for example, -1.5 rounds to -1). + * \param input_tensor The input tensor to rounding op. + * \return The sequence of existing Relay ops. + */ +template +Expr Upward(const Expr& input_tensor) { + auto half = MakeConstantScalar(DataType::Float(Bits), 0.5f); + auto input_tensor_biased = Add(input_tensor, half); + if (has_current_target_sse41_support()) return Floor(input_tensor_biased); + + auto zero = MakeConstantScalar(DataType::Float(Bits), 0.f); + auto one = MakeConstantScalar(DataType::Float(Bits), +1.f); + auto input_tensor_biased_int = Cast(input_tensor_biased, DataType::Int(Bits)); + auto input_tensor_biased_float = Cast(input_tensor_biased_int, DataType::Float(Bits)); + auto is_subtraction_not_necessary = + LogicalOr(Equal(input_tensor_biased, input_tensor_biased_float), + GreaterEqual(input_tensor_biased, zero)); + auto input_tensor_rounded = Where(is_subtraction_not_necessary, input_tensor_biased_float, + Subtract(input_tensor_biased_float, one)); + return Where(IsFinite(input_tensor), input_tensor_rounded, input_tensor); +} + // Lowering of qnn.requantize op /* @@ -119,7 +180,7 @@ InferCorrectLayoutOutput RequantizeInferCorrectLayout(const Attrs& attrs, * \param param The requantize op attrs. * \param input_shape The input tensor shape of the requantize op. * \return The sequence of existing Relay ops. - * \note Requantization using only integer computation. Here, the computation is + * \note RequantizationInt using only integer computation. Here, the computation is * converted to a fixed point computation by computing output multiplier * and shift. This is useful, if the target device does not support/have * very expensive floating point computations. @@ -131,10 +192,10 @@ InferCorrectLayoutOutput RequantizeInferCorrectLayout(const Attrs& attrs, * 4) Add the output zero point. * 5) Cast to the out_dtype. */ -Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, - const Expr& input_zero_point, const Expr& output_scale, - const Expr& output_zero_point, const RequantizeAttrs* param, - const Array& input_shape, const DataType& out_dtype) { +Expr RequantizeLowerInt(const Expr& input_tensor, const Expr& input_scale, + const Expr& input_zero_point, const Expr& output_scale, + const Expr& output_zero_point, const RequantizeAttrs* param, + const Array& input_shape, const DataType& out_dtype) { auto tensor = Cast(input_tensor, DataType::Int(32)); auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0); if (!IsEqualScalar(input_zero_point, zero_scalar)) { @@ -208,6 +269,142 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, return Cast(clipped_t, out_dtype); } +// Lowering of qnn.requantize op + +/* + * \brief Lower requantize to a sequence of ops. + * \param input_tensor The input tensor to requantize op. + * \param param The requantize op attrs. + * \param input_shape The input tensor shape of the requantize op. + * \return The sequence of existing Relay ops. + * \note RequantizationFP using floating computation. All multiplication/sub/sum + * occurs in floating point data type and only at the end is converted to + * int32 data type and clamped for output data type. + * + * The whole computation this can be broken down into following steps + * 1) Subtract the input zero point. + * 2) Perform multiplication. + * 3) Add the output zero point. + * 4) Cast to the out_dtype. + */ +template +Expr RequantizeLowerFP(const Expr& input_tensor, const Expr& input_scale, + const Expr& input_zero_point, const Expr& output_scale, + const Expr& output_zero_point, const RequantizeAttrs* param, + const Array& input_shape, const DataType& out_dtype) { + auto tensor = Cast(input_tensor, DataType::Float(Bits)); + auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0); + if (!IsEqualScalar(input_zero_point, zero_scalar)) { + // Broadcast input zero point if needed. + int rank = static_cast(input_shape.size()); + int axis = (param->axis < 0) ? ((rank > 0) ? rank + param->axis : 0) : param->axis; + Expr input_zero_broadcast = ExpandBiasToMatchAxis(Reshape(input_zero_point, + { + -1, + }), + rank, {axis}); + tensor = Subtract(Cast(tensor, DataType::Float(Bits)), + Cast(input_zero_broadcast, DataType::Float(Bits))); + } else { + tensor = Cast(tensor, DataType::Float(Bits)); + } + + // 2) If the input and output scales are same, we can skip the multiplication. Check + // if the input scale is per-tensor or per-channel. If it is per-tensor, there is single scale for + // the whole tensor. For per-channel (aka per-axis), there is a vector of scales for the input + // tensor. Depending on the quantization type, the fixed point multiplication routing is called. + auto scaled_fp_t = tensor; + double output_scale_float = GetScalarFromConstant(output_scale); + if (IsConstScalar(input_scale)) { + // This is per-tensor quantization. Single scale. + double input_scale_float = GetScalarFromConstant(input_scale); + double double_multiplier = static_cast(input_scale_float) / output_scale_float; + // Skip if input and output scales are same. + if (!IsEqualScalar(input_scale, output_scale)) { + double multiplier = double_multiplier; + auto m_scalar = MakeConstantScalar(DataType::Float(Bits), multiplier); + scaled_fp_t = Multiply(m_scalar, scaled_fp_t); + } + + } else { + // This is per-channel (per=axis) quantization. + std::vector double_multipliers; + auto input_axis_scales = GetFloatVectorFromConstant(input_scale); + double output_scale_float = GetScalarFromConstant(output_scale); + for (auto input_axis_scale : input_axis_scales) { + double multiplier = static_cast(input_axis_scale) / output_scale_float; + double_multipliers.push_back(multiplier); + } + int axis = param->axis; + axis = (axis == -1) ? input_shape.size() - 1 : axis; + + auto fixed_pt_multiplier_expr = MakeConstantTensor( + DataType::Float(Bits), {(int64_t)double_multipliers.size()}, double_multipliers); + size_t n_dim = input_shape.size(); + auto exp_fixed_pt_multiplier_expr = + ExpandBiasToMatchAxis(fixed_pt_multiplier_expr, n_dim, {axis}); + + scaled_fp_t = Multiply(scaled_fp_t, exp_fixed_pt_multiplier_expr); + } + + // 3) Add the output zero point. + auto shifted_fp_t = scaled_fp_t; + if (!IsEqualScalar(output_zero_point, zero_scalar)) { + shifted_fp_t = Add(shifted_fp_t, Cast(output_zero_point, DataType::Float(Bits))); + } + + if (param->rounding == "UPWARD") { + shifted_fp_t = Upward(shifted_fp_t); + } else /*if (param->rounding == "TONEAREST")*/ { + shifted_fp_t = Tonearest(shifted_fp_t); + } + + shifted_fp_t = Cast(shifted_fp_t, DataType::Int(32)); + // 4) Clip to the out_dtype min/max. Skip clipping if out_dtype is Int32. The fixed point + // multiplication keeps the value in int32 range. + if (out_dtype == DataType::Int(32)) { + return shifted_fp_t; + } + + auto q_min = GetQmin(out_dtype); + auto q_max = GetQmax(out_dtype); + auto clipped_t = Clip(shifted_fp_t, q_min, q_max); + return Cast(clipped_t, out_dtype); +} + +// Lowering of qnn.requantize op +/* + * \brief Lower requantize to a sequence of ops. + * \param input_tensor The input tensor to requantize op. + * \param param The requantize op attrs. + * \param input_shape The input tensor shape of the requantize op. + * \return The sequence of existing Relay ops. + */ +Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, + const Expr& input_zero_point, const Expr& output_scale, + const Expr& output_zero_point, const RequantizeAttrs* param, + const Array& input_shape, const DataType& out_dtype) { + // Check rounding validity. + ICHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST") + << "QNN requantize supports two rounding modes - UPWARD and " + << "TONEAREST"; + // Check compute_dtype validity. + ICHECK(param->compute_dtype == "int64" || param->compute_dtype == "float32" || + param->compute_dtype == "float64") + << "QNN requantize supports three compute_dtype variants - \"int64\", \"float32\" and " + "\"float64\""; + if (param->compute_dtype == "float32") { + return RequantizeLowerFP<32>(input_tensor, input_scale, input_zero_point, output_scale, + output_zero_point, param, input_shape, out_dtype); + } else if (param->compute_dtype == "float64") { + return RequantizeLowerFP<64>(input_tensor, input_scale, input_zero_point, output_scale, + output_zero_point, param, input_shape, out_dtype); + } else /*if (param->compute_dtype == "int64") */ { + return RequantizeLowerInt(input_tensor, input_scale, input_zero_point, output_scale, + output_zero_point, param, input_shape, out_dtype); + } +} + /* * \brief Forward rewrite the requantize op. * \param ref_call The original call that will be lowered. @@ -230,8 +427,15 @@ Expr RequantizeQnnCanonicalize(const Attrs& attrs, const Array& new_args, auto& output_scale = new_args[3]; auto& output_zero_point = new_args[4]; const auto* param = attrs.as(); + const RequantizeConfig& cfg = RequantizeConfig::Current(); + ICHECK(param != nullptr); + const_cast(param)->rounding = + SelectRequntizeParameter(param->rounding, cfg->get_rounding(), cfg->is_default, "rounding"); + const_cast(param)->compute_dtype = SelectRequntizeParameter( + param->compute_dtype, cfg->get_compute_dtype(), cfg->is_default, "compute_dtype"); + // Find input shape. ICHECK_EQ(types.size(), 6); auto in_type = types[0]; @@ -246,11 +450,6 @@ Expr RequantizeQnnCanonicalize(const Attrs& attrs, const Array& new_args, ICHECK(out_tensor_type != nullptr) << "Type information missing." << " Please run infer_type pass."; auto out_dtype = out_tensor_type->dtype; - - // Check rounding validity. - ICHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST") - << "QNN requantize supports two rounding modes - UPWARD and " - << "TONEAREST"; return RequantizeLower(quantized_data, input_scale, input_zero_point, output_scale, output_zero_point, param, input_shape, out_dtype); } @@ -317,11 +516,13 @@ bool RequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, // Positional relay function to create qnn requantize operator // used by frontend FFI. Expr MakeRequantize(Expr data, Expr input_scale, Expr input_zero_point, Expr output_scale, - Expr output_zero_point, int axis, String rounding, DataType out_dtype) { + Expr output_zero_point, int axis, String rounding, String compute_dtype, + DataType out_dtype) { auto attrs = make_object(); attrs->axis = axis; attrs->rounding = std::move(rounding); attrs->out_dtype = std::move(out_dtype); + attrs->compute_dtype = std::move(compute_dtype); static const Op& op = Op::Get("qnn.requantize"); return Call(op, {data, input_scale, input_zero_point, output_scale, output_zero_point}, Attrs(attrs), {}); diff --git a/src/relay/qnn/op/requantize_config.cc b/src/relay/qnn/op/requantize_config.cc new file mode 100644 index 000000000000..4a52f56400c9 --- /dev/null +++ b/src/relay/qnn/op/requantize_config.cc @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/qnn/op/requantize_config.cc + * \brief QNN requantize config. + */ + +#include "./requantize_config.h" + +#include +#include +#include + +#include + +namespace tvm { +namespace relay { +namespace qnn { + +/*! \brief Entry to hold the BuildConfig context stack. */ +struct TVMRequantizeConfigThreadLocalEntry { + /*! \brief The default build config if the stack is empty */ + RequantizeConfig default_config; + + /*! \brief The current build config context */ + std::stack context_stack; + + TVMRequantizeConfigThreadLocalEntry() : default_config(make_object(true)) {} +}; + +/*! \brief Thread local store to hold the BuildConfig context stack. */ +typedef dmlc::ThreadLocalStore + TVMRequantizeConfigThreadLocalStore; + +void RequantizeConfig::EnterRequantizeConfigScope(const RequantizeConfig& build_config) { + TVMRequantizeConfigThreadLocalEntry* entry = TVMRequantizeConfigThreadLocalStore::Get(); + entry->context_stack.push(build_config); +} + +void RequantizeConfig::ExitRequantizeConfigScope() { + TVMRequantizeConfigThreadLocalEntry* entry = TVMRequantizeConfigThreadLocalStore::Get(); + entry->context_stack.pop(); +} + +RequantizeConfig& RequantizeConfig::Current() { + TVMRequantizeConfigThreadLocalEntry* entry = TVMRequantizeConfigThreadLocalStore::Get(); + if (entry->context_stack.size() > 0) { + return entry->context_stack.top(); + } + + return entry->default_config; +} + +TVM_REGISTER_NODE_TYPE(RequantizeConfigNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* op = static_cast(ref.get()); + p->stream << "requantize_config("; + p->stream << "rounding==" << op->get_rounding() << ", "; + p->stream << "compute_dtype==" << op->get_compute_dtype(); + p->stream << ")"; + }); + +TVM_REGISTER_GLOBAL("relay._requantize._GetCurrentRequantizeConfig") + .set_body_typed([]() -> RequantizeConfig { return RequantizeConfig::Current(); }); + +TVM_REGISTER_GLOBAL("relay._requantize._EnterRequantizeConfigScope") + .set_body_typed(RequantizeConfig::EnterRequantizeConfigScope); + +TVM_REGISTER_GLOBAL("relay._requantize._ExitRequantizeConfigScope") + .set_body_typed(RequantizeConfig::ExitRequantizeConfigScope); + +} // namespace qnn +} // namespace relay +} // namespace tvm diff --git a/src/relay/qnn/op/requantize_config.h b/src/relay/qnn/op/requantize_config.h new file mode 100644 index 000000000000..f1cd9219c32b --- /dev/null +++ b/src/relay/qnn/op/requantize_config.h @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/qnn/op/requantize_config.h + * \brief QNN requantize config. + */ + +#ifndef TVM_RELAY_QNN_OP_REQUANTIZE_CONFIG_H_ +#define TVM_RELAY_QNN_OP_REQUANTIZE_CONFIG_H_ + +#include +#include +#include +#include +#include + +#include + +#include "../../op/op_common.h" + +namespace tvm { +namespace relay { +namespace qnn { + +class RequantizeConfig; +/*! + * \brief Container for build configuration options + */ +class RequantizeConfigNode : public Object { + std::string rounding; + std::string compute_dtype; + + public: + explicit RequantizeConfigNode(bool is_default = false) : is_default(is_default) {} + + std::string get_rounding() const { + if (!rounding.empty()) return rounding; + return "UPWARD"; + } + + std::string get_compute_dtype() const { + if (!compute_dtype.empty()) return compute_dtype; + + // For the x86 architecture, the float32 computation is expected to give significant speedup, + // with little loss in the accuracy of the requantize operation. + auto target = Target::Current(true); + auto target_has_sse41 = tvm::runtime::Registry::Get("tvm.topi.x86.utils.target_has_sse41"); + ICHECK(target_has_sse41) << "Function tvm.topi.x86.utils.target_has_sse41 not found"; + if (target.defined() && target->kind->name == "llvm" && + (target->GetAttr("mcpu") && + (*target_has_sse41)(target->GetAttr("mcpu").value()))) { + return "float32"; + } + + return "int64"; + } + + const bool is_default = false; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("rounding", &rounding); + v->Visit("compute_dtype", &compute_dtype); + } + + static constexpr const char* _type_key = "relay.qnn.op.RequantizeConfig"; + TVM_DECLARE_FINAL_OBJECT_INFO(RequantizeConfigNode, Object); +}; + +/*! + * \brief Container for build configuration options + */ +class RequantizeConfig : public ObjectRef { + public: + RequantizeConfig() {} + explicit RequantizeConfig(ObjectPtr n) : ObjectRef(n) {} + + const RequantizeConfigNode* operator->() const { + return static_cast(get()); + } + + RequantizeConfigNode* operator->() { return static_cast(get_mutable()); } + + /*! + * \brief Push a new BuildConfig context onto the thread local stack. + * \param build_config The configuration to set as the current context. + */ + static void EnterRequantizeConfigScope(const RequantizeConfig& requantize_config); + + /*! + * \brief Pop a build config off the thread local context stack, restoring the previous + * configuration as the current context. + */ + static void ExitRequantizeConfigScope(); + + /*! + * \brief Get the current BuildConfig context from thread local storage, or a default + * configuration if a BuildConfig scope has not been entered. + * \return The configuration that is the current context. + */ + static RequantizeConfig& Current(); + + using ContainerType = RequantizeConfigNode; +}; + +} // namespace qnn +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_QNN_OP_REQUANTIZE_CONFIG_H_ diff --git a/src/relay/qnn/utils.cc b/src/relay/qnn/utils.cc index 982efa0a61c1..7dfd788d96c6 100644 --- a/src/relay/qnn/utils.cc +++ b/src/relay/qnn/utils.cc @@ -199,6 +199,22 @@ Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector multipliers, return Cast(tensor, DataType::Int(32)); } +std::string SelectRequntizeParameter(const std::string& arg_value, const std::string& cfg_value, + const bool is_cfg_default, const std::string& name) { + if (arg_value == "None") { + return cfg_value; + } else { + if (!is_cfg_default && arg_value != cfg_value) { + DLOG(INFO) << "The value of parameter \"" << name + << "\" from the non-default requantize config will not be used. The value " + "provided from " + "requantize function argument will be used instead. The value used is \"" + << arg_value << "\"."; + } + return arg_value; + } +} + } // namespace qnn } // namespace relay } // namespace tvm diff --git a/src/relay/qnn/utils.h b/src/relay/qnn/utils.h index c8f3524d51ea..0f3645a9882a 100644 --- a/src/relay/qnn/utils.h +++ b/src/relay/qnn/utils.h @@ -35,6 +35,8 @@ #include #include +#include "./op/requantize_config.h" + namespace tvm { namespace relay { namespace qnn { @@ -98,13 +100,21 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, const Expr& output_zero_point, const RequantizeAttrs* param, const Array& input_shape, const DataType& out_dtype); +std::string SelectRequntizeParameter(const std::string& arg_value, const std::string& cfg_value, + const bool is_cfg_default, const std::string& name); + static inline Expr Requantize(const Expr& data, const Array& input_shape, const Expr& input_scale, const Expr& input_zero_point, const Expr& output_scale, const Expr& output_zero_point, - const DataType& out_dtype, const std::string& rounding = "UPWARD") { + const DataType& out_dtype, const std::string& rounding = "None", + const std::string& compute_dtype = "None") { auto attrs = make_object(); - attrs->rounding = std::move(rounding); attrs->out_dtype = std::move(out_dtype); + const RequantizeConfig& cfg = RequantizeConfig::Current(); + attrs->rounding = + SelectRequntizeParameter(rounding, cfg->get_rounding(), cfg->is_default, "rounding"); + attrs->compute_dtype = SelectRequntizeParameter(compute_dtype, cfg->get_compute_dtype(), + cfg->is_default, "compute_dtype"); return RequantizeLower(data, input_scale, input_zero_point, output_scale, output_zero_point, attrs.operator->(), input_shape, attrs->out_dtype); } diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index 16a23a4ba699..7d2657eb04f2 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -565,6 +565,11 @@ inline Expr Round(Expr x) { return Call(op, {x}, Attrs(), {}); } +inline Expr Floor(Expr x) { + static const Op& op = Op::Get("floor"); + return Call(op, {x}, Attrs(), {}); +} + inline Expr Clip(Expr x, double a_min, double a_max) { return MakeClip(x, a_min, a_max); } inline Expr FixedPointMultiply(Expr x, int32_t multiplier, int32_t shift) { @@ -662,11 +667,31 @@ static inline Expr Where(const Expr& condition, const Expr& x, const Expr& y) { return Call(op, {condition, x, y}); } +static inline Expr LogicalOr(const Expr& lhs, const Expr& rhs) { + static const Op& op = Op::Get("logical_or"); + return Call(op, {lhs, rhs}, Attrs(), {}); +} + static inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) { static const Op& op = Op::Get("greater_equal"); return Call(op, {lhs, rhs}, Attrs(), {}); } +static inline Expr Equal(const Expr& lhs, const Expr& rhs) { + static const Op& op = Op::Get("equal"); + return Call(op, {lhs, rhs}, Attrs(), {}); +} + +static inline Expr Less(const Expr& lhs, const Expr& rhs) { + static const Op& op = Op::Get("less"); + return Call(op, {lhs, rhs}, Attrs(), {}); +} + +static inline Expr IsFinite(const Expr x) { + static const Op& op = Op::Get("isfinite"); + return Call(op, {x}, Attrs(), {}); +} + static inline Expr Full(Expr fill_value, Array shape, DataType dtype) { return MakeFull(fill_value, CheckConstantShapeArrayInteger(shape), dtype); } diff --git a/tests/python/relay/test_op_qnn_requantize.py b/tests/python/relay/test_op_qnn_requantize.py index 0f512df25cdf..64306476dfe9 100644 --- a/tests/python/relay/test_op_qnn_requantize.py +++ b/tests/python/relay/test_op_qnn_requantize.py @@ -22,14 +22,15 @@ from tvm.contrib import graph_executor roundings = ["UPWARD", "TONEAREST"] +compute_dtypes = ["float32", "float64", "int64"] -def verify(mod, goldens): +def verify(mod, goldens, target="llvm"): with tvm.transform.PassContext(opt_level=3): - graph, lib, params = relay.build(mod, "llvm", params=None) + graph, lib, params = relay.build(mod, target, params=None) golden_data, golden_output = goldens rt_mod = graph_executor.create(graph, lib, device=tvm.cpu(0)) - rt_mod.set_input("quantized_data", golden_data) + rt_mod.set_input("input_data", golden_data) rt_mod.set_input(**params) rt_mod.run() res = rt_mod.get_output(0).numpy() @@ -44,10 +45,11 @@ def get_mod( output_scale, input_zero_point=0, output_zero_point=0, - rounding="TONEAREST", + rounding="None", + compute_dtype="None", axis=0, ): - quantized_data = relay.var("quantized_data", shape=data_shape, dtype=data_dtype) + input_data = relay.var("input_data", shape=data_shape, dtype=data_dtype) if isinstance(input_scale, float): input_scale_expr = relay.const(input_scale, "float32") else: @@ -59,13 +61,14 @@ def get_mod( input_zero_point_expr = relay.const(np.array(input_zero_point).astype("int32")) mod = relay.qnn.op.requantize( - quantized_data, + input_data, input_scale=input_scale_expr, input_zero_point=input_zero_point_expr, output_scale=relay.const(output_scale, "float32"), output_zero_point=relay.const(output_zero_point, "int32"), axis=axis, rounding=rounding, + compute_dtype=compute_dtype, out_dtype=out_dtype, ) @@ -78,327 +81,344 @@ def test_same_scale(): # Have same scales, everything within range golden_data = np.arange(-100, 100, 1).astype("int32") golden_output = golden_data - - for rounding in roundings: - mod = get_mod( - data_shape=(200,), - data_dtype="int32", - out_dtype="int8", - input_scale=0.5, - output_scale=0.5, - rounding=rounding, - ) - assert "right_shift" not in mod.astext() - verify(mod, (golden_data, golden_output)) + for compute_dtype in compute_dtypes: + for rounding in roundings: + mod = get_mod( + data_shape=(200,), + data_dtype="int32", + out_dtype="int8", + input_scale=0.5, + output_scale=0.5, + rounding=rounding, + compute_dtype=compute_dtype, + ) + assert "right_shift" not in mod.astext() + verify(mod, (golden_data, golden_output)) def test_scalar_same_scale(): # Have same scales, everything within range golden_data = np.array(-10).astype("int32") golden_output = golden_data - - for rounding in roundings: - mod = get_mod( - data_shape=(), - data_dtype="int32", - out_dtype="int8", - input_scale=0.5, - output_scale=0.5, - rounding=rounding, - ) - assert "right_shift" not in mod.astext() - verify(mod, (golden_data, golden_output)) + for compute_dtype in compute_dtypes: + for rounding in roundings: + mod = get_mod( + data_shape=(), + data_dtype="int32", + out_dtype="int8", + input_scale=0.5, + output_scale=0.5, + rounding=rounding, + compute_dtype=compute_dtype, + ) + assert "right_shift" not in mod.astext() + verify(mod, (golden_data, golden_output)) def test_downscale(): - for rounding in roundings: - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=1, - output_scale=16, - rounding=rounding, - ) - - # Try positive values - # 8 corresponds to 0.5, resulting in 1 - golden_data = np.arange(0, 32, 1).astype("int32") - golden_output = np.repeat([0, 1, 2], [8, 16, 8]) - verify(mod, (golden_data, golden_output)) - - # Try negative values - # -8 corresponds to -0.5. For UPWARD, this is 0 - golden_data = np.arange(0, -32, -1).astype("int32") - if rounding == "UPWARD": - golden_output = np.repeat([0, -1, -2], [9, 16, 7]) - else: - golden_output = np.repeat([0, -1, -2], [8, 16, 8]) - verify(mod, (golden_data, golden_output)) - - # Try a different scale - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=1, - output_scale=4, - rounding=rounding, - ) - - # Try positive values - # 2I corresponds to 0.5, resulting in 1 - golden_data = np.arange(0, 32, 1).astype("int32") - golden_output = np.repeat([0, 1, 2, 3, 4, 5, 6, 7, 8], [2, 4, 4, 4, 4, 4, 4, 4, 2]) - verify(mod, (golden_data, golden_output)) - - # Try negative values - # -8 corresponds to -0.5. For UPWARD, this is 0 - golden_data = np.arange(0, -32, -1).astype("int32") - if rounding == "UPWARD": - golden_output = np.repeat( - [0, -1, -2, -3, -4, -5, -6, -7, -8], [3, 4, 4, 4, 4, 4, 4, 4, 1] + for compute_dtype in compute_dtypes: + for rounding in roundings: + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype="int8", + input_scale=1, + output_scale=16, + rounding=rounding, + compute_dtype=compute_dtype, ) - else: - golden_output = np.repeat( - [0, -1, -2, -3, -4, -5, -6, -7, -8], [2, 4, 4, 4, 4, 4, 4, 4, 2] + + # Try positive values + # 8 corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype("int32") + golden_output = np.repeat([0, 1, 2], [8, 16, 8]) + verify(mod, (golden_data, golden_output)) + + # Try negative values + # -8 corresponds to -0.5. For UPWARD, this is 0 + golden_data = np.arange(0, -32, -1).astype("int32") + if rounding == "UPWARD": + golden_output = np.repeat([0, -1, -2], [9, 16, 7]) + else: + golden_output = np.repeat([0, -1, -2], [8, 16, 8]) + verify(mod, (golden_data, golden_output)) + + # Try a different scale + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype="int8", + input_scale=1, + output_scale=4, + rounding=rounding, ) - verify(mod, (golden_data, golden_output)) - - # Try uint8 out_dtype - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="uint8", - input_scale=1, - output_scale=16, - rounding=rounding, - ) - - # Try positive values - # 8 corresponds to 0.5, resulting in 1 - golden_data = np.arange(0, 32, 1).astype("int32") - golden_output = np.repeat([0, 1, 2], [8, 16, 8]) - verify(mod, (golden_data, golden_output)) - - # Try uint8 in_dtyope and uint8 out_dtype - mod = get_mod( - data_shape=(32,), - data_dtype="uint8", - out_dtype="uint8", - input_scale=1, - output_scale=16, - rounding=rounding, - ) - - # Try positive values - # 8 corresponds to 0.5, resulting in 1 - golden_data = np.arange(0, 32, 1).astype("int32") - golden_output = np.repeat([0, 1, 2], [8, 16, 8]) - verify(mod, (golden_data, golden_output)) + + # Try positive values + # 2I corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype("int32") + golden_output = np.repeat([0, 1, 2, 3, 4, 5, 6, 7, 8], [2, 4, 4, 4, 4, 4, 4, 4, 2]) + verify(mod, (golden_data, golden_output)) + + # Try negative values + # -8 corresponds to -0.5. For UPWARD, this is 0 + golden_data = np.arange(0, -32, -1).astype("int32") + if rounding == "UPWARD": + golden_output = np.repeat( + [0, -1, -2, -3, -4, -5, -6, -7, -8], [3, 4, 4, 4, 4, 4, 4, 4, 1] + ) + else: + golden_output = np.repeat( + [0, -1, -2, -3, -4, -5, -6, -7, -8], [2, 4, 4, 4, 4, 4, 4, 4, 2] + ) + verify(mod, (golden_data, golden_output)) + + # Try uint8 out_dtype + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype="uint8", + input_scale=1, + output_scale=16, + rounding=rounding, + ) + + # Try positive values + # 8 corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype("int32") + golden_output = np.repeat([0, 1, 2], [8, 16, 8]) + verify(mod, (golden_data, golden_output)) + + # Try uint8 in_dtyope and uint8 out_dtype + mod = get_mod( + data_shape=(32,), + data_dtype="uint8", + out_dtype="uint8", + input_scale=1, + output_scale=16, + rounding=rounding, + ) + + # Try positive values + # 8 corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype("int32") + golden_output = np.repeat([0, 1, 2], [8, 16, 8]) + verify(mod, (golden_data, golden_output)) def test_upscale(): - for rounding in roundings: - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=2, - output_scale=1, - rounding=rounding, - ) - - # Try positive values - # 8 corresponds to 0.5, resulting in 1 - golden_data = np.arange(0, 32, 1).astype("int32") - golden_output = np.multiply(2, golden_data) - verify(mod, (golden_data, golden_output)) - - # Try negative values - # -8 corresponds to -0.5. For UPWARD, this is 0 - golden_data = np.arange(0, -32, -1).astype("int32") - golden_output = np.multiply(2, golden_data) - verify(mod, (golden_data, golden_output)) + for compute_dtype in compute_dtypes: + for rounding in roundings: + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype="int8", + input_scale=2, + output_scale=1, + rounding=rounding, + compute_dtype=compute_dtype, + ) + + # Try positive values + # 8 corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype("int32") + golden_output = np.multiply(2, golden_data) + verify(mod, (golden_data, golden_output)) + + # Try negative values + # -8 corresponds to -0.5. For UPWARD, this is 0 + golden_data = np.arange(0, -32, -1).astype("int32") + golden_output = np.multiply(2, golden_data) + verify(mod, (golden_data, golden_output)) def test_non_power_of_two(): - for rounding in roundings: - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=1, - output_scale=3, - rounding=rounding, - ) - - # Try positive values - golden_data = np.multiply(np.arange(0, 32, 1).astype("int32"), 3) - golden_output = np.arange(0, 32, 1) - verify(mod, (golden_data, golden_output)) - - # Try negative values - golden_data = np.multiply(np.arange(0, -32, -1).astype("int32"), 3) - golden_output = np.arange(0, -32, -1) - verify(mod, (golden_data, golden_output)) - - # Try a different scale - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=3, - output_scale=1, - rounding=rounding, - ) - - # Try positive values - golden_data = np.arange(0, 32, 1).astype("int32") - golden_output = np.multiply(golden_data, 3) - verify(mod, (golden_data, golden_output)) - - # Try negative values - golden_data = np.arange(0, -32, -1).astype("int32") - golden_output = np.multiply(golden_data, 3) - verify(mod, (golden_data, golden_output)) + for compute_dtype in compute_dtypes: + for rounding in roundings: + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype="int8", + input_scale=1, + output_scale=3, + rounding=rounding, + compute_dtype=compute_dtype, + ) + + # Try positive values + golden_data = np.multiply(np.arange(0, 32, 1).astype("int32"), 3) + golden_output = np.arange(0, 32, 1) + verify(mod, (golden_data, golden_output)) + + # Try negative values + golden_data = np.multiply(np.arange(0, -32, -1).astype("int32"), 3) + golden_output = np.arange(0, -32, -1) + verify(mod, (golden_data, golden_output)) + + # Try a different scale + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype="int8", + input_scale=3, + output_scale=1, + rounding=rounding, + ) + + # Try positive values + golden_data = np.arange(0, 32, 1).astype("int32") + golden_output = np.multiply(golden_data, 3) + verify(mod, (golden_data, golden_output)) + + # Try negative values + golden_data = np.arange(0, -32, -1).astype("int32") + golden_output = np.multiply(golden_data, 3) + verify(mod, (golden_data, golden_output)) def test_saturation(): - for rounding in roundings: - mod = get_mod( - data_shape=(16,), - data_dtype="int32", - out_dtype="int8", - input_scale=0.5, - output_scale=0.5, - rounding=rounding, - ) - golden_data = np.arange(0, 16, 1).astype("int32") - golden_data = np.add(120, golden_data) - output = np.array( - [120, 121, 122, 123, 124, 125, 126, 127, 127, 127, 127, 127, 127, 127, 127, 127] - ) - golden_output = output - verify(mod, (golden_data, golden_output)) - - # Try negative numbers - golden_data = np.arange(0, -16, -1).astype("int32") - golden_data = np.add(-120, golden_data) - output = np.array( - [ - -120, - -121, - -122, - -123, - -124, - -125, - -126, - -127, - -128, - -128, - -128, - -128, - -128, - -128, - -128, - -128, - ] - ) - golden_output = output - verify(mod, (golden_data, golden_output)) + for compute_dtype in compute_dtypes: + for rounding in roundings: + mod = get_mod( + data_shape=(16,), + data_dtype="int32", + out_dtype="int8", + input_scale=0.5, + output_scale=0.5, + rounding=rounding, + compute_dtype=compute_dtype, + ) + golden_data = np.arange(0, 16, 1).astype("int32") + golden_data = np.add(120, golden_data) + output = np.array( + [120, 121, 122, 123, 124, 125, 126, 127, 127, 127, 127, 127, 127, 127, 127, 127] + ) + golden_output = output + verify(mod, (golden_data, golden_output)) + + # Try negative numbers + golden_data = np.arange(0, -16, -1).astype("int32") + golden_data = np.add(-120, golden_data) + output = np.array( + [ + -120, + -121, + -122, + -123, + -124, + -125, + -126, + -127, + -128, + -128, + -128, + -128, + -128, + -128, + -128, + -128, + ] + ) + golden_output = output + verify(mod, (golden_data, golden_output)) def test_zero_point(): # Output zero point - for rounding in roundings: - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=1, - output_scale=16, - output_zero_point=1, - rounding=rounding, - ) - - # Try positive values - # 8 corresponds to 0.5, resulting in 1 - golden_data = np.arange(0, 32, 1).astype("int32") - golden_output = np.repeat([0, 1, 2], [8, 16, 8]) - golden_output = np.add(1, golden_output) - verify(mod, (golden_data, golden_output)) - - # Try negative values - # -8 corresponds to -0.5. For UPWARD, this is 0 - golden_data = np.arange(-32, -64, -1).astype("int32") - if rounding == "UPWARD": - golden_output = np.repeat([-2, -3, -4], [9, 16, 7]) - else: - golden_output = np.repeat([-2, -3, -4], [8, 16, 8]) - golden_output = np.add(1, golden_output) - verify(mod, (golden_data, golden_output)) + for compute_dtype in compute_dtypes: + for rounding in roundings: + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype="int8", + input_scale=1, + output_scale=16, + output_zero_point=1, + rounding=rounding, + compute_dtype=compute_dtype, + ) + + # Try positive values + # 8 corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype("int32") + golden_output = np.repeat([0, 1, 2], [8, 16, 8]) + golden_output = np.add(1, golden_output) + verify(mod, (golden_data, golden_output)) + + # Try negative values + # -8 corresponds to -0.5. For UPWARD, this is 0 + golden_data = np.arange(-32, -64, -1).astype("int32") + if rounding == "UPWARD": + golden_output = np.repeat([-2, -3, -4], [9, 16, 7]) + else: + golden_output = np.repeat([-2, -3, -4], [8, 16, 8]) + golden_output = np.add(1, golden_output) + verify(mod, (golden_data, golden_output)) # Input zero point - for rounding in roundings: - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=1, - output_scale=16, - input_zero_point=16, - rounding=rounding, - ) - - # Try positive values - golden_data = np.arange(32, 64, 1).astype("int32") - golden_output = np.repeat([2, 3, 4], [8, 16, 8]) - golden_output = np.subtract(golden_output, 1) - verify(mod, (golden_data, golden_output)) - - # Try negative values - golden_data = np.arange(-32, -64, -1).astype("int32") - if rounding == "UPWARD": - golden_output = np.repeat([-2, -3, -4], [9, 16, 7]) - else: - golden_output = np.repeat([-2, -3, -4], [8, 16, 8]) - golden_output = np.subtract(golden_output, 1) - verify(mod, (golden_data, golden_output)) + for compute_dtype in compute_dtypes: + for rounding in roundings: + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype="int8", + input_scale=1, + output_scale=16, + input_zero_point=16, + rounding=rounding, + compute_dtype=compute_dtype, + ) + + # Try positive values + golden_data = np.arange(32, 64, 1).astype("int32") + golden_output = np.repeat([2, 3, 4], [8, 16, 8]) + golden_output = np.subtract(golden_output, 1) + verify(mod, (golden_data, golden_output)) + + # Try negative values + golden_data = np.arange(-32, -64, -1).astype("int32") + if rounding == "UPWARD": + golden_output = np.repeat([-2, -3, -4], [9, 16, 7]) + else: + golden_output = np.repeat([-2, -3, -4], [8, 16, 8]) + golden_output = np.subtract(golden_output, 1) + verify(mod, (golden_data, golden_output)) def test_per_channel_same_scale(): # Have same scales, everything within range golden_data = np.arange(-5, 5, 1).astype("int32").reshape((5, 2)) golden_output = golden_data - - for rounding in roundings: - mod = get_mod( - data_shape=(5, 2), - data_dtype="int32", - out_dtype="int8", - input_scale=[0.5, 0.5], - output_scale=0.5, - axis=1, - rounding=rounding, - ) - verify(mod, (golden_data, golden_output)) + for compute_dtype in compute_dtypes: + for rounding in roundings: + mod = get_mod( + data_shape=(5, 2), + data_dtype="int32", + out_dtype="int8", + input_scale=[0.5, 0.5], + output_scale=0.5, + axis=1, + rounding=rounding, + compute_dtype=compute_dtype, + ) + verify(mod, (golden_data, golden_output)) # Change axis golden_data = np.arange(-10, 10, 1).astype("int32").reshape((2, 2, 5)) golden_output = golden_data - for rounding in roundings: - mod = get_mod( - data_shape=(2, 2, 5), - data_dtype="int32", - out_dtype="int8", - input_scale=[0.5, 0.5], - output_scale=0.5, - axis=1, - rounding=rounding, - ) - verify(mod, (golden_data, golden_output)) + for compute_dtype in compute_dtypes: + for rounding in roundings: + mod = get_mod( + data_shape=(2, 2, 5), + data_dtype="int32", + out_dtype="int8", + input_scale=[0.5, 0.5], + output_scale=0.5, + axis=1, + rounding=rounding, + compute_dtype=compute_dtype, + ) + verify(mod, (golden_data, golden_output)) def test_per_channel_different_scale(): @@ -406,17 +426,19 @@ def test_per_channel_different_scale(): golden_data = np.arange(-5, 5, 1).astype("int32").reshape((5, 2)) golden_output = np.array([-5, -2, -3, -1, -1, 0, 1, 1, 3, 2]).reshape((5, 2)) - for rounding in roundings: - mod = get_mod( - data_shape=(5, 2), - data_dtype="int32", - out_dtype="int8", - input_scale=[0.5, 0.25], - output_scale=0.5, - axis=1, - rounding=rounding, - ) - verify(mod, (golden_data, golden_output)) + for compute_dtype in compute_dtypes: + for rounding in roundings: + mod = get_mod( + data_shape=(5, 2), + data_dtype="int32", + out_dtype="int8", + input_scale=[0.5, 0.25], + output_scale=0.5, + axis=1, + rounding=rounding, + compute_dtype=compute_dtype, + ) + verify(mod, (golden_data, golden_output)) # Change axis golden_data = np.arange(-20, 20, 2).astype("int32").reshape((2, 2, 5)) @@ -424,33 +446,113 @@ def test_per_channel_different_scale(): [-20, -18, -16, -14, -12, -5, -4, -3, -2, -1, 0, 2, 4, 6, 8, 5, 6, 7, 8, 9] ).reshape((2, 2, 5)) - for rounding in roundings: - mod = get_mod( - data_shape=(2, 2, 5), - data_dtype="int32", - out_dtype="int8", - input_scale=[0.5, 0.25], - output_scale=0.5, - axis=1, - rounding=rounding, - ) - verify(mod, (golden_data, golden_output)) + for compute_dtype in compute_dtypes: + for rounding in roundings: + mod = get_mod( + data_shape=(2, 2, 5), + data_dtype="int32", + out_dtype="int8", + input_scale=[0.5, 0.25], + output_scale=0.5, + axis=1, + rounding=rounding, + compute_dtype=compute_dtype, + ) + verify(mod, (golden_data, golden_output)) # Have input scale > output scale golden_data = np.arange(-5, 5, 1).astype("int32").reshape((5, 2)) golden_output = np.array([-10, -2, -6, -1, -2, 0, 2, 1, 6, 2]).reshape((5, 2)) + for compute_dtype in compute_dtypes: + for rounding in roundings: + mod = get_mod( + data_shape=(5, 2), + data_dtype="int32", + out_dtype="int8", + input_scale=[1.0, 0.25], + output_scale=0.5, + axis=1, + rounding=rounding, + compute_dtype=compute_dtype, + ) + verify(mod, (golden_data, golden_output)) + + +def test_default_cfg_and_no_args(): + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype="int8", + input_scale=1, + output_scale=16, + ) + golden_data = np.arange(0, -32, -1).astype("int32") + golden_output = np.repeat([0, -1, -2], [9, 16, 7]) + verify(mod, (golden_data, golden_output)) + + +def test_non_default_cfg_and_no_args(): + for rounding_cfg in roundings: + with relay.qnn.op.requantize_config(rounding=rounding_cfg): + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype="int8", + input_scale=1, + output_scale=16, + ) + + golden_data = np.arange(0, -32, -1).astype("int32") + + if rounding_cfg == "UPWARD": + golden_output = np.repeat([0, -1, -2], [9, 16, 7]) + else: + golden_output = np.repeat([0, -1, -2], [8, 16, 8]) + verify(mod, (golden_data, golden_output)) + + +def test_default_cfg_and_args(): for rounding in roundings: - mod = get_mod( - data_shape=(5, 2), - data_dtype="int32", - out_dtype="int8", - input_scale=[1.0, 0.25], - output_scale=0.5, - axis=1, - rounding=rounding, - ) - verify(mod, (golden_data, golden_output)) + with relay.qnn.op.requantize_config(rounding="UPWARD"): + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype="int8", + input_scale=1, + output_scale=16, + rounding=rounding, + ) + + golden_data = np.arange(0, -32, -1).astype("int32") + + if rounding == "UPWARD": + golden_output = np.repeat([0, -1, -2], [9, 16, 7]) + else: + golden_output = np.repeat([0, -1, -2], [8, 16, 8]) + verify(mod, (golden_data, golden_output)) + + +def test_non_default_cfg_and_args(): + for rounding_arg in roundings: + for rounding_cfg in roundings: + with relay.qnn.op.requantize_config(rounding=rounding_cfg): + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype="int8", + input_scale=1, + output_scale=16, + rounding=rounding_arg, + ) + + golden_data = np.arange(0, -32, -1).astype("int32") + + if rounding_arg == "UPWARD": + golden_output = np.repeat([0, -1, -2], [9, 16, 7]) + else: + golden_output = np.repeat([0, -1, -2], [8, 16, 8]) + verify(mod, (golden_data, golden_output)) if __name__ == "__main__": @@ -463,3 +565,7 @@ def test_per_channel_different_scale(): test_zero_point() test_per_channel_same_scale() test_per_channel_different_scale() + test_default_cfg_and_no_args() + test_non_default_cfg_and_no_args() + test_default_cfg_and_args() + test_non_default_cfg_and_args()