Skip to content

Commit

Permalink
vm external codegen (apache#4544)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Dec 31, 2019
1 parent 3a0a606 commit 02c850f
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 56 deletions.
81 changes: 52 additions & 29 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -476,30 +476,39 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
argument_registers.push_back(reg->second);
}

// Next generate the invoke instruction.
Target target;
if (targets_.size() == 1) {
// homogeneous execution.
for (auto kv : targets_) {
target = kv.second;
}

if (!func->UseDefaultCompiler()) {
target = tvm::target::ext_dev();
} else {
// heterogeneous execution.
LOG(FATAL) << "Currently VM compiler doesn't support heterogeneous compilation";
// Next generate the invoke instruction.
if (targets_.size() == 1) {
// homogeneous execution.
const auto& it = targets_.begin();
target = (*it).second;
} else {
// heterogeneous execution.
LOG(FATAL) << "Currently VM compiler doesn't support heterogeneous compilation";
}
}

auto key = CCacheKeyNode::make(func, target);
auto cfunc = engine_->Lower(key);

// TODO(jroesch): support lowered funcs for multiple targets
CHECK_EQ(cfunc->funcs.size(), 1);
auto op_index = -1;
if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) {
if (!func->UseDefaultCompiler()) {
op_index = context_->cached_funcs.size();
context_->cached_funcs.push_back(cfunc);
context_->seen_funcs[cfunc->funcs[0]] = op_index;
} else {
op_index = context_->seen_funcs[cfunc->funcs[0]];
// TODO(jroesch): support lowered funcs for multiple targets
CHECK_EQ(cfunc->funcs.size(), 1);
if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) {
op_index = context_->cached_funcs.size();
context_->cached_funcs.push_back(cfunc);
context_->seen_funcs[cfunc->funcs[0]] = op_index;
} else {
op_index = context_->seen_funcs[cfunc->funcs[0]];
}
}

Emit(Instruction::InvokePacked(op_index,
Expand Down Expand Up @@ -950,32 +959,46 @@ void VMCompiler::LibraryCodegen() {
if (cached_funcs.size() == 0) {
return;
}
std::unordered_map<std::string, Array<LoweredFunc>> tgt_funcs;
for (auto &cfunc : cached_funcs) {
std::unordered_map<std::string, Array<LoweredFunc>> funcs;
for (auto& cfunc : cached_funcs) {
std::string target_str = cfunc->target->str();
if (tgt_funcs.count(target_str) == 0) {
tgt_funcs.emplace(target_str, Array<LoweredFunc>{cfunc->funcs[0]});
if (target_str == "ext_dev") {
continue;
} else if (funcs.count(target_str) == 0) {
funcs.emplace(target_str, Array<LoweredFunc>{cfunc->funcs[0]});
} else {
tgt_funcs[target_str].push_back(cfunc->funcs[0]);
funcs[target_str].push_back(cfunc->funcs[0]);
}
}
Map<Target, Array<LoweredFunc>> funcs;
for (auto &it : tgt_funcs) {
funcs.Set(Target::Create(it.first), it.second);
}

if (const auto *f = runtime::Registry::Get("relay.backend.build")) {
// The target is just a dummy arg because funcs already contains corresponding target
// therefore target won't be used in the build function
runtime::Module mod = (*f)(funcs, Target(), target_host_);
auto compile_engine = CompileEngine::Global();
auto ext_mods = compile_engine->LowerExternalFunctions();
runtime::Module mod;
if (funcs.size() > 0) {
mod = tvm::build(funcs, target_host_, tvm::BuildConfig::Current());
CHECK(mod.operator->());
exec_->lib = mod;
} else {
LOG(FATAL) << "relay.backend.build is not registered";
CHECK_EQ(ext_mods.size(), 1U)
<< "Expect to have a TVM DSOModule when multiple runtime modules exist";
}
if (!ext_mods.empty()) {
if (funcs.size() == 0) {
mod = ext_mods[0];
} else {
// Import all external runtime modules.
for (auto it : ext_mods) {
mod.Import(it);
}
}
}
exec_->lib = mod;
size_t primitive_index = 0;
for (auto cfunc : cached_funcs) {
exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
if (cfunc->target->str() == "ext_dev") {
exec_->primitive_map.insert({cfunc->func_name, primitive_index++});
} else {
exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
}
}
}

Expand Down
4 changes: 3 additions & 1 deletion src/runtime/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,9 @@ void VirtualMachine::LoadExecutable(const Executable* exec) {
if (packed_funcs_.size() <= packed_index) {
packed_funcs_.resize(packed_index + 1);
}
packed_funcs_[packed_index] = lib.GetFunction(packed_name);
tvm::runtime::PackedFunc pf = lib.GetFunction(packed_name, true);
CHECK(pf != nullptr) << "Cannot find function in module: " << packed_name;
packed_funcs_[packed_index] = pf;
}
}

Expand Down
70 changes: 44 additions & 26 deletions tests/python/relay/test_external_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,36 +26,54 @@
from tvm import relay
from tvm.contrib import util

def check_result(mod, map_inputs, out_shape, result, tol=1e-5):
def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
ctx=tvm.cpu()):
if sys.platform == "win32":
print("Skip test on Windows for now")
return

with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
json, lib, _ = relay.build(mod, "llvm")
test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
source_dir = os.path.join(test_dir, "..", "..", "..")
contrib_path = os.path.join(source_dir, "src", "runtime", "contrib")

kwargs = {}
kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path]
tmp_path = util.tempdir()
lib_name = 'lib.so'
lib_path = tmp_path.relpath(lib_name)
lib.export_library(lib_path, fcompile=False, **kwargs)
lib = tvm.module.load(lib_path)

ctx = tvm.cpu()
rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)

for name, data in map_inputs.items():
rt_mod.set_input(name, data)

rt_mod.run()
out = tvm.nd.empty(out_shape, ctx=ctx)
out = rt_mod.get_output(0, out)

tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
def update_lib(lib):
test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
source_dir = os.path.join(test_dir, "..", "..", "..")
contrib_path = os.path.join(source_dir, "src", "runtime", "contrib")

kwargs = {}
kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path]
tmp_path = util.tempdir()
lib_name = 'lib.so'
lib_path = tmp_path.relpath(lib_name)
lib.export_library(lib_path, fcompile=False, **kwargs)
lib = tvm.module.load(lib_path)

return lib

def check_vm_result():
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
exe = relay.vm.compile(mod, target=target)
code, lib = exe.save()
lib = update_lib(lib)
exe = relay.vm.Executable.load_exec(code, lib)
vm = relay.vm.VirtualMachine(exe)
vm.init(ctx)
out = vm.run(**map_inputs)
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)

def check_graph_runtime_result():
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
json, lib, _ = relay.build(mod, target=target)
lib = update_lib(lib)
rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)

for name, data in map_inputs.items():
rt_mod.set_input(name, data)
rt_mod.run()
out = tvm.nd.empty(out_shape, ctx=ctx)
out = rt_mod.get_output(0, out)

tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)

check_vm_result()
check_graph_runtime_result()


def set_external_func_attr(func, compiler, ext_symbol):
Expand Down

0 comments on commit 02c850f

Please sign in to comment.