Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TFLite] Add support to int16 data type in TFLite frontend #10915

Merged
merged 2 commits into from
May 18, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def get_tensor_type_as_numpy(self, tensor_wrapper):
return {
TensorType.UINT8: np.uint8,
TensorType.INT8: np.int8,
TensorType.INT16: np.int16,
TensorType.FLOAT16: np.float16,
TensorType.FLOAT32: np.float32,
TensorType.INT32: np.int32,
Expand Down Expand Up @@ -430,6 +431,8 @@ def get_tensor_type_str(self, tensor_type):

if tensor_type == TensorType.INT8:
return "int8"
if tensor_type == TensorType.INT16:
return "int16"
if tensor_type == TensorType.UINT8:
return "uint8"
if tensor_type == TensorType.FLOAT16:
Expand Down Expand Up @@ -2149,7 +2152,9 @@ def convert_conv(self, op, conv_type):
qnn_conv2d_params = dict(params)
qnn_conv2d_params["input_zero_point"] = input_tensor.qnn_params["zero_point"]
qnn_conv2d_params["kernel_zero_point"] = weight_tensor.qnn_params["zero_point"]
qnn_conv2d_params["out_dtype"] = "int32"
qnn_conv2d_params["out_dtype"] = (
"int64" if output_tensor_type_str == "int16" else "int32"
)
qnn_conv2d_params["input_scale"] = input_tensor.qnn_params["scale"]
qnn_conv2d_params["kernel_scale"] = weight_tensor.qnn_params["scale"]
out = _qnn.op.conv2d(in_expr, weight_expr, **qnn_conv2d_params)
Expand All @@ -2160,8 +2165,8 @@ def convert_conv(self, op, conv_type):
if len(input_tensors) == 3:
bias_tensor = input_tensors[2]
bias_tensor_type = bias_tensor.tensor.Type()
# bias tensor type should be INT32 (quantization) or FLOAT32
assert bias_tensor_type in (TensorType.INT32, TensorType.FLOAT32)
# bias tensor type should be INT32 (int8 qnn) or INT64 (int16 qnn) or FLOAT32
assert bias_tensor_type in (TensorType.INT32, TensorType.INT64, TensorType.FLOAT32)
bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type)
if self.has_expr(bias_tensor.tensor_idx):
bias_expr = self.get_expr(bias_tensor.tensor_idx)
Expand Down
50 changes: 30 additions & 20 deletions src/relay/qnn/op/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,14 @@ bool QnnConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
if (data == nullptr || weight == nullptr) return false;
const auto* param = attrs.as<Conv2DAttrs>();
ICHECK(param != nullptr) << "Conv2DAttrs cannot be nullptr.";
ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8))
<< "Expected qnn conv2d type(int8, uint8) for input but was " << data->dtype;
ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8) ||
data->dtype == DataType::Int(16))
<< "Expected qnn conv2d type(int8, uint8, int16) for input but was " << data->dtype;
leandron marked this conversation as resolved.
Show resolved Hide resolved
ICHECK(weight->dtype == DataType::Int(8) || weight->dtype == DataType::UInt(8))
<< "Expected qnn conv2d type(int8, uint8) for weight but was " << weight->dtype;
ICHECK(param->out_dtype == DataType::Int(16) || param->out_dtype == DataType::Int(32))
<< "Expected qnn conv2d type(int32, int16) for output but was " << param->out_dtype;
<< "Expected qnn conv2d type(int8, uint8, int16) for weight but was " << weight->dtype;
leandron marked this conversation as resolved.
Show resolved Hide resolved
ICHECK(param->out_dtype == DataType::Int(16) || param->out_dtype == DataType::Int(32) ||
param->out_dtype == DataType::Int(64))
<< "Expected qnn conv2d type(int16, int32, int64) for output but was " << param->out_dtype;
ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";

