Skip to content

Commit

Permalink
[FRONTEND][TFLITE]Gather, StridedSlice op support added (apache#4788)
Browse files Browse the repository at this point in the history
* [FRONTEND][TFLITE]Gather, StridedSlice op added

* Review comments fixed
  • Loading branch information
siju-samuel authored and dpankratz committed Apr 24, 2020
1 parent 2101a44 commit 1f0889b
Show file tree
Hide file tree
Showing 2 changed files with 289 additions and 0 deletions.
214 changes: 214 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=invalid-name, unused-argument, too-many-lines, import-outside-toplevel
"""Tensorflow lite frontend."""
import math
import itertools
import numpy as np
import tvm
from tvm.ir import IRModule
Expand Down Expand Up @@ -82,6 +83,7 @@ def __init__(self, model, subgraph, exp_tab):
'FLOOR_MOD': self.convert_floor_mod,
'FLOOR': self.convert_floor,
'FULLY_CONNECTED': self.convert_fully_connected,
'GATHER': self.convert_gather,
'GREATER_EQUAL': self.convert_greater_equal,
'GREATER': self.convert_greater,
'HARD_SWISH': self.convert_hard_swish,
Expand Down Expand Up @@ -125,6 +127,7 @@ def __init__(self, model, subgraph, exp_tab):
'SQUARE': self.convert_square,
'SQUARED_DIFFERENCE': self.convert_squared_difference,
'SQUEEZE': self.convert_squeeze,
'STRIDED_SLICE': self.convert_strided_slice,
'SUB': self.convert_sub,
'SUM': self._convert_reduce_sum,
'TAN': self.convert_tan,
Expand Down Expand Up @@ -983,6 +986,217 @@ def convert_logical_or(self, op):
"""Convert tflite LOGICAL_OR"""
return self._convert_logical_binary(_op.logical_or, op)

def convert_gather(self, op):
"""Method to Convert TFLite GATHER operator"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.GatherOptions import GatherOptions
from tflite.TensorType import TensorType
except ImportError:
raise ImportError("The tflite package must be installed")

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2"

data = self.get_expr(input_tensors[0].tensor_idx)

indices = input_tensors[1]
indices_type = indices.tensor.Type()
assert indices_type in (TensorType.INT32, TensorType.INT64)
indices_type_str = self.get_tensor_type_str(indices_type)
indices = self.exp_tab.new_const(self.get_tensor_value(indices),
dtype=indices_type_str)

assert op.BuiltinOptionsType() == BuiltinOptions.GatherOptions
op_options = op.BuiltinOptions()
gather_options = GatherOptions()
gather_options.Init(op_options.Bytes, op_options.Pos)
axis = gather_options.Axis()

# Check the indices are with in bounds.
data_shape = list(input_tensors[0].tensor.ShapeAsNumpy())
data_dim = len(data_shape)

axis_n = axis
if axis_n < 0:
axis_n += axis_n + data_dim
assert axis_n >= 0, "Axis out of bounds"
assert axis_n < data_dim, "Axis out of bounds"

indices_val = self.get_tensor_value(input_tensors[1])
indices_shape = list(indices_val.shape)
indices_len = len(indices_shape)

out_shape = []
for i in range(data_dim):
if axis_n == i:
for j in range(indices_len):
out_shape.append(indices_shape[j])
else:
out_shape.append(data_shape[i])

loopover = [range(s) for s in out_shape]
for idx in list(itertools.product(*loopover)):
indices_position = [idx[j] for j in range(axis_n, axis_n+indices_len)]

real_indices = [idx[j] for j in range(axis_n)]
real_indices.append(indices_val[tuple(indices_position)])
real_indices.extend([idx[j] for j in range(axis_n + indices_len, len(idx))])
for r, d in zip(real_indices, data_shape):
if r >= d:
raise ValueError("TFLite out of bound indices are not supported.")

# Use mode 'fast' since indices are already checked within bounds.
out = _op.take(data, indices, axis=axis, mode="fast")
return out

def convert_strided_slice(self, op):
"""Method to Convert TFLite STRIDED_SLICE operator.
NOTE: Eventhough tensorflow supports begin_mask, end_mask, ellipsis_mask, new_axis_mask
and shrink_axis_mask, tflite doesn't support these and expect these values to be zero.
But in future, they may open up the mask implementation, so kept the implementation
same as tensorflow.
This op extracts a slice of size (end - begin) / stride from the given input tensor.
Starting at the location specified by begin the slice continues by adding stride to the
index until all dimensions are not less than end. Note that a stride can be negative,
which causes a reverse slice.
For slice input[val0, val1, ..., valn], begin/end/strides will be vectors of length n.
In each mask field(begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask)
the ith bit will correspond to the ith val.
If the ith bit of begin_mask is set, begin[i] is ignored and the fullest possible range
in that dimension is used instead.
If the ith bit of ellipsis_mask is set, as many unspecified dimensions as needed will be
inserted between other dimensions. Only one non-zero bit is allowed in ellipsis_mask.
If the ith bit of new_axis_mask is set, then begin, end, and stride are ignored and a
new length 1 dimension is added at this point in the output tensor.
If the ith bit of shrink_axis_mask is set, it implies that the ith specification shrinks
the dimensionality by 1, taking on the value at index begin[i]. end[i] and strides[i]
are ignored in this case.
begin and end are zero-indexed. strides entries must be non-zero.
TVM Relay implementation of doesn't support mask, so the mask values are processed in
this function and begin/end/strides are updated accordingly. If any mask is present, and
since tvm doesn't support mask computation directly, the output need a final reshape.
"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.StridedSliceOptions import StridedSliceOptions
except ImportError:
raise ImportError("The tflite package must be installed")

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 4, "input tensors length should be 4"

data_expr = self.get_expr(input_tensors[0].tensor_idx)

begin = list(self.get_tensor_value(input_tensors[1]))
end = list(self.get_tensor_value(input_tensors[2]))
stride = list(self.get_tensor_value(input_tensors[3]))

assert op.BuiltinOptionsType() == BuiltinOptions.StridedSliceOptions
op_options = op.BuiltinOptions()
options = StridedSliceOptions()
options.Init(op_options.Bytes, op_options.Pos)
begin_mask = options.BeginMask()
end_mask = options.EndMask()
ellipsis_mask = options.EllipsisMask()
new_axis_mask = options.NewAxisMask()
shrink_axis_mask = options.ShrinkAxisMask()

data_shape = list(input_tensors[0].tensor.ShapeAsNumpy())
data_dim = len(data_shape)
stride_dim = len(stride)
def _transform_mask(stride_dim, ellipsis_mask):
"""Handle mask inputs to create new begin, end, stride and output shape"""
m_begin = [0] * data_dim
m_end = [0] * data_dim
m_stride = [0] * data_dim
fshape_indices = []
#Count new axis after ellipsis_mask, consider while applying ellipsis_mask.
ellipsis_seen = False
new_axes_after_ellipsis = 0
for i in range(stride_dim):
mask = 1 << i
if ellipsis_seen and (mask & new_axis_mask) != 0:
new_axes_after_ellipsis += 1
if (mask & ellipsis_mask) != 0:
ellipsis_seen = True
if not ellipsis_seen:
#Used later for extending the stride attributes in the below loop.
ellipsis_mask |= (1 << stride_dim)
stride_dim += 1
final_index = 0
for index in range(stride_dim):
mask = 1 << index
if mask & ellipsis_mask:
#Identify the end index for applying ellipsis_mask
to_index = min(((data_dim - (stride_dim-index)) + 1 \
+ new_axes_after_ellipsis), data_dim)
for i in range(final_index, to_index):
m_begin[final_index] = 0
m_end[final_index] = data_shape[final_index]
m_stride[final_index] = 1
fshape_indices.append(final_index)
final_index += 1
elif mask &new_axis_mask:
fshape_indices.append(-1)
elif not mask & new_axis_mask:
if final_index == len(m_begin):
break
if mask & begin_mask:
m_begin[final_index] = data_shape[final_index] \
if stride[index] < 0 else 0
elif begin[index]:
m_begin[final_index] = begin[index]
if mask & end_mask:
m_end[final_index] = 0 if stride[index] < 0 \
else data_shape[final_index]
elif end[index]:
m_end[final_index] = end[index]
m_stride[final_index] = stride[index]
if mask & shrink_axis_mask:
#Tensorflow make axis with shrink_axis_mask as dimension 1
m_begin[final_index] = data_shape[final_index] + begin[index] \
if begin[index] < 0 else begin[index]
m_end[final_index] = begin[index] + 1
m_stride[final_index] = 1
fshape_indices.append(-2)
else:
fshape_indices.append(final_index)

final_index += 1
return m_begin, m_end, m_stride, fshape_indices

fshape_indices = None
if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask:
begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask)

