Skip to content

Commit

Permalink
[AOT] Name mangling in AOT
Browse files Browse the repository at this point in the history
Mini-RFC is here: https://discuss.tvm.apache.org/t/mini-rfc-name-mangling-in-aot

With this change we'll mangle the name of global symbols so that we can bundle
together multiple models in the same application.

The relay.build interface has been left unchanged, which means I am
resuing mod_name as a prefix for all functions. If mod_name is None then
a "_tvm" prefix is used.

I had to add two different compilation functions:
- _CompileEngineLowerWithModuleName to mangle all the operators with the mod_name
- PartitionGraphWithModName to mangle all the operators produced by BYOC

I could have changed signature of both, but that would have meant a very
invasive refactoring.

I refactored the aot test utils and added some tests for multiple
models.

Change-Id: I30e93fa075f660054577ea36cf9268ec0c6eebcb
  • Loading branch information
Giuseppe Rossini committed May 27, 2021
1 parent 69e56c6 commit b0bcfd4
Show file tree
Hide file tree
Showing 30 changed files with 552 additions and 191 deletions.
2 changes: 1 addition & 1 deletion include/tvm/runtime/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ constexpr const char* tvm_param_prefix = "__tvm_param__";
/*! \brief A PackedFunc that looks up linked parameters by storage_id. */
constexpr const char* tvm_lookup_linked_param = "_lookup_linked_param";
/*! \brief The main AOT executor function */
constexpr const char* tvm_run_func_prefix = "tvm__run_func";
constexpr const char* tvm_run_func_suffix = "run_model";
} // namespace symbol

// implementations of inline functions.
Expand Down
16 changes: 10 additions & 6 deletions python/tvm/micro/model_library_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class UnsupportedInModelLibraryFormatError(Exception):
"""Raised when export_model_library_format does not support the given Module tree."""


def _populate_codegen_dir(mod, codegen_dir: str):
def _populate_codegen_dir(mod, codegen_dir: str, module_name: str = None):
"""Populate the codegen sub-directory as part of a Model Library Format export.
Parameters
Expand All @@ -44,6 +44,9 @@ def _populate_codegen_dir(mod, codegen_dir: str):
Module which should be written to codegen_dir.
codegen_dir : str
Path to the codegen directory on disk.
module_name: Optional[str]
Name used to prefix the generated source files
"""
dso_modules = mod._collect_dso_modules()
dso_module_handles = [m.handle.value for m in dso_modules]
Expand All @@ -55,17 +58,19 @@ def _populate_codegen_dir(mod, codegen_dir: str):

mod_indices = {"lib": 0, "src": 0}
host_codegen_dir = os.path.join(codegen_dir, "host")
lib_name = f"{module_name}_lib" if module_name else "lib"

