Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][TFLite] Densify Op added #7048

Merged
merged 4 commits into from
Jan 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 205 additions & 10 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(self, model, subgraph, exp_tab):
self.builtin_op_code = build_str_map(BuiltinOperator())
self.activation_fn_type = build_str_map(ActivationFunctionType())
self.builtin_options = build_str_map(BuiltinOptions())
self.prefetched_nodes = {}

# Add more operators
self.convert_map = {
Expand All @@ -80,6 +81,7 @@ def __init__(self, model, subgraph, exp_tab):
"CONCATENATION": self.convert_concatenation,
"CONV_2D": self.convert_conv2d,
"COS": self.convert_cos,
"DENSIFY": self.convert_densify,
"DEPTH_TO_SPACE": self.convert_depth_to_space,
"DEPTHWISE_CONV_2D": self.convert_depthwise_conv2d,
"DEQUANTIZE": self.convert_dequantize,
Expand Down Expand Up @@ -200,6 +202,10 @@ def convert_op_to_relay(self):
assert isinstance(op, Operator)
ret = self.convert_map[op_code_str](op)

# In case the Op can be prefetched, the output can be optimized out
if ret is None:
continue

if len(output_tensors) == 1:
tensor_idx = output_tensors[0].tensor_idx
self.exp_tab.set_expr(get_tensor_name(self.subgraph, tensor_idx), ret)
Expand Down Expand Up @@ -338,7 +344,8 @@ def get_tensor_type_as_numpy(self, tensor_wrapper):
"Tensor type '{}' currently not supported".format(tensor_wrapper.tensor.Type())
)

def get_tensor_value(self, tensor_wrapper):
# pylint: disable=no-else-return
def get_tensor_value(self, tensor_wrapper, is_sparse=False):
"""Get tensor buffer value from given tensor wrapper"""
assert isinstance(tensor_wrapper, TensorWrapper)

Expand All @@ -350,7 +357,10 @@ def get_tensor_value(self, tensor_wrapper):
else:
shape = []

return np.frombuffer(data, dtype=dtype).reshape(shape)
if is_sparse:
return np.frombuffer(data, dtype=dtype)
else:
return np.frombuffer(data, dtype=dtype).reshape(shape)

def get_tensor_type_str(self, tensor_type):
"""Get tensor type string representation when given TFLite tensor type"""
Expand Down Expand Up @@ -1645,11 +1655,15 @@ def _convert_reduce(self, relay_op, op):
axis = tuple(axis_value) if len(axis_value.shape) > 0 else tuple((axis_value.item(),))

# Options - keep_dims (bool)
assert op.BuiltinOptionsType() == BuiltinOptions.ReducerOptions
reduce_options = ReducerOptions()
op_options = op.BuiltinOptions()
reduce_options.Init(op_options.Bytes, op_options.Pos)
keep_dims = reduce_options.KeepDims()
# In case Options are not present, set keep_dims to False(default)
if op.BuiltinOptionsType():
assert op.BuiltinOptionsType() == BuiltinOptions.ReducerOptions
reduce_options = ReducerOptions()
op_options = op.BuiltinOptions()
reduce_options.Init(op_options.Bytes, op_options.Pos)
keep_dims = reduce_options.KeepDims()
else:
keep_dims = False

if input_tensor.qnn_params:
in_expr = _op.cast(in_expr, "int32")
Expand Down Expand Up @@ -2008,7 +2022,11 @@ def convert_conv(self, op, conv_type):
else:
weight_expr = _op.transpose(weight_expr, axes=(1, 2, 3, 0))
else:
weight_value = self.get_tensor_value(weight_tensor)
if self.is_prefetched(weight_tensor.tensor_idx):
weight_value = self.get_prefetched_node(weight_tensor.tensor_idx)
else:
weight_value = self.get_tensor_value(weight_tensor)

# TFLite kernel layout:
# convolution:
# OC KH KW IC, we require KH KW IC OC (HWIO)
Expand Down Expand Up @@ -3167,22 +3185,199 @@ def convert_matrix_diag(self, op):
out = _op.matrix_set_diag(input_expr, diagonal_expr)
return out

def convert_densify(self, op):
"""Convert TFLite DENSIFY"""
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]

sparse_weight_tensor = input_tensors[0]
sparse_weight_tensor_type_str = self.get_tensor_type_str(sparse_weight_tensor.tensor.Type())

# NOTE: With current implementation in TFLite, Densify Op does not need to be present
# in runtime.
# TODO(ANSHUMAN87): we need to use the sparse_indices output
# from below function and use that in sparse_to_dense Op.
# Once the stack corruption issue is resolved in sparse_to_dense Op.
_, dense_weight = prepare_dense_matrix_from_sparse(
sparse_weight_tensor.tensor,
self.get_tensor_value(sparse_weight_tensor, is_sparse=True),
sparse_weight_tensor_type_str,
)