// Check the types of scale and zero points.
Expand Down Expand Up @@ -190,19 +192,21 @@ WorkloadType GetWorkload(const Array<tvm::relay::Type>& arg_types, const Conv2DA
*/
Expr Conv2DFallBack(const Expr& data, const Expr& weight, const Expr& input_zero_point,
const Expr& kernel_zero_point, const Conv2DAttrs* param) {
// Upcast the zero point to Int16.
auto zp_data = Cast(input_zero_point, DataType::Int(16));
auto zp_kernel = Cast(kernel_zero_point, DataType::Int(16));
// Upcast the parameters to be at least int32 to avoid overflow
auto upcast_bits = param->out_dtype.bits() < 32 ? 32 : param->out_dtype.bits();

auto shifted_data = Cast(data, DataType::Int(16));
auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0);
auto zp_data = Cast(input_zero_point, DataType::Int(upcast_bits));
auto zp_kernel = Cast(kernel_zero_point, DataType::Int(upcast_bits));

auto shifted_data = Cast(data, DataType::Int(upcast_bits));
auto zero_scalar = MakeConstantScalar(DataType::Int(upcast_bits), 0);
if (!IsEqualScalar(input_zero_point, zero_scalar)) {
shifted_data = Subtract(Cast(data, DataType::Int(16)), zp_data);
shifted_data = Subtract(Cast(data, DataType::Int(upcast_bits)), zp_data);
}

auto shifted_kernel = Cast(weight, DataType::Int(16));
auto shifted_kernel = Cast(weight, DataType::Int(upcast_bits));
if (!IsEqualScalar(kernel_zero_point, zero_scalar)) {
shifted_kernel = Subtract(Cast(weight, DataType::Int(16)), zp_kernel);
shifted_kernel = Subtract(Cast(weight, DataType::Int(upcast_bits)), zp_kernel);
}

return Conv2D(shifted_data, shifted_kernel, param->strides, param->padding, param->dilation,
Expand Down Expand Up @@ -557,17 +561,19 @@ Expr Conv2DThirdTerm(const Expr& weight, const Expr& input_zero_point, const Con
* \param in_channels The number of input channels.
* \param kernel_h The height of kernel.
* \param kernel_w The width of kernel.
* \param param The qnn conv2d attributes.
* \return The sequence of Relay operators for term4.
* \note The term4 looks like this
*
* Sigma(c,r,s) zp_a * zp_w
*
*/
Expr Conv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_int, int in_channels,
int kernel_h, int kernel_w) {
int kernel_h, int kernel_w, const Conv2DAttrs* param) {
auto upcast_bits = param->out_dtype.bits() < 32 ? 32 : param->out_dtype.bits();
int scalar_term4 =
input_zero_point_int * kernel_zero_point_int * in_channels * kernel_h * kernel_w;
return MakeConstantScalar(DataType::Int(32), scalar_term4);
return MakeConstantScalar(DataType::Int(upcast_bits), scalar_term4);
}

/*
Expand All @@ -578,15 +584,18 @@ Expr Conv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_int, int i
* \param in_channels The number of input channels.
* \param kernel_h The height of kernel.
* \param kernel_w The width of kernel.
* \param param The qnn conv2d attributes.
* \return The sequence of Relay operators for term4.
* \note The term4 looks like this
*
* Sigma(c,r,s) zp_a * zp_w
*
*/
Expr Conv2DFourthTerm(const Expr& input_zero_point, const Expr& kernel_zero_point, int in_channels,
int kernel_h, int kernel_w) {
Expr scalar_term4 = MakeConstantScalar(DataType::Int(32), in_channels * kernel_h * kernel_w);
int kernel_h, int kernel_w, const Conv2DAttrs* param) {
auto upcast_bits = param->out_dtype.bits() < 32 ? 32 : param->out_dtype.bits();
Expr scalar_term4 =
MakeConstantScalar(DataType::Int(upcast_bits), in_channels * kernel_h * kernel_w);
Expr variable_term4 = Multiply(input_zero_point, kernel_zero_point);
return Multiply(scalar_term4, variable_term4);
}
Expand Down Expand Up @@ -791,10 +800,11 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
auto term3 = Conv2DThirdTerm(weight, input_zero_point, param, out_channels);
Expr term4;
if (dynamic_zp) {
term4 = Conv2DFourthTerm(input_zero_point, kernel_zero_point, in_channels, kernel_h, kernel_w);
term4 = Conv2DFourthTerm(input_zero_point, kernel_zero_point, in_channels, kernel_h, kernel_w,
param);
} else {
term4 = Conv2DFourthTerm(input_zero_point_int, kernel_zero_point_int, in_channels, kernel_h,
kernel_w);
kernel_w, param);
}
return Conv2DCombineTerms(term1, term2, term3, term4, input_zero_point_int,
kernel_zero_point_int);
Expand Down Expand Up @@ -829,7 +839,7 @@ This operator convolves quantized weight with quantized data. The scale of the
output quantized tensor is the product of the weight_scale and input_scale of
the input quantized tensors. The zero point of the output quantized tensor is
0. By default, the dtype of output is int32. Please also refer to Requantize
operator to understand how to scale back the int32 output to (u)int8.
operator to understand how to scale back the int32 output to (u)int8 or (u)int16.
leandron marked this conversation as resolved.
Show resolved Hide resolved
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
(batch_size, in_channels, height, width) if `layout` is `NCHW`.
- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
Expand Down
4 changes: 2 additions & 2 deletions src/relay/qnn/op/dequantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ bool DequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,

const auto input_dtype = data->dtype;
ICHECK(input_dtype == DataType::Int(8) || input_dtype == DataType::UInt(8) ||
input_dtype == DataType::Int(32))
<< "Input type should be one of the quantized types [unit8, int8, int32] but was "
input_dtype == DataType::Int(16) || input_dtype == DataType::Int(32))
<< "Input type should be one of the quantized types [unit8, int8, int16, int32] but was "
<< input_dtype;

