Skip to content

Commit

Permalink
[AutoScheduler] Add custom build function (apache#7185)
Browse files Browse the repository at this point in the history
* [AutoScheduler] Add custom build function

Signed-off-by: leowang1225 <810916296@qq.com>

* [AutoScheduler] Add custom build function

Signed-off-by: leowang1225 <810916296@qq.com>

* cheduler] Add custom build function

* [AutoScheduler] Add custom build function

Signed-off-by: leowang1225 <810916296@qq.com>

* [AutoScheduler] Add custom build function

Signed-off-by: leowang1225 <810916296@qq.com>

* [AutoScheduler] Add custom build function

Signed-off-by: leowang1225 <810916296@qq.com>
  • Loading branch information
leowang1225 authored and masahi committed Jan 14, 2021
1 parent 3dfe8e9 commit cbd62b6
Showing 1 changed file with 35 additions and 9 deletions.
44 changes: 35 additions & 9 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,18 @@
MAX_FLOAT = 1e10


class BuildFunc:
"""store build_func name and callable to class variable.
name: str = "default"
The name of registered build function.
build_func: callable = tar.tar
The callable of registered build function.
"""

name = "default"
build_func = tar.tar


@tvm._ffi.register_object("auto_scheduler.MeasureCallback")
class MeasureCallback(Object):
""" The base class of measurement callback functions. """
Expand Down Expand Up @@ -303,12 +315,28 @@ class LocalBuilder(ProgramBuilder):
This is used in a wrapper of the multiprocessing.Process.join().
n_parallel : int = multiprocessing.cpu_count()
Number of threads used to build in parallel.
build_func : str = 'default'
The name of registered build function.
build_func: callable or str = "default"
If is 'default', use default build function
If is 'ndk', use function for android ndk
If is callable, use it as custom build function, expect lib_format field.
"""

def __init__(self, timeout=15, n_parallel=multiprocessing.cpu_count(), build_func="default"):
self.__init_handle_by_constructor__(_ffi_api.LocalBuilder, timeout, n_parallel, build_func)
if build_func == "default":
BuildFunc.name = "default"
BuildFunc.build_func = tar.tar
elif build_func == "ndk":
BuildFunc.name = "ndk"
BuildFunc.build_func = ndk.create_shared
elif callable(build_func):
BuildFunc.name = "custom"
BuildFunc.build_func = build_func
else:
raise ValueError("Invalid build_func" + build_func)

self.__init_handle_by_constructor__(
_ffi_api.LocalBuilder, timeout, n_parallel, BuildFunc.name
)


@tvm._ffi.register_object("auto_scheduler.LocalRunner")
Expand Down Expand Up @@ -624,12 +652,10 @@ def local_build_worker(args):
The build result of this Builder thread.
"""
inp, build_func, timeout, verbose = args
if build_func == "default":
build_func = tar.tar
elif build_func == "ndk":
build_func = ndk.create_shared
else:
raise ValueError("Invalid build_func" + build_func)
assert build_func == BuildFunc.name, (
"BuildFunc.name: " + BuildFunc.name + ", but args is: " + build_func
)
build_func = BuildFunc.build_func

res = call_func_with_timeout(timeout, _timed_func, args=(inp, build_func, verbose))
if isinstance(res, TimeoutError):
Expand Down

0 comments on commit cbd62b6

Please sign in to comment.