out = _op.strided_slice(data_expr, begin=begin, end=end, strides=stride)
out_shape = _infer_shape(out)
if not fshape_indices:
fshape_indices = range(len(out_shape))

#Create final output shape.
final_output = []
for gather_index in fshape_indices:
if gather_index == -1:
final_output.append(1)
elif gather_index == -2:
pass
else:
final_output.append(out_shape[gather_index])

if not final_output:
return out
return _op.reshape(out, newshape=tuple(final_output))

def convert_zeros_like(self, op):
"""Convert TFLite ZEROS LIKE"""
input_tensors = self.get_input_tensors(op)
Expand Down
75 changes: 75 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,79 @@ def test_forward_topk():
_test_topk((3, 5, 7), 3)
_test_topk((3, 5, 7), 3)

#######################################################################
# Gather
# ------

def _test_gather(dshape, indices, axis, dtype, quantized=False, oob=False):
""" One iteration of Gather """
indices = np.asarray(indices).astype('int32')
data = np.random.uniform(1, 10, size=dshape)
data = data.astype(np.uint8) if quantized else data.astype(dtype)
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="in_data")
if axis:
out = array_ops.gather(in_data, indices, axis=axis)
else:
out = array_ops.gather(in_data, indices) #tflite conversion fails for None axis
input_range = {'in_data': (-100, 100)} if quantized else None
try:
compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out],
quantized=quantized, input_range=input_range)
except ValueError as e:
if not oob:
raise e
except Exception as e:
raise e

