Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QNN] Optimize requantize for power of 2 and fix dequantize for per-channel quantized input #6675

Merged
merged 5 commits into from
Oct 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/relay/qnn/op/dequantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale,
expanded_input_zero_point = ExpandBiasToMatchAxis(input_zero_point, n_dim, {axis});
}

auto shift = Subtract(Cast(input_tensor, DataType::Int(32)), input_zero_point);
auto scaled_output = Multiply(Cast(shift, DataType::Float(32)), input_scale);
auto shift = Subtract(Cast(input_tensor, DataType::Int(32)), expanded_input_zero_point);
auto scaled_output = Multiply(Cast(shift, DataType::Float(32)), expanded_input_scale);
return scaled_output;
}

Expand Down
93 changes: 62 additions & 31 deletions src/target/intrin_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,37 +128,68 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.q_multiply_shift")
PrimExpr q = call->args[2];
PrimExpr s = call->args[3];

// Only int32 types are supported (any number of lanes is allowed)
ICHECK(y.dtype().code() == DLDataTypeCode::kDLInt && y.dtype().bits() == 32);
ICHECK(s.dtype().code() == DLDataTypeCode::kDLInt && s.dtype().bits() == 32);

DataType hp_dtype = DataType::Int(64, x.dtype().lanes());
DataType lp_dtype = DataType::Int(32, x.dtype().lanes());

// 1) Calculating the integer multiplier and integer shift
PrimExpr zero = make_const(s.dtype(), 0);
PrimExpr left_shift = tir::Select(s > zero, s, zero);
PrimExpr right_shift = tir::Select(s > zero, zero, -s);

// 2) Cast and Multiply the integer multiplier
PrimExpr one = make_const(hp_dtype, 1);
x = cast(hp_dtype, x);
y = cast(hp_dtype, y);
x = tir::Select(left_shift != zero, x << left_shift, x);

// 3) Perform the multiplication in higher precision.
x = x * y;

// 4) Find the rounding scalar
PrimExpr total_right_shift = right_shift + q;
PrimExpr pos_rounding_value = (one << (total_right_shift - 1));
x = x + pos_rounding_value;

// 5) Simply right shift the result to get the final output.
x = x >> total_right_shift;

// 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32.
*rv = cast(lp_dtype, x);
// Lambda function to extract the int value from PrimExpr
auto get_int_value = [](const PrimExpr node) {
if (auto int_node = node.as<IntImmNode>()) {
return int_node->value;
}
auto broadcast_node = node.as<BroadcastNode>();
CHECK(broadcast_node != nullptr);
auto int_node = broadcast_node->value.as<IntImmNode>();
CHECK(int_node != nullptr);
return int_node->value;
};
// Power of 2 is determined by the fixed_point_multiplier == 1 << 30. In case of power of 2,
// fixed point multiplier will represent a float value of 0.5. In fixed point, this is
// represented by 1 << 30.
if (get_int_value(y) == (1 << 30)) {
PrimExpr exp = s - 1;
int exp_val = get_int_value(s) - 1;
if (exp_val > 0) {
// power of 2 is greater than 0, apply left shift.
*rv = x << exp;
} else {
// power of 2 is less than 0, round and then apply right shift.
DataType lp_dtype = DataType::Int(32, x.dtype().lanes());
PrimExpr one = make_const(lp_dtype, 1);
exp = -exp;
PrimExpr rounding_factor = one << (exp - 1);
PrimExpr rounded_t = x + rounding_factor;
*rv = rounded_t >> exp;
}
} else {
// Only int32 types are supported (any number of lanes is allowed)
ICHECK(y.dtype().code() == DLDataTypeCode::kDLInt && y.dtype().bits() == 32);
ICHECK(s.dtype().code() == DLDataTypeCode::kDLInt && s.dtype().bits() == 32);

DataType hp_dtype = DataType::Int(64, x.dtype().lanes());
DataType lp_dtype = DataType::Int(32, x.dtype().lanes());

// 1) Calculating the integer multiplier and integer shift
PrimExpr zero = make_const(s.dtype(), 0);
PrimExpr left_shift = tir::Select(s > zero, s, zero);
PrimExpr right_shift = tir::Select(s > zero, zero, -s);

// 2) Cast and Multiply the integer multiplier
PrimExpr one = make_const(hp_dtype, 1);
x = cast(hp_dtype, x);
y = cast(hp_dtype, y);
x = tir::Select(left_shift != zero, x << left_shift, x);

// 3) Perform the multiplication in higher precision.
x = x * y;

// 4) Find the rounding scalar
PrimExpr total_right_shift = right_shift + q;
PrimExpr pos_rounding_value = (one << (total_right_shift - 1));
x = x + pos_rounding_value;

// 5) Simply right shift the result to get the final output.
x = x >> total_right_shift;

// 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32.
*rv = cast(lp_dtype, x);
}
});

} // namespace intrin
Expand Down
18 changes: 18 additions & 0 deletions tests/python/relay/test_op_qnn_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,26 @@ def test_channelwise_axis_1():
)


def test_channelwise_axis_0():
data = np.array([0, 1, 2, 3, 4, 243, 247, 249, 250, 251]).astype("uint8").reshape((2, 5))
output = (
np.array([-63.5, -63, -62.5, -62, -61.5, 30, 31, 31.5, 31.75, 32])
.astype("float32")
.reshape((2, 5))
)
quant_args = {
"in_zero_point": np.array([127, 123]).astype("int32"),
"in_scale": np.array([0.5, 0.25]).astype("float32"),
}

dequantize_test_driver(
in_dtype="uint8", quant_args=quant_args, in_data=data, verify_output_data=output, axis=0
)


if __name__ == "__main__":
test_uint8_to_float32()
test_int8_to_float32()
test_int32_to_float32()
test_channelwise_axis_1()
test_channelwise_axis_0()
43 changes: 43 additions & 0 deletions tests/python/relay/test_op_qnn_requantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,48 @@ def test_upscale():
verify(mod, (golden_data, golden_output))


def test_non_power_of_two():
for rounding in roundings:
mod = get_mod(
data_shape=(32,),
data_dtype="int32",
out_dtype="int8",
input_scale=1,
output_scale=3,
rounding=rounding,
)

# Try positive values
golden_data = np.multiply(np.arange(0, 32, 1).astype("int32"), 3)
golden_output = np.arange(0, 32, 1)
verify(mod, (golden_data, golden_output))

# Try negative values
golden_data = np.multiply(np.arange(0, -32, -1).astype("int32"), 3)
golden_output = np.arange(0, -32, -1)
verify(mod, (golden_data, golden_output))

# Try a different scale
mod = get_mod(
data_shape=(32,),
data_dtype="int32",
out_dtype="int8",
input_scale=3,
output_scale=1,
rounding=rounding,
)

# Try positive values
golden_data = np.arange(0, 32, 1).astype("int32")
golden_output = np.multiply(golden_data, 3)
verify(mod, (golden_data, golden_output))

# Try negative values
golden_data = np.arange(0, -32, -1).astype("int32")
golden_output = np.multiply(golden_data, 3)
verify(mod, (golden_data, golden_output))


def test_saturation():
for rounding in roundings:
mod = get_mod(
Expand Down Expand Up @@ -397,6 +439,7 @@ def test_per_channel_different_scale():
test_same_scale()
test_downscale()
test_upscale()
test_non_power_of_two()
test_saturation()
test_zero_point()
test_per_channel_same_scale()
Expand Down