self.set_prefetched_node(output_tensor.tensor_idx, dense_weight)

def get_expr(self, input_tensor_idx):
return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx))

def has_expr(self, input_tensor_idx):
return self.exp_tab.has_expr(get_tensor_name(self.subgraph, input_tensor_idx))

def get_tensor_expr(self, tensor):
def is_prefetched(self, input_tensor_idx):
return (
self.prefetched_nodes.get(get_tensor_name(self.subgraph, input_tensor_idx)) is not None
)

def set_prefetched_node(self, input_tensor_idx, value):
self.prefetched_nodes[get_tensor_name(self.subgraph, input_tensor_idx)] = value

def get_prefetched_node(self, input_tensor_idx):
return self.prefetched_nodes[get_tensor_name(self.subgraph, input_tensor_idx)]

def get_tensor_expr(self, tensor, is_sparse=False):
""" Return the Relay expr for tensor. """
if self.has_expr(tensor.tensor_idx):
expr = self.get_expr(tensor.tensor_idx)
else:
type_str = self.get_tensor_type_str(tensor.tensor.Type())
expr = self.exp_tab.new_const(self.get_tensor_value(tensor), dtype=type_str)
expr = self.exp_tab.new_const(self.get_tensor_value(tensor, is_sparse), dtype=type_str)
return expr


# pylint: disable=no-else-return
def prepare_dense_matrix_from_sparse(sparse_tensor, sparse_tensor_value, sparse_tensor_type):
""" Prepare sparse indices and dense matrix from TFLite sparse parameters. """
# The function is implemented based on TFLite sparse parameter specifications
# Please refer
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs#L89
# for details about each parameters
sparsity = sparse_tensor.Sparsity()
dense_shape = sparse_tensor.ShapeAsNumpy()
orig_rank = len(dense_shape)

# The traversal order of the dimensions defined in the `shape` field of the to be dense tensor.
traversal_order = sparsity.TraversalOrderAsNumpy()

# For an n-dimensional tensor with a k-dimensional block (0 <= k <= n),
# stores how a block dimension in (dn, ..., dn+k-1) maps to the original
# tensor dimension in (d0, ..., dn). It's stored in the order of (dn, ..., dn+k-1).
# If not block-sparse, this field is NULL.
block_map = sparsity.BlockMapAsNumpy()

total_rank = sparsity.TraversalOrderLength()
dense_mat = np.full(shape=dense_shape, fill_value=0, dtype=sparse_tensor_type).flatten()

from enum import Enum

# NOTE: Here the Vector term is borrowed from TFLite spec.
class VectorType(Enum):
Empty = 0
Int32 = 1
Uint16 = 2
Uint8 = 3

def _get_vector_flag(v_type):
if VectorType(v_type) == VectorType.Int32:
return N.Int32Flags
elif VectorType(v_type) == VectorType.Uint16:
return N.Uint16Flags
elif VectorType(v_type) == VectorType.Uint8:
return N.Uint8Flags
else:
raise tvm.error.OpNotImplemented("The provided type {} is not supported".format(v_type))

def _get_flattened_index(indices, shape):
index = 0
sub_elements = 1
for i in reversed(range(0, len(dense_shape))):
index += indices[i] * sub_elements
sub_elements *= shape[i]
return index

# DimensionMetadata per dimension: the metadata needed for
# each dimension to locate the non-zero values in the original dense tensor
# inline with traversal order parameter.
#
# sp_format has 2 possible values: {DENSE = 0, SPARSE_CSR = 1}
# If format = DENSE{0} : DenseSize represents size of that dimension
# If format = SPARSE_CSR{1} : array_segments represents how to segment the indices array,
# each segment corresponds to one element in the previous dimension. array_indices
# represents the index of the non-zero elements within this dimension
# (as those in the CSR matrix format, where the first array is row pointers
# and the second array is column indices).
sp_format = np.zeros(sparsity.DimMetadataLength())
dim_metadata = [None] * (2 * sparsity.DimMetadataLength())
ANSHUMAN87 marked this conversation as resolved.
Show resolved Hide resolved

# Below loop will fetch all meta data per dimension based on format type
# Dense or Sparse and will put it in an agnostic array for easy access
# while preparing dense buffer or indices.
for i in range(sparsity.DimMetadataLength()):
sp_format[i] = sparsity.DimMetadata(i).Format()
if sp_format[i] == 0:
dim_metadata[2 * i] = [sparsity.DimMetadata(i).DenseSize()]
else:
from flatbuffers import number_types as N

