Skip to content

Commit

Permalink
Rebase the tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ekalda committed Dec 8, 2021
1 parent 1abac94 commit 5c10153
Showing 1 changed file with 5 additions and 61 deletions.
66 changes: 5 additions & 61 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,68 +870,12 @@ def sigmoid_function(x):
],
)
def test_tflite_split(accel_type, ifm_shape, num_or_size_splits, axis):
dtype = "int8"

def get_tflite_graph():
class Model(tf.Module):
@tf.function
def tf_function(self, x, num_or_size_splits, axis):
op = tf.split(x, num_or_size_splits, axis=axis)
return op

model = Model()

concrete_func = model.tf_function.get_concrete_function(
tf.TensorSpec(ifm_shape, dtype=tf.float32), num_or_size_splits, axis
)

# Convert the model
def representative_dataset():
for _ in range(100):
data = np.random.rand(*tuple(ifm_shape))
yield [data.astype(np.float32)]

converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
tflite_model = converter.convert()
return tflite_model

tflite_graph = get_tflite_graph()
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)

relay_module, params = relay.frontend.from_tflite(
tflite_model,
shape_dict={"input": ifm_shape},
dtype_dict={"input": dtype},
)
mod = partition_for_ethosu(relay_module, params)

# Generate reference data
input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)

compiled_models = infra.build_source(
mod,
input_data,
output_data,
accel_type,
)

# Assumes only two runtime.Modules are created -- i.e. single offload module
imported_modules = compiled_models[0].executor_factory.lib.imported_modules
assert len(imported_modules) == 2
ethosu_module = imported_modules[0]

# Verify generated C source
get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs")
cmms = get_cs(ethosu_module)
cmms = bytes.fromhex(cmms)
@tf.function
def split_func(x):
op = tf.split(x, num_or_size_splits, axis=axis)
return op

infra.print_payload(cmms)
infra.verify_source(compiled_models, accel_type)
_compare_tvm_with_tflite(split_func, [ifm_shape], accel_type)


if __name__ == "__main__":
Expand Down

0 comments on commit 5c10153

Please sign in to comment.