def test_forward_gather():
""" GATHER """
for quantized in [False, True]:
_test_gather((4,), [1], 0, 'float32', quantized)
_test_gather((4,), [1], None, 'int32', quantized)
_test_gather((1, 4), [0], 0, 'int32', quantized)
_test_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32', quantized)
_test_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32', quantized)
_test_gather((2, 2), [[[1, 0], [0, 1]]], None, 'float32', quantized)
_test_gather((3, 3, 3), [[[1, 0]]], 0, 'int32', quantized)
_test_gather((3, 3, 3), [[[1, 0]]], 2, 'int32', quantized)
_test_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32', quantized)
_test_gather((3, 3, 3), [[[2, 1]]], -1, 'int32', quantized)
_test_gather((4,), [16], 0, 'float32', quantized, oob=True)
_test_gather((1, 3, 3), [12], 0, 'int32', quantized, oob=True)
_test_gather((1, 3, 3), [20], 1, 'float32', quantized, oob=True)
_test_gather((1, 3, 3), [20, 20], 2, 'float32', quantized, oob=True)

#######################################################################
# StridedSlice
# ------------

def _test_stridedslice(ip_shape, begin, end, stride, dtype,
begin_mask=0, end_mask=0, new_axis_mask=0,
shrink_axis_mask=0, ellipsis_mask=0, quantized=False):
""" One iteration of a Stridedslice """
data = np.random.uniform(size=ip_shape).astype(dtype)
data = data.astype(np.uint8) if quantized else data.astype(dtype)
with tf.Graph().as_default():
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
out = array_ops.strided_slice(in_data, begin, end, stride,
begin_mask=begin_mask,
end_mask=end_mask,
new_axis_mask=new_axis_mask,
shrink_axis_mask=shrink_axis_mask,
ellipsis_mask=ellipsis_mask)
input_range = {'in_data': (-100, 100)} if quantized else None
compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out], quantized=quantized,
input_range=input_range)

def test_forward_stridedslice():
'''test StridedSlice'''
for quantized in [False, True]:
_test_stridedslice((2), [1], [1], [1], 'float32', shrink_axis_mask=1, quantized=quantized)
_test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32', quantized=quantized)
_test_stridedslice((3, 4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=0, quantized=quantized)
_test_stridedslice((4, 4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=2, quantized=quantized)

#######################################################################
# transpose
# ---------
Expand Down Expand Up @@ -1855,6 +1928,8 @@ def test_forward_mediapipe_hand_landmark():
test_forward_squeeze()
test_forward_slice()
test_forward_topk()
test_forward_gather()
test_forward_stridedslice()
test_forward_depthtospace()
test_forward_spacetodepth()

Expand Down

0 comments on commit 1f0889b

Please sign in to comment.