Skip to content

Commit

Permalink
Add FP requantize flow. Set float32 flow by default for llvm x86 targ…
Browse files Browse the repository at this point in the history
…ets with (apache#9637)

sse4.1 support
  • Loading branch information
Icemist authored and ylc committed Feb 16, 2022
1 parent 7e50fc4 commit c0ec226
Show file tree
Hide file tree
Showing 11 changed files with 1,060 additions and 340 deletions.
8 changes: 7 additions & 1 deletion include/tvm/relay/qnn/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ namespace qnn {
struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
int axis;
std::string rounding;
std::string compute_dtype;
DataType out_dtype;

TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") {
Expand All @@ -44,7 +45,7 @@ struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
"The output channel axis for channel wise quantization. Default value is -1,"
"which corresponds to the last axis.")
.set_default(-1);
TVM_ATTR_FIELD(rounding).set_default("UPWARD").describe(
TVM_ATTR_FIELD(rounding).set_default("None").describe(
"Defines the rounding direction when the value is midway between"
"two representable values. There are two supported modes - UPWARD"
"or TONEAREST. Both modes behave exactly same except at the"
Expand All @@ -54,6 +55,11 @@ struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
"value is rounded away from zero at midpoints (for example, -1.5"
"rounds to -2). More context can be found at following gblic manual"
"https://www.gnu.org/software/libc/manual/html_node/Rounding.html.");
TVM_ATTR_FIELD(compute_dtype)
.set_default("None")
.describe(
"Specifies the data type used during requantize. Supported "
"options: \"int64\", \"float32\", \"float64\"");
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
Expand Down
21 changes: 21 additions & 0 deletions python/tvm/relay/qnn/op/_requantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument
"""Internal module for qnn requantization."""
import tvm._ffi

tvm._ffi._init_api("relay._requantize", __name__)
102 changes: 98 additions & 4 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,109 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
# pylint: disable=invalid-name,unused-argument, not-context-manager
"""QNN dialect operators."""

from __future__ import absolute_import as _abs

import tvm
import tvm.ir
from tvm import relay
from tvm.runtime import Object
from tvm.relay.expr import Tuple, TupleWrapper
from tvm.relay.op.nn.utils import get_pad_tuple2d
from tvm.topi.nn.qnn import SQNN_DTYPE_TO_CODE

from tvm.target import Target
from tvm.topi.x86.utils import target_has_sse41
from ... import op as reg
from ...op import OpPattern
from . import _make
from . import _requantize


@tvm._ffi.register_object("relay.qnn.op.RequantizeConfig")
class RequantizeConfig(Object):
"""Configure the requantization behavior by setting config variables.
Note
----
This object is backed by node system in C++, with arguments that can be
exchanged between python and C++.
Do not construct directly, use requantize_config instead.
The fields that are backed by the C++ node are immutable once an instance
is constructed. Use _node_defaults getters to get results for the fields.
"""

@staticmethod
def _get_node_default_rounding():
return "UPWARD"

@staticmethod
def _get_node_default_compute_dtype():
target = Target.current(True)
if target and str(target.kind) == "llvm" and target_has_sse41(target.mcpu):
return "float32"

return "int64"

_node_defaults = {
"rounding": _get_node_default_rounding.__func__,
"compute_dtype": _get_node_default_compute_dtype.__func__,
}

# pylint: disable=no-member
def __init__(self, handle):
"""Initialize the function with handle
Parameters
----------
handle : SymbolHandle
the handle to the underlying C++ Symbol
"""
super(RequantizeConfig, self).__init__(handle)
self.handle = handle

def __enter__(self):
# pylint: disable=protected-access
_requantize._EnterRequantizeConfigScope(self)
return self

def __exit__(self, ptype, value, trace):
_requantize._ExitRequantizeConfigScope()

def __setattr__(self, name, value):
if name in RequantizeConfig._node_defaults:
raise AttributeError("'%s' object cannot set attribute '%s'" % (str(type(self)), name))
return super(RequantizeConfig, self).__setattr__(name, value)


def current_requantize_config():
"""Get the current requantization configuration."""
return _requantize._GetCurrentRequantizeConfig()


def requantize_config(**kwargs):
"""Configure the requantization behavior by setting config variables.
Parameters
---------
rounding: "UPWARD" or "TONEAREST"
Rounding direction for fixed point multiplications.
compute_dtype:
Specifies the data type used during requantize.
Supported options: \"int64\", \"float32\", \"float64\"
Returns
-------
config: RequantizeConfig
The requantization configuration
"""
node_args = {
k: v() if k not in kwargs else kwargs[k] for k, v in RequantizeConfig._node_defaults.items()
}
return tvm.ir.make_node("relay.qnn.op.RequantizeConfig", **node_args)


def requantize(
Expand All @@ -36,7 +126,8 @@ def requantize(
output_scale,
output_zero_point,
axis=-1,
rounding="UPWARD",
rounding="None",
compute_dtype="None",
out_dtype="int8",
):
r"""Requantized operator.
Expand Down Expand Up @@ -70,7 +161,9 @@ def requantize(
rounding : string, optional
Defines the rounding direction when the value is midway between two
representable values.
compute_dtype:
Specifies the data type used during requantize.
Supported options: \"int64\", \"float32\", \"float64\"
out_dtype : str, optional
Specifies the output data type.
Expand All @@ -88,6 +181,7 @@ def requantize(
output_zero_point,
axis,
rounding,
compute_dtype,
out_dtype,
)

Expand Down
22 changes: 22 additions & 0 deletions python/tvm/topi/x86/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,23 @@
import tvm


@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_sse41")
def target_has_sse41(target):
return (
target_has_sse42(target)
or target_has_avx(target)
or target_has_avx2(target)
or target_has_avx512(target)
or target_has_vnni(target)
or target
in {
"btver2",
"penryn",
}
)


@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_sse42")
def target_has_sse42(target):
return (
target_has_avx(target)
Expand All @@ -42,6 +59,7 @@ def target_has_sse42(target):
)


@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_avx")
def target_has_avx(target):
return (
target_has_avx2(target)
Expand All @@ -51,6 +69,7 @@ def target_has_avx(target):
)


@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_avx2")
def target_has_avx2(target):
return (
target_has_avx512(target)
Expand All @@ -70,6 +89,7 @@ def target_has_avx2(target):
)


@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_avx512")
def target_has_avx512(target):
return target in {
"skylake-avx512",
Expand All @@ -89,6 +109,7 @@ def target_has_avx512(target):
}


@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_vnni")
def target_has_vnni(target):
return target in {
"cascadelake",
Expand All @@ -102,6 +123,7 @@ def target_has_vnni(target):
}


@tvm._ffi.register_func("tvm.topi.x86.utils.get_simd_32bit_lanes")
def get_simd_32bit_lanes():
mcpu = tvm.target.Target.current().mcpu
fp32_vec_len = 4
Expand Down
Loading

0 comments on commit c0ec226

Please sign in to comment.