const auto* dequantize_attrs = attrs.as<DequantizeAttrs>();
Expand Down
4 changes: 2 additions & 2 deletions src/relay/qnn/op/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ bool QuantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const Array<tvm::PrimExpr> oshape = data->shape;
const DataType out_dtype = quantize_attrs->out_dtype;
ICHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) ||
out_dtype == DataType::Int(32))
<< "Output type should be one of [int8, unit8, int32] but was " << out_dtype;
out_dtype == DataType::Int(16) || out_dtype == DataType::Int(32))
<< "Output type should be one of [int8, unit8, int16, int32] but was " << out_dtype;
// assign output type
reporter->Assign(types[3], TensorType(oshape, out_dtype));
return true;
Expand Down
8 changes: 4 additions & 4 deletions src/relay/qnn/op/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -480,8 +480,8 @@ bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
}
const auto in_dtype = data->dtype;
ICHECK(in_dtype == DataType::Int(8) || in_dtype == DataType::UInt(8) ||
in_dtype == DataType::Int(32))
<< "Input type should be one of [int8, uint8, int32] but was " << in_dtype;
in_dtype == DataType::Int(32) || in_dtype == DataType::Int(64))
<< "Input type should be one of [int8, uint8, int32, int64] but was " << in_dtype;

const RequantizeAttrs* requantize_attrs = attrs.as<RequantizeAttrs>();
int axis = requantize_attrs->axis;
Expand All @@ -507,8 +507,8 @@ bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
// assign output type
auto out_dtype = requantize_attrs->out_dtype;
ICHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) ||
out_dtype == DataType::Int(32))
<< "Output type should be one of [int8, uint8, int32] but was " << out_dtype;
out_dtype == DataType::Int(16) || out_dtype == DataType::Int(32))
<< "Output type should be one of [int8, uint8, int16, int32] but was " << out_dtype;
reporter->Assign(types[5], TensorType(oshape, out_dtype));
return true;
}
Expand Down
Loading