dim_metadata[2 * i] = (
sparsity.DimMetadata(i)
.ArraySegments()
.GetVectorAsNumpy(
flags=_get_vector_flag(sparsity.DimMetadata(i).ArraySegmentsType()), off=4
)
)
dim_metadata[2 * i + 1] = (
sparsity.DimMetadata(i)
.ArrayIndices()
.GetVectorAsNumpy(
flags=_get_vector_flag(sparsity.DimMetadata(i).ArrayIndicesType()), off=4
)
)

block_dim = 0
block_size = np.zeros(sparsity.BlockMapLength())

# Block size parameter if encoded in BSR format
for i in range(orig_rank):
if block_dim < sparsity.BlockMapLength() and block_map[block_dim] == i:
orig_dim = traversal_order[orig_rank + block_dim]
block_size[block_dim] = sparsity.DimMetadata(orig_dim).DenseSize()
block_dim += 1

indices_list = []

# Below function iterates through each applicable indices per dimension
# based on format type specified and finaly produce the dense matrix and the NZ indices.
def _def_prepare_dense_matrix_from_sparse(indices, level, prev_idx):
ANSHUMAN87 marked this conversation as resolved.
Show resolved Hide resolved
if level == len(indices):
start_pos = 0
orig_idx = np.zeros(orig_rank, dtype="int32")
while start_pos < orig_rank:
orig_idx[traversal_order[start_pos]] = indices[start_pos]
start_pos += 1
while start_pos < len(indices):
block_idx = traversal_order[start_pos] - orig_rank
orig_dim = block_map[block_idx]
orig_idx[orig_dim] = orig_idx[orig_dim] * block_size[block_idx] + indices[start_pos]
start_pos += 1
indices_list.append(orig_idx)
nonlocal value_idx
dense_mat[_get_flattened_index(orig_idx, dense_shape)] = sparse_tensor_value[value_idx]
value_idx += 1
else:
metadata_idx = 2 * level
if sp_format[level] == 0:
shape_of_level = dim_metadata[metadata_idx][0]
for idx in range(shape_of_level):
indices[level] = idx
_def_prepare_dense_matrix_from_sparse(
indices, level + 1, prev_idx * shape_of_level + idx
)
else:
array_segments = dim_metadata[metadata_idx]
array_indices = dim_metadata[metadata_idx + 1]
for idx in range(array_segments[prev_idx], array_segments[prev_idx + 1]):
indices[level] = array_indices[idx]
_def_prepare_dense_matrix_from_sparse(indices, level + 1, idx)

indices = np.zeros(total_rank)
value_idx = 0
_def_prepare_dense_matrix_from_sparse(indices, 0, 0)
return np.array(indices_list, dtype="int32"), dense_mat.reshape(dense_shape)


def get_scalar_from_constant(expr):
""" Returns scalar value from Relay constant scalar. """
assert (
Expand Down
48 changes: 48 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3567,6 +3567,50 @@ def test_forward_mobilenet_v3():
)


#######################################################################
# Mobilenet V1 Sparse
# -----------------


def test_forward_sparse_mobilenet_v1():
"""Test the Sparse version of Mobilenet V1 TF Lite model."""
# MobilenetV1
tflite_model_file = download_testdata(
"https://storage.googleapis.com/fast-convnets/tflite-models/mbv1_140_90_12b4_720.tflite",
"mbv1_140_90_12b4_720.tflite",
)
with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()
data = np.random.uniform(size=(1, 224, 224, 3)).astype("float32")
tflite_output = run_tflite_graph(tflite_model_buf, data)
tvm_output = run_tvm_graph(tflite_model_buf, data, "float_image_input")
tvm.testing.assert_allclose(
np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5
)


#######################################################################
# Mobilenet V2 Sparse
# -----------------


def test_forward_sparse_mobilenet_v2():
"""Test the Sparse version of Mobilenet V2 TF Lite model."""
# MobilenetV1
tflite_model_file = download_testdata(
"https://storage.googleapis.com/fast-convnets/tflite-models/mbv2_200_85_11-16b2_744.tflite",
"mbv2_200_85_11-16b2_744.tflite",
)
with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()
data = np.random.uniform(size=(1, 224, 224, 3)).astype("float32")
tflite_output = run_tflite_graph(tflite_model_buf, data)
tvm_output = run_tvm_graph(tflite_model_buf, data, "float_image_input")
tvm.testing.assert_allclose(
np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5
)


#######################################################################
# Inception
# ---------
Expand Down Expand Up @@ -4073,6 +4117,10 @@ def test_forward_mediapipe_hand_landmark():
test_forward_coco_ssd_mobilenet_v1()
test_forward_mediapipe_hand_landmark()

# End to End Sparse models
test_forward_sparse_mobilenet_v1()
test_forward_sparse_mobilenet_v2()

# End to End quantized
test_forward_qnn_inception_v1_net()
test_forward_qnn_mobilenet_v1_net()
Expand Down