Skip to content

Commit

Permalink
[TFLite] Added ability to infer shapes for arguments
Browse files Browse the repository at this point in the history
Added an ability to infer argument shapes if shapes are not present in
TFLite files. The set of networks on which the patch was tested is
internal to Arm. Any help with creating unit tests would be appreciated.
  • Loading branch information
d-smirnov committed Jan 15, 2021
1 parent 1677bb2 commit db32f0e
Showing 1 changed file with 23 additions and 12 deletions.
35 changes: 23 additions & 12 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def get_tensor_value(self, tensor_wrapper, is_sparse=False):
data = tensor_wrapper.buffer.DataAsNumpy()

if tensor_wrapper.tensor.ShapeLength() != 0:
shape = to_int_list(tensor_wrapper.tensor.ShapeAsNumpy())
shape = to_int_list(self.get_tensor_shape(tensor_wrapper))
else:
shape = []

Expand Down Expand Up @@ -1417,7 +1417,7 @@ def convert_gather(self, op):
axis = gather_options.Axis()

# Check the indices are with in bounds.
data_shape = to_int_list(input_tensors[0].tensor.ShapeAsNumpy())
data_shape = to_int_list(self.get_tensor_shape(input_tensors[0]))
data_dim = len(data_shape)

axis = data_dim + axis if axis < 0 else axis
Expand Down Expand Up @@ -1535,7 +1535,7 @@ def convert_strided_slice(self, op):
new_axis_mask = options.NewAxisMask()
shrink_axis_mask = options.ShrinkAxisMask()

data_shape = to_int_list(input_tensors[0].tensor.ShapeAsNumpy())
data_shape = to_int_list(self.get_tensor_shape(input_tensors[0]))
data_dim = len(data_shape)
stride_dim = len(stride)

Expand Down Expand Up @@ -1792,7 +1792,7 @@ def convert_fully_connected(self, op):
output_tensor_type = output_tensor.tensor.Type()
output_tensor_type_str = self.get_tensor_type_str(output_tensor_type)

weight_tensor_shape = to_int_list(weight_tensor.tensor.ShapeAsNumpy())
weight_tensor_shape = to_int_list(self.get_tensor_shape(weight_tensor))

# Weight should have only 2 dimensions(TFLite convention)
assert len(weight_tensor_shape) == 2, "Weight should be only 2-dim"
Expand Down Expand Up @@ -1987,16 +1987,16 @@ def convert_conv(self, op, conv_type):
padding = conv_options.Padding()
fused_activation_fn = conv_options.FusedActivationFunction()

_, input_h, input_w, input_c = to_int_list(input_tensor.tensor.ShapeAsNumpy())
_, input_h, input_w, input_c = to_int_list(self.get_tensor_shape(input_tensor))

if is_depthwise_conv:
# TFLite depthwise convolution kernel layout is:
# 1 KH KW C(input_c * depth_multiplier)
_, kernel_h, kernel_w, in_channels = to_int_list(weight_tensor.tensor.ShapeAsNumpy())
_, kernel_h, kernel_w, in_channels = to_int_list(self.get_tensor_shape(weight_tensor))
assert in_channels == input_c * depth_multiplier
else:
output_channels, kernel_h, kernel_w, _ = to_int_list(
weight_tensor.tensor.ShapeAsNumpy()
self.get_tensor_shape(weight_tensor)
)

dilated_kernel_h = dilation_h * (kernel_h - 1) + 1
Expand Down Expand Up @@ -2219,7 +2219,7 @@ def convert_slice(self, op):
size = list(self.get_tensor_value(input_tensors[2]))
# strided_slice(Relay) needs the slice's end indices, not the size
end = size
input_tensor_shape = to_int_list(input_tensor.tensor.ShapeAsNumpy())
input_tensor_shape = to_int_list(self.get_tensor_shape(input_tensor))
input_tensor_rank = len(input_tensor_shape)
for i in range(input_tensor_rank):
if size[i] == -1:
Expand Down Expand Up @@ -2381,7 +2381,8 @@ def convert_pool2d(self, op, pool_type):

in_expr = self.get_expr(input_tensor_idx)

_, input_h, input_w, _ = to_int_list(input_tensor.tensor.ShapeAsNumpy())
_, input_h, input_w, _ = to_int_list(self.get_tensor_shape(input_tensor))

if padding == Padding.VALID:
pass
elif padding == Padding.SAME:
Expand Down Expand Up @@ -2771,12 +2772,13 @@ def convert_transpose_conv(self, op):

# Input (data) Tensor. NHWC layout
input_tensor = input_tensors[2]
_, input_h, input_w, input_c = to_int_list(input_tensor.tensor.ShapeAsNumpy())
_, input_h, input_w, input_c = to_int_list(self.get_tensor_shape(input_tensor))
# Weights tensor. TFLite uses OHWI layout
weights_tensor = input_tensors[1]
out_channels, kernel_h, kernel_w, in_channels = to_int_list(
weights_tensor.tensor.ShapeAsNumpy()
self.get_tensor_shape(weights_tensor)
)

assert (
input_c == in_channels
), "Input channel in the filter should match to channel in the input"
Expand Down Expand Up @@ -3204,7 +3206,7 @@ def convert_matrix_diag(self, op):
), "TFLite MATRIX_DIAG requires diagonal and output tensors' \
scale and zero points to be equal"

shape = to_int_list(diagonal.tensor.ShapeAsNumpy())
shape = to_int_list(self.get_tensor_shape(diagonal))
shape = np.append(shape, shape[-1])
dtype = self.get_tensor_type_str(diagonal.tensor.Type())

Expand Down Expand Up @@ -3265,6 +3267,15 @@ def get_tensor_expr(self, tensor, is_sparse=False):
expr = self.exp_tab.new_const(self.get_tensor_value(tensor, is_sparse), dtype=type_str)
return expr

def get_tensor_shape(self, tensor_wrapper):
""" Returns tensor shape. Infers shape if the shape is empty. """
assert isinstance(tensor_wrapper, TensorWrapper), "Expecting TensorWrapper here"
return (
tensor_wrapper.tensor.ShapeAsNumpy()
if tensor_wrapper.tensor.ShapeLength() > 0
else _infer_shape(self.get_tensor_expr(tensor_wrapper))
)


# pylint: disable=no-else-return
def prepare_dense_matrix_from_sparse(sparse_tensor, sparse_tensor_value, sparse_tensor_type):
Expand Down

0 comments on commit db32f0e

Please sign in to comment.