Skip to content

Commit

Permalink
Support quantised RSQRT operator in TFLite
Browse files Browse the repository at this point in the history
The commit tests _convert_unary_elemwise function for the quantised and
non quantized tensor for the RSQRT op.
Other operators will be tested in future (separated )commits.
  • Loading branch information
ophirfrish committed Oct 3, 2021
1 parent 719d2f6 commit 5183900
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 17 deletions.
14 changes: 9 additions & 5 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,8 +1116,16 @@ def _convert_unary_elemwise(self, relay_op, op):

input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)
out = relay_op(in_expr)

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]

if input_tensor.qnn_params:
in_expr = self.dequantize(in_expr, input_tensor)
out = relay_op(in_expr)
if output_tensor.qnn_params:
out = self.quantize(out, output_tensor)
return out

def convert_abs(self, op):
Expand Down Expand Up @@ -1186,10 +1194,6 @@ def convert_sqrt(self, op):

def convert_rsqrt(self, op):
"""Convert TFLite RSQRT"""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
"TFlite quantized RSQRT operator is not supported yet."
)
return self._convert_unary_elemwise(_op.rsqrt, op)

def convert_neg(self, op):
Expand Down
53 changes: 41 additions & 12 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1868,16 +1868,6 @@ def _test_sqrt(data):
return _test_unary_elemwise(math_ops.sqrt, data)


#######################################################################
# Rsqrt
# -----


def _test_rsqrt(data):
"""One iteration of rsqrt"""
return _test_unary_elemwise(math_ops.rsqrt, data)


#######################################################################
# Neg
# ---
Expand Down Expand Up @@ -1910,7 +1900,7 @@ def _test_elu(data):

def _test_forward_unary_elemwise(test_op):
# functions that need positive input
if test_op.__name__ in {"_test_log", "_test_sqrt", "_test_rsqrt"}:
if test_op.__name__ in {"_test_log", "_test_sqrt"}:
test_op(np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3)))
else:
test_op(np.random.uniform(-10, 10, (3, 2)).astype(np.float32))
Expand All @@ -1923,7 +1913,6 @@ def test_all_unary_elemwise():
_test_forward_unary_elemwise(_test_log)
_test_forward_unary_elemwise(_test_sin)
_test_forward_unary_elemwise(_test_sqrt)
_test_forward_unary_elemwise(_test_rsqrt)
_test_forward_unary_elemwise(_test_neg)
_test_forward_unary_elemwise(_test_square)
# ceil and cos come with TFLite 1.14.0.post1 fbs schema
Expand Down Expand Up @@ -3352,6 +3341,45 @@ def test_forward_tanh():
_test_tanh(np.arange(0, 256, 30, dtype=np.uint8), quantized=True)


#######################################################################
# RSQRT
# ----


def _test_rsqrt(data, quantized=False):
"""One iteration of RSQRT"""
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype="float32", name="in_0")

if quantized:
inq_data = tf.quantization.fake_quant_with_min_max_args(
in_data, min=1, max=6, name="inq_0"
)
input_range = {"inq_0": (1, 6)}
out = math_ops.rsqrt(inq_data)
out = tf.quantization.fake_quant_with_min_max_args(out, min=1, max=6, name="out")
compare_tflite_with_tvm(
data,
"inq_0:0",
[inq_data],
[out],
quantized=True,
input_range=input_range,
experimental_new_converter=True,
)
else:
out = math_ops.rsqrt(in_data)
compare_tflite_with_tvm(data, "in_0:0", [in_data], [out])


def test_forward_rsqrt():
"""RSQRT"""
_test_rsqrt(np.arange(1.0, 7.0, dtype=np.float32), quantized=False)
_test_rsqrt(np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3)), quantized=False)
_test_rsqrt(np.arange(1, 240, 40, dtype=np.uint8), quantized=True)
_test_rsqrt(np.arange(1, 240, 40, dtype=np.uint8).reshape((2, 1, 3)), quantized=True)


#######################################################################
# ReLu
# ----
Expand Down Expand Up @@ -4561,6 +4589,7 @@ def test_prevent_tensorflow_dynamic_range():
test_forward_l2_pool2d()
test_forward_softmax()
test_forward_tanh()
test_forward_rsqrt()
test_forward_relu()
test_forward_relu6()
test_forward_leaky_relu()
Expand Down

0 comments on commit 5183900

Please sign in to comment.