Skip to content

Commit

Permalink
[microNPU] Add support for SPLIT and SPLIT_V (apache#9621)
Browse files Browse the repository at this point in the history
Both, SPLIT and SPLIT_V get lowered to relay.split and in the
legalization the Relay split gets turned into strided slices. This
patch adds the pattern and legalizer to enable offloading the TFLite's
splits to the NPU.
  • Loading branch information
ekalda authored and yangulei committed Jan 11, 2022
1 parent eb9608b commit 68a97ad
Show file tree
Hide file tree
Showing 4 changed files with 243 additions and 0 deletions.
20 changes: 20 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,25 @@ def callback(
return relay.Tuple(strided_slices)


class PartitionedSplitRewriter(DFPatternCallback):
"""This pass brings the split out of the partitioned function"""

def __init__(self):
super().__init__(require_type=True, rewrite_once=True)
self.pattern = (
wildcard().has_attr({"Composite": ethosu_patterns.SplitParams.composite_name})
)(wildcard())

def callback(
self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
) -> tvm.relay.Expr:
split_input = post.args[0]
split_params = ethosu_patterns.SplitParams(post.op.body)
indices_or_sections = split_params.indices_or_sections
axis = split_params.axis
return relay.op.split(split_input, indices_or_sections, axis=axis).astuple()


@ir.transform.module_pass(opt_level=1)
class LegalizeSplit:
"""This is the pass that wraps SplitRewriter"""
Expand All @@ -116,6 +135,7 @@ def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(PartitionedSplitRewriter(), func)
func = rewrite(SplitRewriter(), func)
mod.update_func(global_var, func)
return mod
Expand Down
43 changes: 43 additions & 0 deletions python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,44 @@ def concat_pattern():
return concat


class SplitParams:
"""
This class will parse a call to a ethos-u.split composite function
and extract the parameter information.
"""

composite_name = "ethos-u.split"

def __init__(self, func_body):
self.split = func_body
self.input = TensorParams(func_body.args[0])
self.axis = func_body.attrs.axis
self.indices_or_sections = self.convert_indices_or_sections(
func_body.attrs.indices_or_sections
)

def convert_indices_or_sections(self, indices_or_sections):
# split_v
if isinstance(indices_or_sections, tvm.ir.container.Array):
values = [i.value for i in indices_or_sections]
# split
else:
values = indices_or_sections.value
return values

def is_valid(self):
"""Checks whether split has compatible attributes with the hardware"""
if not check_valid_dtypes([self.input], supported_dtypes=[np.int8]):
return False
return True


def split_pattern():
"Create the pattern for split"
split = is_op("split")(wildcard())
return split


@register_pattern_table("ethos-u")
def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]:
return [
Expand Down Expand Up @@ -1187,6 +1225,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal
sigmoid_pattern(),
lambda pat: SigmoidParams(pat).is_valid(),
),
(
SplitParams.composite_name,
split_pattern(),
lambda pat: SplitParams(pat).is_valid(),
),
]


Expand Down
22 changes: 22 additions & 0 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,5 +929,27 @@ def sigmoid_function(x):
_compare_tvm_with_tflite(sigmoid_function, [ifm_shape], accel_type)


# This codegen test checks both, split and split_v
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize(
"ifm_shape, num_or_size_splits, axis",
[
((1, 4, 6, 8), (1, 3, 4), 3),
((4, 6, 8), 2, 0),
((50,), 25, 0),
((5, 11), 1, 1),
((13,), (13,), 0),
((22, 7), (4, -1), 1),
],
)
def test_tflite_split(accel_type, ifm_shape, num_or_size_splits, axis):
@tf.function
def split_func(x):
op = tf.split(x, num_or_size_splits, axis=axis)
return op

_compare_tvm_with_tflite(split_func, [ifm_shape], accel_type)


if __name__ == "__main__":
pytest.main([__file__])
158 changes: 158 additions & 0 deletions tests/python/contrib/test_ethosu/test_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,5 +1344,163 @@ def representative_dataset():
assert tuple(func_body.args[1].checked_type.shape) == (256,)


