Skip to content

Commit

Permalink
lint format
Browse files Browse the repository at this point in the history
  • Loading branch information
mehrdadh committed Mar 1, 2022
1 parent 503e86e commit f373d8a
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 33 deletions.
11 changes: 4 additions & 7 deletions python/tvm/contrib/hexagon/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,19 +254,16 @@ def get_graph_executor(self, libmod, remote_libmod_filename: str):
libmod.get_graph_json(), hexagon_mod, self.session.device
)

def get_aot_executor(self, libmod, remote_libmod_filename: str):
"""Create a local GraphModule which consumes a remote libmod.
def get_aot_executor(self, remote_libmod_filename: str):
"""Create a local AoTModule which consumes a remote libmod.
Parameters
----------
libmod : tvm.runtime.Module
The module of the corresponding function.
This library module is for remote hexagon runtime.
remote_libmod_filename : str
Module filename on remote. It is assumed this file lives under self._workspace path.
Returns
-------
graph_module : GraphModule
Runtime graph module that can be used to execute the graph.
aot_module : AotModule
Runtime AOT module that can be used to execute.
"""
self.session.__enter__()
hexagon_mod = self.get_module(remote_libmod_filename)
Expand Down
55 changes: 36 additions & 19 deletions python/tvm/contrib/hexagon/hexagon.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import functools as ft
import os
import pathlib
from typing import Union

import tvm
import tvm.ir
import tvm.contrib.cc as cc
Expand All @@ -40,16 +42,14 @@
#
# Subsequent calls to 'link_shared' will use the newly registered linker.

HEXAGON_TOOLCHAIN = os.environ.get(
"HEXAGON_TOOLCHAIN", default=None
HEXAGON_TOOLCHAIN = os.environ.get("HEXAGON_TOOLCHAIN", default="") # pylint: disable=invalid-name
HEXAGON_SDK_PATH = os.environ.get("HEXAGON_SDK_PATH", default="") # pylint: disable=invalid-name
HEXAGON_LINK_MAIN = (
pathlib.Path(HEXAGON_TOOLCHAIN) / "bin" / "hexagon-link"
) # pylint: disable=invalid-name
HEXAGON_CLANG_PLUS = (
pathlib.Path(HEXAGON_TOOLCHAIN) / "bin" / "hexagon-clang++"
) # pylint: disable=invalid-name
HEXAGON_SDK_PATH = os.environ.get("HEXAGON_SDK_PATH", default=None) # pylint: disable=invalid-name
HEXAGON_LINK_MAIN = os.path.join( # pylint: disable=invalid-name
HEXAGON_TOOLCHAIN, "bin", "hexagon-link"
)
HEXAGON_CLANG_PLUS = os.path.join( # pylint: disable=invalid-name
HEXAGON_TOOLCHAIN, "bin", "hexagon-clang++"
)
HEXAGON_SDK_INCLUDE_DIRS = [ # pylint: disable=invalid-name
pathlib.Path(HEXAGON_SDK_PATH) / "incs",
pathlib.Path(HEXAGON_SDK_PATH) / "incs" / "stddef",
Expand All @@ -62,15 +62,15 @@ def register_linker(f):


@register_func("tvm.contrib.hexagon.hexagon.hexagon_link")
def hexagon_link():
def hexagon_link() -> str:
"""Return path to the Hexagon linker."""
return HEXAGON_LINK_MAIN
return str(HEXAGON_LINK_MAIN)


@register_func("tvm.contrib.hexagon.hexagon.hexagon_clang_plus")
def hexagon_clang_plus():
def hexagon_clang_plus() -> str:
"""Return path to the Hexagon clang++."""
return HEXAGON_CLANG_PLUS
return str(HEXAGON_CLANG_PLUS)


@register_func("tvm.contrib.hexagon.hexagon.link_shared")
Expand Down Expand Up @@ -267,25 +267,42 @@ def ir_lower_vtcm_pass():
return [(3, ir_lower_vtcm())]


@register_func("tvm.contrib.hexagon.hexagon.aot_export")
def aot_export(so_name, files, **kwargs):
def create_aot_shared(so_name: Union[str, pathlib.Path], files, hexagon_arch: str, options=None):
"""Export Hexagon AOT module."""
if not os.access(str(HEXAGON_CLANG_PLUS), os.X_OK):
raise Exception(
'The Clang++ "' + str(HEXAGON_CLANG_PLUS) + '" does not exist or is not executable.'
)
if not HEXAGON_TOOLCHAIN:
raise Exception(
" The environment variable HEXAGON_TOOLCHAIN is unset. Please export "
+ "HEXAGON_TOOLCHAIN in your environment."
)
if not HEXAGON_SDK_PATH:
raise Exception(
" The environment variable HEXAGON_SDK_PATH is unset. Please export "
+ "HEXAGON_SDK_PATH in your environment."
)

tvm_dir = pathlib.Path(os.path.dirname(os.path.realpath(__file__))) / ".." / ".." / ".." / ".."
options = [
compute_arch = f"compute{hexagon_arch}"
compile_options = [
f"-I{tvm_dir / 'include'}",
f"-I{tvm_dir / '3rdparty' / 'dlpack' / 'include'}",
f"-I{tvm_dir / '3rdparty' / 'dmlc-core' / 'include'}",
f"-I{tvm_dir / 'src' / 'runtime' / 'hexagon' / 'android' / 'sim' / 'driver'}",
f"-I{pathlib.Path(HEXAGON_SDK_PATH) / 'rtos' / 'qurt' / compute_arch / 'include'/ 'posix'}",
f"-I{pathlib.Path(HEXAGON_SDK_PATH) / 'rtos' / 'qurt' / compute_arch / 'include' / 'qurt'}",
f"-DDMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>",
f"-D_MACH_I32=int",
]

# For debugging
for path in HEXAGON_SDK_INCLUDE_DIRS:
options.append(f"-I{str(path)}")
compile_options.append(f"-I{str(path)}")

cross_compile = cc.cross_compiler(
compile_func=tvm.get_global_func("tvm.contrib.hexagon.hexagon.hexagon_clang_plus")()
)
cross_compile.output_format = "o"
c_files = [str(file) for file in files]
cross_compile(so_name, c_files, options=options)
cross_compile(str(so_name), c_files, options=compile_options)
2 changes: 1 addition & 1 deletion python/tvm/runtime/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No
# The imports could contain a c module but the object format could be tar
# Thus, it would not recognize the following include paths as options
# which are there assuming a c compiler is the fcompile.
if has_c_module and not file_name.endswith(".tar") and not file_name.endswith(".so"):
if has_c_module and not file_name.endswith(".tar"):
options = []
if "options" in kwargs:
opts = kwargs["options"]
Expand Down
12 changes: 8 additions & 4 deletions tests/python/contrib/test_hexagon/test_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,9 @@ def test_aot_executor_conv2d(tvm_tracker_host, tvm_tracker_port, android_serial_
runtime=Runtime("cpp"),
executor=Executor("aot", {"unpacked-api": False, "interface-api": "c"}),
)
lowered.export_library(dso_binary_path, fcompile=tvm.contrib.hexagon.hexagon.aot_export)
lowered.export_library(
dso_binary_path, fcompile=hexagon.create_aot_shared, hexagon_arch="v68"
)

if not android_serial_number:
pytest.skip(msg="Skip hardware test since ANDROID_SERIAL_NUMBER is not set.")
Expand All @@ -426,7 +428,7 @@ def test_aot_executor_conv2d(tvm_tracker_host, tvm_tracker_port, android_serial_
launcher.hexagon_session_setup(remote_kw)
launcher.upload(dso_binary_path, dso_binary)

aot_mod = launcher.get_aot_executor(lowered, dso_binary)
aot_mod = launcher.get_aot_executor(dso_binary)
aot_mod.set_input(**inputs)
aot_mod.run()
hexagon_output = aot_mod.get_output(0).numpy()
Expand Down Expand Up @@ -506,7 +508,9 @@ def test_aot_executor_multiple_conv2d(tvm_tracker_host, tvm_tracker_port, androi
runtime=Runtime("cpp"),
executor=Executor("aot", {"unpacked-api": False, "interface-api": "c"}),
)
lowered.export_library(dso_binary_path, fcompile=tvm.contrib.hexagon.hexagon.aot_export)
lowered.export_library(
dso_binary_path, fcompile=hexagon.create_aot_shared, hexagon_arch="v68"
)

if not android_serial_number:
pytest.skip(msg="Skip hardware test since ANDROID_SERIAL_NUMBER is not set.")
Expand All @@ -523,7 +527,7 @@ def test_aot_executor_multiple_conv2d(tvm_tracker_host, tvm_tracker_port, androi
launcher.hexagon_session_setup(remote_kw)
launcher.upload(dso_binary_path, dso_binary)

aot_mod = launcher.get_aot_executor(lowered, dso_binary)
aot_mod = launcher.get_aot_executor(dso_binary)
aot_mod.set_input(**inputs)
aot_mod.run()
hexagon_output = aot_mod.get_output(0).numpy()
Expand Down
11 changes: 9 additions & 2 deletions tests/python/relay/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,12 +922,14 @@ def test_get_input_index(target, dev):
assert vm_factory.get_input_index(data_0) == 0
assert vm_factory.get_input_index("invalid") == -1


def get_one_input_relay_mod(tensor_type, shape, data_name):
x = relay.var(data_name, shape = shape, dtype = tensor_type)
x = relay.var(data_name, shape=shape, dtype=tensor_type)
y = relay.exp(x)
f = relay.Function([x], y)
return IRModule.from_expr(f)


@tvm.testing.parametrize_targets("llvm")
def test_one_set_input(target, dev):
dtype = "float32"
Expand Down Expand Up @@ -956,11 +958,13 @@ def test_one_set_input(target, dev):
assert output.dtype == ref_res.dtype
tvm.testing.assert_allclose(ref_res_core, output.numpy())


def get_multiple_input_relay_mod(tensor_type, shape, data_name0, data_name1):
x, y = [relay.var(c, shape=shape, dtype = tensor_type) for c in [data_name0, data_name1]]
x, y = [relay.var(c, shape=shape, dtype=tensor_type) for c in [data_name0, data_name1]]
f = relay.Function([x, y], x + y)
return IRModule.from_expr(f)


@tvm.testing.parametrize_targets("llvm")
def test_multiple_set_input(target, dev):
dtype = "float32"
Expand Down Expand Up @@ -992,6 +996,7 @@ def test_multiple_set_input(target, dev):
assert output.dtype == ref_res.dtype
tvm.testing.assert_allclose(ref_res_core, output.numpy())


@tvm.testing.parametrize_targets("llvm")
def test_one_set_one_input(target, dev):
dtype = "float32"
Expand Down Expand Up @@ -1025,6 +1030,7 @@ def test_one_set_one_input(target, dev):
assert output.dtype == ref_res.dtype
tvm.testing.assert_allclose(ref_res_core, output.numpy())


@tvm.testing.parametrize_targets("llvm")
def test_multiple_set_one_input(target, dev):
dtype = "float32"
Expand Down Expand Up @@ -1065,6 +1071,7 @@ def test_multiple_set_one_input(target, dev):
assert output.dtype == ref_res.dtype
tvm.testing.assert_allclose(ref_res_core, output.numpy())


@tvm.testing.parametrize_targets("llvm")
def test_benchmark(target, dev):
mod, params = mlp.get_workload(1)
Expand Down

0 comments on commit f373d8a

Please sign in to comment.