Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AOT] Name mangling in AOT #8014

Merged
merged 2 commits into from
Jun 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions apps/microtvm/zephyr/aot_demo/src/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
#define WORKSPACE_SIZE (270 * 1024)

static uint8_t g_aot_memory[WORKSPACE_SIZE];
extern tvm_model_t network;
extern tvm_model_t tvmgen_default_network;
tvm_workspace_t app_workspace;

// Wakeup sequence used to wake up QEMU on the host.
Expand Down Expand Up @@ -205,7 +205,7 @@ void main(void) {

double elapsed_time = 0;
TVMPlatformTimerStart();
int ret_val = tvm_runtime_run(&network, inputs, outputs);
int ret_val = tvm_runtime_run(&tvmgen_default_network, inputs, outputs);
TVMPlatformTimerStop(&elapsed_time);

if (ret_val != 0) {
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ constexpr const char* tvm_param_prefix = "__tvm_param__";
/*! \brief A PackedFunc that looks up linked parameters by storage_id. */
constexpr const char* tvm_lookup_linked_param = "_lookup_linked_param";
/*! \brief The main AOT executor function */
constexpr const char* tvm_run_func_prefix = "tvm__run_func";
constexpr const char* tvm_run_func_suffix = "run_model";
} // namespace symbol

// implementations of inline functions.
Expand Down
16 changes: 10 additions & 6 deletions python/tvm/micro/model_library_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class UnsupportedInModelLibraryFormatError(Exception):
"""Raised when export_model_library_format does not support the given Module tree."""


def _populate_codegen_dir(mod, codegen_dir: str):
def _populate_codegen_dir(mod, codegen_dir: str, module_name: str = None):
"""Populate the codegen sub-directory as part of a Model Library Format export.

Parameters
Expand All @@ -44,6 +44,9 @@ def _populate_codegen_dir(mod, codegen_dir: str):
Module which should be written to codegen_dir.
codegen_dir : str
Path to the codegen directory on disk.
module_name: Optional[str]
Name used to prefix the generated source files

"""
dso_modules = mod._collect_dso_modules()
dso_module_handles = [m.handle.value for m in dso_modules]
Expand All @@ -55,17 +58,19 @@ def _populate_codegen_dir(mod, codegen_dir: str):

mod_indices = {"lib": 0, "src": 0}
host_codegen_dir = os.path.join(codegen_dir, "host")
lib_name = f"{module_name}_lib" if module_name else "lib"
giuseros marked this conversation as resolved.
Show resolved Hide resolved

for dso_mod in dso_modules:
if dso_mod.type_key == "c":
index = mod_indices["src"]
mod_indices["src"] += 1
parent_dir = os.path.join(host_codegen_dir, "src")
file_name = os.path.join(parent_dir, f"lib{index}.c")
file_name = os.path.join(parent_dir, f"{lib_name}{index}.c")
giuseros marked this conversation as resolved.
Show resolved Hide resolved
elif dso_mod.type_key == "llvm":
index = mod_indices["lib"]
mod_indices["lib"] += 1
parent_dir = os.path.join(host_codegen_dir, "lib")
file_name = os.path.join(parent_dir, f"lib{index}.o")
file_name = os.path.join(parent_dir, f"{lib_name}{index}.o")
else:
assert (
False
Expand Down Expand Up @@ -98,7 +103,6 @@ def _build_sid_map(graph_json):
A list with one entry per storage id describing that memory.
"""
graph = json.loads(graph_json)

seen_storage_ids = set()
memory_map = []
for node_id, storage_id in enumerate(graph["attrs"]["storage_id"][1]):
Expand Down Expand Up @@ -227,7 +231,7 @@ def export_model_library_format(mod: executor_factory.ExecutorFactoryModule, fil
runtime = ["aot"] if is_aot else ["graph"]

metadata = {
"version": 2,
"version": 3,
"model_name": mod.libmod_name,
"export_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%SZ"),
"memory": _build_memory_map(mod),
Expand All @@ -240,7 +244,7 @@ def export_model_library_format(mod: executor_factory.ExecutorFactoryModule, fil

codegen_dir_path = tempdir.relpath("codegen")
os.mkdir(codegen_dir_path)
_populate_codegen_dir(mod.lib, codegen_dir_path)
_populate_codegen_dir(mod.lib, codegen_dir_path, mod.libmod_name)

parameters_dir_path = tempdir.relpath("parameters")
os.mkdir(parameters_dir_path)
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/backend/compile_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tvm.runtime import Object
from tvm.support import libinfo
from tvm.target import Target
from ..backend.utils import mangle_module_name
from .. import function as _function
from .. import ty as _ty
from . import _backend
Expand Down Expand Up @@ -328,7 +329,7 @@ class CompileEngine(Object):
def __init__(self):
raise RuntimeError("Cannot construct a CompileEngine")

def lower(self, source_func, target=None):
def lower(self, source_func, target=None, mod_name="default"):
"""Lower a source_func to a CachedFunc.

Parameters
Expand All @@ -346,8 +347,9 @@ def lower(self, source_func, target=None):
"""
# pylint: disable=broad-except, import-outside-toplevel
try:
mod_name = mangle_module_name(mod_name)
key = _get_cache_key(source_func, target)
return _backend._CompileEngineLower(self, key)
return _backend._CompileEngineLower(self, key, mod_name)
except Exception:
import traceback

Expand Down
4 changes: 3 additions & 1 deletion python/tvm/relay/backend/graph_executor_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from tvm.relay import _build_module
from tvm.target import Target
from tvm.tir import expr as _expr
from .utils import mangle_module_name


class GraphExecutorCodegen(object):
Expand Down Expand Up @@ -80,7 +81,8 @@ def codegen(self, func):
params : Dict[str, tvm.nd.NDArray]
Additional constant parameters.
"""
self._codegen(func)
default_mod_name = mangle_module_name("default")
self._codegen(func, default_mod_name)
graph_json = self._get_graph_json()
lowered_func = self._get_irmodule()
param_names = self._list_params_name()
Expand Down
37 changes: 37 additions & 0 deletions python/tvm/relay/backend/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Utility backend functions."""


def _is_valid_modname(mod_name):
"""Determine if mod_name is a valid string to use inside function names"""
if mod_name:
try:
mod_name.encode("ascii")
return True
except UnicodeEncodeError:
return False

return True
giuseros marked this conversation as resolved.
Show resolved Hide resolved


def mangle_module_name(mod_name):
giuseros marked this conversation as resolved.
Show resolved Hide resolved
if not _is_valid_modname(mod_name):
raise ValueError(mod_name + " contains invalid characters")
if mod_name:
return "tvmgen_" + mod_name
return "tvmgen"
15 changes: 12 additions & 3 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from . import expr as _expr
from . import function as _function
from .transform import InferType
from .backend.utils import mangle_module_name
from .backend import executor_factory as _executor_factory
from .backend import interpreter as _interpreter
from .backend.vm import VMExecutor
Expand Down Expand Up @@ -85,7 +86,9 @@ def __init__(self):
self._get_params_func = self.mod["get_params"]
self._get_function_metadata = self.mod["get_function_metadata"]

def build(self, mod, target=None, target_host=None, params=None, executor="graph"):
def build(
self, mod, target=None, target_host=None, params=None, executor="graph", mod_name=None
):
"""
Parameters
----------
Expand Down Expand Up @@ -115,6 +118,9 @@ def build(self, mod, target=None, target_host=None, params=None, executor="graph
- If "graph" is specified, then the graph_executor will be used
- If "aot" is specified, then the aot_executor will be used

mod_name: Optional[str]
The module name we will build
giuseros marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
graph_json : str
Expand Down Expand Up @@ -145,7 +151,9 @@ def build(self, mod, target=None, target_host=None, params=None, executor="graph
old_autotvm_silent = autotvm.GLOBAL_SCOPE.silent
autotvm.GLOBAL_SCOPE.silent = use_auto_scheduler

self._build(mod, target, target_host, executor)
mod_name = mangle_module_name(mod_name)

self._build(mod, target, target_host, executor, mod_name)
autotvm.GLOBAL_SCOPE.silent = old_autotvm_silent

# Get artifacts
Expand Down Expand Up @@ -295,6 +303,7 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"
"""
# pylint: enable=line-too-long
# fmt: on

if not isinstance(ir_mod, (IRModule, _function.Function)):
raise ValueError("Type of input parameter mod must be tvm.IRModule")

Expand Down Expand Up @@ -330,7 +339,7 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"
with tophub_context:
bld_mod = BuildModule()
executor_config, runtime_mod, params = bld_mod.build(
mod=ir_mod, target=target, params=params, executor=executor
mod=ir_mod, target=target, params=params, executor=executor, mod_name=mod_name
)
func_metadata = bld_mod.get_function_metadata()

Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from tvm.runtime import ndarray as _nd

from . import _ffi_api
from ..backend.utils import mangle_module_name


def build_config(opt_level=2, required_pass=None, disabled_pass=None, trace=None):
Expand Down Expand Up @@ -713,7 +714,7 @@ def LambdaLift():
return _ffi_api.LambdaLift()


def PartitionGraph():
def PartitionGraph(mod_name="default"):
"""Partition a Relay program into regions that can be executed on different
backends.

Expand All @@ -722,7 +723,8 @@ def PartitionGraph():
ret: tvm.transform.Pass
The registered pass that partitions the Relay program.
"""
return _ffi_api.PartitionGraph()
mod_name = mangle_module_name(mod_name)
return _ffi_api.PartitionGraph(mod_name)


def AnnotateTarget(targets, include_non_call_ops=True):
Expand Down
30 changes: 20 additions & 10 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -375,11 +375,12 @@ class AOTExecutorCodegen : public ExprVisitor {
auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey");
auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
Target target;

// Handle external function
if (func->GetAttr<String>(attr::kCompiler).defined()) {
target = Target("ext_dev");
CCacheKey key = (*pf0)(func, target);
CachedFunc ext_func = (*pf1)(compile_engine_, key);
CachedFunc ext_func = (*pf1)(compile_engine_, key, mod_name_);
ICHECK(ext_func.defined()) << "External function is not defined.";
UpdateConstants(func, &params_);

Expand Down Expand Up @@ -410,7 +411,7 @@ class AOTExecutorCodegen : public ExprVisitor {
target = targets_[call_dev_type];
}
CCacheKey key = (*pf0)(func, target);
CachedFunc lowered_func = (*pf1)(compile_engine_, key);
CachedFunc lowered_func = (*pf1)(compile_engine_, key, mod_name_);
if (!lowered_funcs_.count(target->str())) {
lowered_funcs_[target->str()] = IRModule(Map<GlobalVar, BaseFunc>({}));
}
Expand Down Expand Up @@ -533,7 +534,10 @@ class AOTExecutorCodegen : public ExprVisitor {

// Define the PrimFunc attributes
Map<String, ObjectRef> dict_attrs;
dict_attrs.Set("global_symbol", runtime::String(runtime::symbol::tvm_run_func_prefix));
String run_func_name =
runtime::get_name_mangled(mod_name_, runtime::symbol::tvm_run_func_suffix);
dict_attrs.Set("global_symbol", run_func_name);
dict_attrs.Set("runner_function", Bool(true));

// Make the PrimFunc
return tir::PrimFunc(main_signature_, body, VoidType(), Map<tir::Var, tir::Buffer>(),
Expand Down Expand Up @@ -586,6 +590,8 @@ class AOTExecutorCodegen : public ExprVisitor {
std::vector<tir::Stmt> stmts_;
/*! \brief the list of return sids (note that the function might return more then one output */
IntegerArray return_sid_;
/*! \brief the module name we use to mangle the function names */
String mod_name_;

public:
AOTExecutorCodegen(runtime::Module* mod, const TargetsMap& targets, Target target_host)
Expand All @@ -595,10 +601,11 @@ class AOTExecutorCodegen : public ExprVisitor {
use_unpacked_api_(target_host->GetAttr<Bool>("unpacked-api").value_or(Bool(false))),
compile_engine_(CompileEngine::Global()) {}

LoweredOutput Codegen(relay::Function func) {
LoweredOutput Codegen(relay::Function func, String mod_name) {
// Get the module, storage map and token sizes
auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
storage_device_map_ = (*pf)(func);
mod_name_ = mod_name;

for (auto input : func->params) {
input_vars_.push_back(input);
Expand Down Expand Up @@ -645,15 +652,15 @@ class AOTExecutorCodegen : public ExprVisitor {
auto target_host_str = target_host_->str();
if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) {
ret.lowered_funcs[target_host_str]->Add(
GlobalVar(::tvm::runtime::symbol::tvm_run_func_prefix), prim_func);
GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func);
} else {
Map<GlobalVar, BaseFunc> symbol_map;
symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_prefix), prim_func);
symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func);
ret.lowered_funcs.Set(target_host_str, IRModule(symbol_map));
}
ret.function_metadata = std::move(function_metadata_);
ret.metadata =
runtime::Metadata(input_vars_.size(), return_sid_.size(), runtime::kTvmExecutorAot);
ret.metadata = runtime::Metadata(input_vars_.size(), return_sid_.size(),
runtime::kTvmExecutorAot, mod_name);
return ret;
}
};
Expand All @@ -673,7 +680,8 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode {
} else if (name == "codegen") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
Function func = args[0];
this->output_ = codegen(func);
String mod_name = args[1];
this->output_ = codegen(func, mod_name);
});
} else if (name == "list_params_name") {
return PackedFunc(
Expand Down Expand Up @@ -724,7 +732,9 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode {
targets, target_host);
}

LoweredOutput codegen(Function func) { return this->codegen_->Codegen(func); }
LoweredOutput codegen(Function func, String mod_name) {
return this->codegen_->Codegen(func, mod_name);
}

Array<runtime::String> list_params_name() {
Array<runtime::String> ret;
Expand Down
Loading