@pytest.mark.parametrize(
"ifm_shape, num_or_size_splits, axis",
[
((1, 4, 6, 8), 3, 2),
((4, 6, 8), 2, 0),
((5, 15), 3, 1),
((3, 7), 1, 1),
((100,), 25, 0),
],
)
def test_tflite_split_legalize(ifm_shape, num_or_size_splits, axis):
dtype = "int8"

def create_tflite_graph():
class Model(tf.Module):
@tf.function
def tf_function(self, x, num_or_size_splits, axis):
op = tf.split(x, num_or_size_splits, axis=axis)
return op

model = Model()
concrete_func = model.tf_function.get_concrete_function(
tf.TensorSpec(ifm_shape, tf.float32), num_or_size_splits, axis
)

def representative_dataset():
for _ in range(100):
data = np.random.rand(*tuple(ifm_shape))
yield [data.astype(np.float32)]

converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
tflite_model = converter.convert()

return tflite_model

def verify(ext_func):
# dig out the split
single_output_split = num_or_size_splits == 1
split = (
ext_func.body.tuple_value
if single_output_split
else ext_func.body.args[0][0].args[0].tuple_value
)
assert split.op.name == "split"

# Split is specified by number of equal chunks
assert split.attrs.indices_or_sections == num_or_size_splits

assert split.attrs.axis == axis

tflite_graph = create_tflite_graph()
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)

mod, _ = relay.frontend.from_tflite(
tflite_model,
shape_dict={"input": ifm_shape},
dtype_dict={"input": dtype},
)
mod = ethosu.partition_for_ethosu(mod)

mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
legalize.PartitionedSplitRewriter(), mod["tvmgen_default_ethos_u_main_0"]
)

mod["tvmgen_default_ethos_u_main_0"] = relay.transform.InferType()(mod)[
"tvmgen_default_ethos_u_main_0"
]

verify(mod["tvmgen_default_ethos_u_main_0"])


@pytest.mark.parametrize(
"ifm_shape, num_or_size_splits, axis",
[
((1, 4, 6, 8), (1, 3, 4), 3),
((10, 18, 4), (1, 4, 3, 2), 0),
((22, 7), (4, -1), 1),
((25,), (25,), 0),
],
)
def test_tflite_split_v_legalize(ifm_shape, num_or_size_splits, axis):
dtype = "int8"

def create_tflite_graph():
class Model(tf.Module):
@tf.function
def tf_function(self, x, num_or_size_splits, axis):
# TF split gets converted into TFLite's split_v
op = tf.split(x, num_or_size_splits, axis=axis)
return op

model = Model()
concrete_func = model.tf_function.get_concrete_function(
tf.TensorSpec(ifm_shape, tf.float32), num_or_size_splits, axis
)

def representative_dataset():
for _ in range(100):
data = np.random.rand(*tuple(ifm_shape))
yield [data.astype(np.float32)]

converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
tflite_model = converter.convert()

return tflite_model

def verify(ext_func):
# dig out the split
single_output_split = len(num_or_size_splits) == 1
split = (
ext_func.body.tuple_value
if single_output_split
else ext_func.body.args[0][0].args[0].tuple_value
)
assert split.op.name == "split"

# Split is specified by the size of sections, so converting num_or_size_splits
# into the indices where the tensor is split at since this is how split is represented
# in Relay
split_sections = [] if single_output_split else [num_or_size_splits[0]]
for split_size in num_or_size_splits[1:-1]:
sec = split_sections[-1] + split_size
split_sections.append(sec)
assert list(split.attrs.indices_or_sections) == split_sections

assert split.attrs.axis == axis

tflite_graph = create_tflite_graph()
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)

mod, _ = relay.frontend.from_tflite(
tflite_model,
shape_dict={"input": ifm_shape},
dtype_dict={"input": dtype},
)
mod = ethosu.partition_for_ethosu(mod)

mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
legalize.PartitionedSplitRewriter(), mod["tvmgen_default_ethos_u_main_0"]
)

mod["tvmgen_default_ethos_u_main_0"] = relay.transform.InferType()(mod)[
"tvmgen_default_ethos_u_main_0"
]

verify(mod["tvmgen_default_ethos_u_main_0"])


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 68a97ad

Please sign in to comment.