Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QNN] Use Int16 upcast in Fallback Conv2D. Fix test names. #4329

Merged
merged 1 commit into from
Nov 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 22 additions & 19 deletions src/relay/qnn/op/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,25 +106,26 @@ WorkloadType GetWorkload(const Array<tvm::relay::Type>& arg_types, const QnnConv
* \brief Fallback to simpler lowering for dilation or depthwise conv.
* \param data The input expr.
* \param weight The weight expr.
* \param zp_data The data zero point expr.
* \param zp_kernel The kernel zero point expr.
* \param param The qnn conv2d attributes.
* \return The fallback lowered sequence of Relay expr.
* \note In case of dilation, normal lowering would require a dilated pool.
* Since, we don't have dilated pool, we fallback to a simpler sequence of
* Relay operations. This will potentially lead to performance degradation
* as the convolution is called on int32 tensors instead of int8 tensors.
*/
Expr Conv2DFallBack(const Expr& data, const Expr& weight, const Expr& zp_data,
const Expr& zp_kernel, const QnnConv2DAttrs* param) {
auto shifted_data = data;
Expr Conv2DFallBack(const Expr& data, const Expr& weight, const QnnConv2DAttrs* param) {
// Upcast the zero point to Int16.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty straightforward :)

auto zp_data = MakeConstantScalar(Int(16), param->input_zero_point);
auto zp_kernel = MakeConstantScalar(Int(16), param->kernel_zero_point);

auto shifted_data = Cast(data, Int(16));
if (param->input_zero_point != 0) {
shifted_data = Subtract(Cast(data, Int(32)), zp_data);
shifted_data = Subtract(Cast(data, Int(16)), zp_data);
}

auto shifted_kernel = weight;
auto shifted_kernel = Cast(weight, Int(16));
if (param->kernel_zero_point != 0) {
shifted_kernel = Subtract(Cast(weight, Int(32)), zp_kernel);
shifted_kernel = Subtract(Cast(weight, Int(16)), zp_kernel);
}

return Conv2D(shifted_data, shifted_kernel, param->strides, param->padding, param->dilation,
Expand Down Expand Up @@ -186,7 +187,6 @@ Expr Conv2DFirstTerm(const Expr& padded_data, const Expr& weight, const QnnConv2
/*
* \brief Calculates the second term in the qnn.conv2d lowering sequence.
* \param padded_data The padded data expr.
* \param zp_kernel The kernel zero point expr.
* \param param The qnn conv2d attributes.
* \param kernel_h The height of kernel.
* \param kernel_w The width of kernel.
Expand All @@ -200,8 +200,11 @@ Expr Conv2DFirstTerm(const Expr& padded_data, const Expr& weight, const QnnConv2
* followed by a reduce on the C axis. Using avg_pool2d also gives an
* opportunity to reuse alter_op_layout infrastructure.
*/
Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& zp_kernel, const QnnConv2DAttrs* param,
int kernel_h, int kernel_w, int out_channels) {
Expr Conv2DSecondTerm(const Expr& padded_data, const QnnConv2DAttrs* param, int kernel_h,
int kernel_w, int out_channels) {
// Constant Expr for the kernel zero point.
auto zp_kernel = MakeConstantScalar(Int(32), param->kernel_zero_point);

auto casted_t2 = Cast(padded_data, Int(32));

// We can reduce the H and W axis by using avg_pool2d. However, avg_pool2d averages the sum.
Expand Down Expand Up @@ -241,7 +244,6 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& zp_kernel, const QnnC
/*
* \brief Calculates the third term in the qnn.conv2d lowering sequence.
* \param weight The weight expr.
* \param zp_data The data zero point expr.
* \param param The qnn conv2d attributes.
* \param batch_size The batch size.
* \param out_channels The number of output channels.
Expand All @@ -254,8 +256,11 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& zp_kernel, const QnnC
* a 1D tensor. The tensor is then reshaped to conform to NHWC/NCHW
* format.
*/
Expr Conv2DThirdTerm(const Expr& weight, const Expr& zp_data, const QnnConv2DAttrs* param,
int batch_size, int out_channels) {
Expr Conv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int batch_size,
int out_channels) {
// Constant expr for input zero point.
auto zp_data = MakeConstantScalar(Int(32), param->input_zero_point);

// Find which dimensions are C, R, S.
Array<Integer> axes_t3;
if (param->kernel_layout == "OIHW") {
Expand Down Expand Up @@ -415,21 +420,19 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
int batch_size, in_channels, out_channels, kernel_h, kernel_w;
std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w) =
GetWorkload(arg_types, param);
auto zp_data = MakeConstantScalar(Int(32), param->input_zero_point);
auto zp_kernel = MakeConstantScalar(Int(32), param->kernel_zero_point);

// Fallback to int32 conv if there is dilation or depthwise conv2d
CHECK_EQ(param->dilation.size(), 2) << "qnn.conv2d only supports 2D dilation";
auto dilation_h = get_const_int(param->dilation[0]);
auto dilation_w = get_const_int(param->dilation[1]);
if (dilation_h != 1 || dilation_w != 1 || param->groups != 1) {
return Conv2DFallBack(data, weight, zp_data, zp_kernel, param);
return Conv2DFallBack(data, weight, param);
}

auto padded_data = Conv2DPadInput(data, param);
auto term1 = Conv2DFirstTerm(padded_data, weight, param);
auto term2 = Conv2DSecondTerm(padded_data, zp_kernel, param, kernel_h, kernel_w, out_channels);
auto term3 = Conv2DThirdTerm(weight, zp_data, param, batch_size, out_channels);
auto term2 = Conv2DSecondTerm(padded_data, param, kernel_h, kernel_w, out_channels);
auto term3 = Conv2DThirdTerm(weight, param, batch_size, out_channels);
auto term4 = Conv2DFourthTerm(param, batch_size, in_channels, kernel_h, kernel_w);
return Conv2DCombineTerms(term1, term2, term3, term4, param);
}
Expand Down
52 changes: 26 additions & 26 deletions tests/python/relay/test_op_qnn_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def get_output(func, golden_inputs):
qnn_output = get_output(qnn_func, golden_inputs)
np.testing.assert_equal(qnn_output, golden_output)

def no_zero_point_test():
def test_no_zero_point():
# uint8 input
data_shape = (2, 1, 2, 4)
data_dtype = 'uint8'
Expand Down Expand Up @@ -203,7 +203,7 @@ def no_zero_point_test():
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

def kernel_zero_point_test():
def test_kernel_zero_point():
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
Expand Down Expand Up @@ -247,7 +247,7 @@ def kernel_zero_point_test():
kernel_shape, kernel_dtype)


def input_zero_point_test():
def test_input_zero_point():
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
Expand Down Expand Up @@ -290,7 +290,7 @@ def input_zero_point_test():
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

def both_zero_point_test():
def test_both_zero_point():
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
Expand Down Expand Up @@ -333,7 +333,7 @@ def both_zero_point_test():
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

def layout_test():
def test_layout():
# uint8 input
data_shape = (2, 2, 4, 4) # NHWC
data_dtype = 'uint8'
Expand Down Expand Up @@ -378,7 +378,7 @@ def layout_test():



def padding_test():
def test_padding():
# uint8 input
data_shape = (1, 4, 2, 2)
data_dtype = 'uint8'
Expand Down Expand Up @@ -421,7 +421,7 @@ def padding_test():
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

def dilation_test():
def test_dilation():
# uint8 input
data_shape = (2, 4, 4, 4)
data_dtype = 'uint8'
Expand All @@ -444,7 +444,7 @@ def dilation_test():
kernel_shape, kernel_dtype)


def const_folding_test():
def test_const_folding():
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
kernel_shape = (3, 4, 2, 2)
Expand All @@ -470,7 +470,7 @@ def const_folding_test():
folded_func = folded_mod["main"]
assert "reshape" not in folded_func.astext()

def kernel_size_1x1_test():
def test_kernel_size_1x1():
# uint8 input
data_shape = (2, 4, 2, 4)
data_dtype = 'uint8'
Expand All @@ -493,7 +493,7 @@ def kernel_size_1x1_test():
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

def tflite_large_irregular_test():
def test_tflite_large_irregular():
# uint8 input
data_shape = (1, 1024, 1, 1)
data_dtype = 'uint8'
Expand Down Expand Up @@ -526,7 +526,7 @@ def tflite_large_irregular_test():
golden_output = np.full((1, 1001, 1, 1), 0).astype('uint8')
np.testing.assert_equal(qnn_output, golden_output)

def tflite_output_multiplier_greater_than_one():
def test_tflite_output_multiplier_greater_than_one():
# uint8 input
data_shape = (2, 1, 2, 4)
data_dtype = 'uint8'
Expand Down Expand Up @@ -570,7 +570,7 @@ def tflite_output_multiplier_greater_than_one():
0, 0)).reshape(2, 3, 1, 2)
np.testing.assert_equal(qnn_output, golden_output)

def tflite_anistropic_strides():
def test_tflite_anistropic_strides():
# uint8 input
data_shape = (1, 1, 3, 6)
data_dtype = 'uint8'
Expand Down Expand Up @@ -607,7 +607,7 @@ def tflite_anistropic_strides():
golden_output = np.array((124, -92, 164, -132)).reshape(1, 1, 2, 2)
np.testing.assert_equal(qnn_output, golden_output)

def broadcast_layout_test():
def test_broadcast_layout():
# Test broadcast support for NHWC layout.
data_shape = (1, 229, 229, 3) # NHWC
data_dtype = 'uint8'
Expand Down Expand Up @@ -641,16 +641,16 @@ def broadcast_layout_test():
graph, lib, params = relay.build(mod, "llvm -mcpu=skylake-avx512")

if __name__ == "__main__":
no_zero_point_test()
input_zero_point_test()
kernel_zero_point_test()
both_zero_point_test()
layout_test()
padding_test()
dilation_test()
const_folding_test()
kernel_size_1x1_test()
tflite_large_irregular_test()
tflite_output_multiplier_greater_than_one()
tflite_anistropic_strides()
broadcast_layout_test()
test_no_zero_point()
test_input_zero_point()
test_kernel_zero_point()
test_both_zero_point()
test_layout()
test_padding()
test_dilation()
test_const_folding()
test_kernel_size_1x1()
test_tflite_large_irregular()
test_broadcast_layout()
test_tflite_output_multiplier_greater_than_one()
test_tflite_anistropic_strides()
Loading