for dso_mod in dso_modules:
if dso_mod.type_key == "c":
index = mod_indices["src"]
mod_indices["src"] += 1
parent_dir = os.path.join(host_codegen_dir, "src")
file_name = os.path.join(parent_dir, f"lib{index}.c")
file_name = os.path.join(parent_dir, f"{lib_name}{index}.c")
elif dso_mod.type_key == "llvm":
index = mod_indices["lib"]
mod_indices["lib"] += 1
parent_dir = os.path.join(host_codegen_dir, "lib")
file_name = os.path.join(parent_dir, f"lib{index}.o")
file_name = os.path.join(parent_dir, f"{lib_name}{index}.o")
else:
assert (
False
Expand Down Expand Up @@ -98,7 +103,6 @@ def _build_sid_map(graph_json):
A list with one entry per storage id describing that memory.
"""
graph = json.loads(graph_json)

seen_storage_ids = set()
memory_map = []
for node_id, storage_id in enumerate(graph["attrs"]["storage_id"][1]):
Expand Down Expand Up @@ -227,7 +231,7 @@ def export_model_library_format(mod: executor_factory.ExecutorFactoryModule, fil
runtime = ["aot"] if is_aot else ["graph"]

metadata = {
"version": 2,
"version": 3,
"model_name": mod.libmod_name,
"export_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%SZ"),
"memory": _build_memory_map(mod),
Expand All @@ -240,7 +244,7 @@ def export_model_library_format(mod: executor_factory.ExecutorFactoryModule, fil

codegen_dir_path = tempdir.relpath("codegen")
os.mkdir(codegen_dir_path)
_populate_codegen_dir(mod.lib, codegen_dir_path)
_populate_codegen_dir(mod.lib, codegen_dir_path, mod.libmod_name)

parameters_dir_path = tempdir.relpath("parameters")
os.mkdir(parameters_dir_path)
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/backend/compile_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tvm.runtime import Object
from tvm.support import libinfo
from tvm.target import Target
from ..backend.utils import mangle_module_name
from .. import function as _function
from .. import ty as _ty
from . import _backend
Expand Down Expand Up @@ -328,7 +329,7 @@ class CompileEngine(Object):
def __init__(self):
raise RuntimeError("Cannot construct a CompileEngine")

def lower(self, source_func, target=None):
def lower(self, source_func, target=None, mod_name="default"):
"""Lower a source_func to a CachedFunc.
Parameters
Expand All @@ -346,8 +347,9 @@ def lower(self, source_func, target=None):
"""
# pylint: disable=broad-except, import-outside-toplevel
try:
mod_name = mangle_module_name(mod_name)
key = _get_cache_key(source_func, target)
return _backend._CompileEngineLower(self, key)
return _backend._CompileEngineLower(self, key, mod_name)
except Exception:
import traceback

Expand Down
37 changes: 37 additions & 0 deletions python/tvm/relay/backend/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Utility backend functions."""


def _is_valid_modname(mod_name):
"""Determine if mod_name is a valid string to use inside function names"""
if mod_name:
try:
mod_name.encode("ascii")
return True
except UnicodeEncodeError:
return False

return True


def mangle_module_name(mod_name):
if not _is_valid_modname(mod_name):
raise ValueError(mod_name + " contains invalid characters")
if mod_name:
return "tvmgen_" + mod_name
return "tvmgen"
15 changes: 12 additions & 3 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from . import expr as _expr
from . import function as _function
from .transform import InferType
from .backend.utils import mangle_module_name
from .backend import executor_factory as _executor_factory
from .backend import interpreter as _interpreter
from .backend.vm import VMExecutor
Expand Down Expand Up @@ -85,7 +86,9 @@ def __init__(self):
self._get_params_func = self.mod["get_params"]
self._get_function_metadata = self.mod["get_function_metadata"]

def build(self, mod, target=None, target_host=None, params=None, executor="graph"):
def build(
self, mod, target=None, target_host=None, params=None, executor="graph", mod_name=None
):
"""
Parameters
----------
Expand Down Expand Up @@ -115,6 +118,9 @@ def build(self, mod, target=None, target_host=None, params=None, executor="graph
- If "graph" is specified, then the graph_executor will be used
- If "aot" is specified, then the aot_executor will be used
mod_name: Optional[str]
The module name we will build
Returns
-------
graph_json : str
Expand Down Expand Up @@ -145,7 +151,9 @@ def build(self, mod, target=None, target_host=None, params=None, executor="graph
old_autotvm_silent = autotvm.GLOBAL_SCOPE.silent
autotvm.GLOBAL_SCOPE.silent = use_auto_scheduler

self._build(mod, target, target_host, executor)
mod_name = mangle_module_name(mod_name)

self._build(mod, target, target_host, executor, mod_name)
autotvm.GLOBAL_SCOPE.silent = old_autotvm_silent

# Get artifacts
Expand Down Expand Up @@ -295,6 +303,7 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"
"""
# pylint: enable=line-too-long
# fmt: on

if not isinstance(ir_mod, (IRModule, _function.Function)):
raise ValueError("Type of input parameter mod must be tvm.IRModule")

Expand Down Expand Up @@ -330,7 +339,7 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"
with tophub_context:
bld_mod = BuildModule()
executor_config, runtime_mod, params = bld_mod.build(
mod=ir_mod, target=target, params=params, executor=executor
mod=ir_mod, target=target, params=params, executor=executor, mod_name=mod_name
)
func_metadata = bld_mod.get_function_metadata()

Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from tvm import relay
from . import _ffi_api
from ..backend.utils import mangle_module_name


def build_config(opt_level=2, required_pass=None, disabled_pass=None, trace=None):
Expand Down Expand Up @@ -714,7 +715,7 @@ def LambdaLift():
return _ffi_api.LambdaLift()


def PartitionGraph():
def PartitionGraph(mod_name="default"):
"""Partition a Relay program into regions that can be executed on different
backends.
Expand All @@ -723,7 +724,8 @@ def PartitionGraph():
ret: tvm.transform.Pass
The registered pass that partitions the Relay program.
"""
return _ffi_api.PartitionGraph()
mod_name = mangle_module_name(mod_name)
return _ffi_api.PartitionGraph(mod_name)


def AnnotateTarget(targets, include_non_call_ops=True):
Expand Down
30 changes: 20 additions & 10 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -359,11 +359,12 @@ class AOTExecutorCodegen : public ExprVisitor {
auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey");
auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
Target target;

// Handle external function
if (func->GetAttr<String>(attr::kCompiler).defined()) {
target = Target("ext_dev");
CCacheKey key = (*pf0)(func, target);
CachedFunc ext_func = (*pf1)(compile_engine_, key);
CachedFunc ext_func = (*pf1)(compile_engine_, key, mod_name_);
ICHECK(ext_func.defined()) << "External function is not defined.";
UpdateConstants(func, &params_);

Expand Down Expand Up @@ -394,7 +395,7 @@ class AOTExecutorCodegen : public ExprVisitor {
target = targets_[call_dev_type];
}
CCacheKey key = (*pf0)(func, target);
CachedFunc lowered_func = (*pf1)(compile_engine_, key);
CachedFunc lowered_func = (*pf1)(compile_engine_, key, mod_name_);
if (!lowered_funcs_.count(target->str())) {
lowered_funcs_[target->str()] = IRModule(Map<GlobalVar, BaseFunc>({}));
}
Expand Down Expand Up @@ -517,7 +518,10 @@ class AOTExecutorCodegen : public ExprVisitor {

// Define the PrimFunc attributes
Map<String, ObjectRef> dict_attrs;
dict_attrs.Set("global_symbol", runtime::String(runtime::symbol::tvm_run_func_prefix));
String run_func_name =
runtime::get_name_mangled(mod_name_, runtime::symbol::tvm_run_func_suffix);
dict_attrs.Set("global_symbol", run_func_name);
dict_attrs.Set("runner_function", Bool(true));

// Make the PrimFunc
return tir::PrimFunc(main_signature_, body, VoidType(), Map<tir::Var, tir::Buffer>(),
Expand Down Expand Up @@ -561,6 +565,8 @@ class AOTExecutorCodegen : public ExprVisitor {
std::vector<tir::Stmt> stmts_;
/*! \brief the list of return sids (note that the function might return more then one output */
IntegerArray return_sid_;
/*! \brief the module name we use to mangle the function names */
String mod_name_;

public:
AOTExecutorCodegen(runtime::Module* mod, const TargetsMap& targets, Target target_host)
Expand All @@ -570,10 +576,11 @@ class AOTExecutorCodegen : public ExprVisitor {
target_host_ = target_host;
}

LoweredOutput Codegen(relay::Function func) {
LoweredOutput Codegen(relay::Function func, String mod_name) {
// Get the module, storage map and token sizes
auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
storage_device_map_ = (*pf)(func);
mod_name_ = mod_name;

int input_index = 0;
for (auto input : func->params) {
Expand Down Expand Up @@ -621,15 +628,15 @@ class AOTExecutorCodegen : public ExprVisitor {
auto target_host_str = target_host_->str();
if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) {
ret.lowered_funcs[target_host_str]->Add(
GlobalVar(::tvm::runtime::symbol::tvm_run_func_prefix), prim_func);
GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func);
} else {
Map<GlobalVar, BaseFunc> symbol_map;
symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_prefix), prim_func);
symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func);
ret.lowered_funcs.Set(target_host_str, IRModule(symbol_map));
}
ret.function_metadata = std::move(function_metadata_);
ret.metadata =
runtime::Metadata(input_vars_.size(), return_sid_.size(), runtime::kTvmExecutorAot);
ret.metadata = runtime::Metadata(input_vars_.size(), return_sid_.size(),
runtime::kTvmExecutorAot, mod_name);
return ret;
}
};
Expand All @@ -649,7 +656,8 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode {
} else if (name == "codegen") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
Function func = args[0];
this->output_ = codegen(func);
String mod_name = args[1];
this->output_ = codegen(func, mod_name);
});
} else if (name == "list_params_name") {
return PackedFunc(
Expand Down Expand Up @@ -700,7 +708,9 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode {
targets, target_host);
}

LoweredOutput codegen(Function func) { return this->codegen_->Codegen(func); }
LoweredOutput codegen(Function func, String mod_name) {
return this->codegen_->Codegen(func, mod_name);
}

Array<runtime::String> list_params_name() {
Array<runtime::String> ret;
Expand Down
15 changes: 8 additions & 7 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ struct BuildOutput {
struct ExecutorCodegen {
void Init(runtime::Module* m, TargetsMap targets) { CallFunc("init", m, targets); }

void Codegen(const Function& func) { CallFunc("codegen", func); }
void Codegen(const Function& func, String mod_name) { CallFunc("codegen", func, mod_name); }

virtual void UpdateOutput(BuildOutput* ret) = 0;

Expand Down Expand Up @@ -177,8 +177,8 @@ class RelayBuildModule : public runtime::ModuleNode {
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetModule(); });
} else if (name == "build") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
ICHECK_EQ(args.num_args, 4);
this->Build(args[0], args[1], args[2], args[3]);
ICHECK_EQ(args.num_args, 5);
this->Build(args[0], args[1], args[2], args[3], args[4]);
});
} else if (name == "list_params") {
return PackedFunc(
Expand Down Expand Up @@ -279,13 +279,13 @@ class RelayBuildModule : public runtime::ModuleNode {
* \param target_host Host target device
*/
void Build(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host,
const String executor) {
const String executor, const String mod_name) {
// Create protected variable targets_ from ground up
targets_ = targets;
target_host_ = target_host;
executor_ = executor;
CheckAndUpdateHostConsistency(&targets_, &target_host_);
BuildRelay(mod, params_);
BuildRelay(mod, params_, mod_name);
// Clear compile engine so that tuning schedules can be changed between runs. See issue #6096.
CompileEngine::Global()->Clear();
}
Expand Down Expand Up @@ -508,7 +508,8 @@ class RelayBuildModule : public runtime::ModuleNode {
* \param params The parameters.
*/
void BuildRelay(IRModule relay_module,
const std::unordered_map<std::string, tvm::runtime::NDArray>& params) {
const std::unordered_map<std::string, tvm::runtime::NDArray>& params,
const String mod_name) {
Target target_host = GetTargetHost();
// If no target_host has been set, we choose a default one, which is
// llvm if "codegen.LLVMModuleCreate" is accessible.
Expand All @@ -527,7 +528,7 @@ class RelayBuildModule : public runtime::ModuleNode {
// Generate code for the updated function.
executor_codegen_ = MakeExecutorCodegen(executor_);
executor_codegen_->Init(nullptr, targets_);
executor_codegen_->Codegen(func);
executor_codegen_->Codegen(func, mod_name);
executor_codegen_->UpdateOutput(&ret_);
ret_.params = executor_codegen_->GetParams();

Expand Down
Loading

0 comments on commit b0bcfd4

Please sign in to comment.