From 68a97ad99d0d90c674c7c614a57698111b66d11b Mon Sep 17 00:00:00 2001 From: Elen Kalda Date: Fri, 10 Dec 2021 16:38:29 +0000 Subject: [PATCH] [microNPU] Add support for SPLIT and SPLIT_V (#9621) Both, SPLIT and SPLIT_V get lowered to relay.split and in the legalization the Relay split gets turned into strided slices. This patch adds the pattern and legalizer to enable offloading the TFLite's splits to the NPU. --- .../relay/backend/contrib/ethosu/legalize.py | 20 +++ python/tvm/relay/op/contrib/ethosu.py | 43 +++++ .../contrib/test_ethosu/test_codegen.py | 22 +++ .../contrib/test_ethosu/test_legalize.py | 158 ++++++++++++++++++ 4 files changed, 243 insertions(+) diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 0db8db912a51f..ede9cd46371e4 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -108,6 +108,25 @@ def callback( return relay.Tuple(strided_slices) +class PartitionedSplitRewriter(DFPatternCallback): + """This pass brings the split out of the partitioned function""" + + def __init__(self): + super().__init__(require_type=True, rewrite_once=True) + self.pattern = ( + wildcard().has_attr({"Composite": ethosu_patterns.SplitParams.composite_name}) + )(wildcard()) + + def callback( + self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map + ) -> tvm.relay.Expr: + split_input = post.args[0] + split_params = ethosu_patterns.SplitParams(post.op.body) + indices_or_sections = split_params.indices_or_sections + axis = split_params.axis + return relay.op.split(split_input, indices_or_sections, axis=axis).astuple() + + @ir.transform.module_pass(opt_level=1) class LegalizeSplit: """This is the pass that wraps SplitRewriter""" @@ -116,6 +135,7 @@ def transform_module( self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext ) -> tvm.ir.IRModule: for global_var, func in mod.functions.items(): + func = rewrite(PartitionedSplitRewriter(), func) func = rewrite(SplitRewriter(), func) mod.update_func(global_var, func) return mod diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index a7d3da3200b53..73007cffe7268 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -1107,6 +1107,44 @@ def concat_pattern(): return concat +class SplitParams: + """ + This class will parse a call to a ethos-u.split composite function + and extract the parameter information. + """ + + composite_name = "ethos-u.split" + + def __init__(self, func_body): + self.split = func_body + self.input = TensorParams(func_body.args[0]) + self.axis = func_body.attrs.axis + self.indices_or_sections = self.convert_indices_or_sections( + func_body.attrs.indices_or_sections + ) + + def convert_indices_or_sections(self, indices_or_sections): + # split_v + if isinstance(indices_or_sections, tvm.ir.container.Array): + values = [i.value for i in indices_or_sections] + # split + else: + values = indices_or_sections.value + return values + + def is_valid(self): + """Checks whether split has compatible attributes with the hardware""" + if not check_valid_dtypes([self.input], supported_dtypes=[np.int8]): + return False + return True + + +def split_pattern(): + "Create the pattern for split" + split = is_op("split")(wildcard()) + return split + + @register_pattern_table("ethos-u") def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]: return [ @@ -1187,6 +1225,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal sigmoid_pattern(), lambda pat: SigmoidParams(pat).is_valid(), ), + ( + SplitParams.composite_name, + split_pattern(), + lambda pat: SplitParams(pat).is_valid(), + ), ] diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 0707ec27ca27b..ce2efc7dc3f5a 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -929,5 +929,27 @@ def sigmoid_function(x): _compare_tvm_with_tflite(sigmoid_function, [ifm_shape], accel_type) +# This codegen test checks both, split and split_v +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize( + "ifm_shape, num_or_size_splits, axis", + [ + ((1, 4, 6, 8), (1, 3, 4), 3), + ((4, 6, 8), 2, 0), + ((50,), 25, 0), + ((5, 11), 1, 1), + ((13,), (13,), 0), + ((22, 7), (4, -1), 1), + ], +) +def test_tflite_split(accel_type, ifm_shape, num_or_size_splits, axis): + @tf.function + def split_func(x): + op = tf.split(x, num_or_size_splits, axis=axis) + return op + + _compare_tvm_with_tflite(split_func, [ifm_shape], accel_type) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 9dc94d96fb274..9f979153f714a 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -1344,5 +1344,163 @@ def representative_dataset(): assert tuple(func_body.args[1].checked_type.shape) == (256,) +@pytest.mark.parametrize( + "ifm_shape, num_or_size_splits, axis", + [ + ((1, 4, 6, 8), 3, 2), + ((4, 6, 8), 2, 0), + ((5, 15), 3, 1), + ((3, 7), 1, 1), + ((100,), 25, 0), + ], +) +def test_tflite_split_legalize(ifm_shape, num_or_size_splits, axis): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, x, num_or_size_splits, axis): + op = tf.split(x, num_or_size_splits, axis=axis) + return op + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, tf.float32), num_or_size_splits, axis + ) + + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + + return tflite_model + + def verify(ext_func): + # dig out the split + single_output_split = num_or_size_splits == 1 + split = ( + ext_func.body.tuple_value + if single_output_split + else ext_func.body.args[0][0].args[0].tuple_value + ) + assert split.op.name == "split" + + # Split is specified by number of equal chunks + assert split.attrs.indices_or_sections == num_or_size_splits + + assert split.attrs.axis == axis + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, _ = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + mod = ethosu.partition_for_ethosu(mod) + + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + legalize.PartitionedSplitRewriter(), mod["tvmgen_default_ethos_u_main_0"] + ) + + mod["tvmgen_default_ethos_u_main_0"] = relay.transform.InferType()(mod)[ + "tvmgen_default_ethos_u_main_0" + ] + + verify(mod["tvmgen_default_ethos_u_main_0"]) + + +@pytest.mark.parametrize( + "ifm_shape, num_or_size_splits, axis", + [ + ((1, 4, 6, 8), (1, 3, 4), 3), + ((10, 18, 4), (1, 4, 3, 2), 0), + ((22, 7), (4, -1), 1), + ((25,), (25,), 0), + ], +) +def test_tflite_split_v_legalize(ifm_shape, num_or_size_splits, axis): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, x, num_or_size_splits, axis): + # TF split gets converted into TFLite's split_v + op = tf.split(x, num_or_size_splits, axis=axis) + return op + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, tf.float32), num_or_size_splits, axis + ) + + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + + return tflite_model + + def verify(ext_func): + # dig out the split + single_output_split = len(num_or_size_splits) == 1 + split = ( + ext_func.body.tuple_value + if single_output_split + else ext_func.body.args[0][0].args[0].tuple_value + ) + assert split.op.name == "split" + + # Split is specified by the size of sections, so converting num_or_size_splits + # into the indices where the tensor is split at since this is how split is represented + # in Relay + split_sections = [] if single_output_split else [num_or_size_splits[0]] + for split_size in num_or_size_splits[1:-1]: + sec = split_sections[-1] + split_size + split_sections.append(sec) + assert list(split.attrs.indices_or_sections) == split_sections + + assert split.attrs.axis == axis + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, _ = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + mod = ethosu.partition_for_ethosu(mod) + + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + legalize.PartitionedSplitRewriter(), mod["tvmgen_default_ethos_u_main_0"] + ) + + mod["tvmgen_default_ethos_u_main_0"] = relay.transform.InferType()(mod)[ + "tvmgen_default_ethos_u_main_0" + ] + + verify(mod["tvmgen_default_ethos_u_main_0"]) + + if __name__ == "__main__": pytest.main([__file__])