Skip to content

Commit

Permalink
[VM] Avoid round-trip Target->str->Target conversions (apache#8161)
Browse files Browse the repository at this point in the history
Currently, in some cases this round-trip cannot be completed.  For
example, if an Integer value has a value outside a 32-bit signed
integer range, or if a String value contains spaces.

Co-authored-by: Eric Lunderberg <elunderberg@octoml.ai>
  • Loading branch information
Lunderberg and Lunderberg committed May 30, 2021
1 parent 8b5d843 commit e535ec8
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 21 deletions.
24 changes: 15 additions & 9 deletions python/tvm/relay/backend/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,20 +198,26 @@ def _update_target(self, target):
target = target if target else tvm.target.Target.current()
if target is None:
raise ValueError("Target is not set in env or passed as argument.")
tgts = {}
if isinstance(target, (str, tvm.target.Target)):
dev_type = tvm.tir.IntImm("int32", tvm.nd.device(str(target)).device_type)
tgts[dev_type] = tvm.target.Target(target)
elif isinstance(target, dict):
for dev, tgt in target.items():
dev_type = tvm.tir.IntImm("int32", tvm.nd.device(dev).device_type)
tgts[dev_type] = tvm.target.Target(tgt)
else:

if isinstance(target, str):
target = {target: target}
elif isinstance(target, tvm.target.Target):
target = {target.kind.name: target}
elif not isinstance(target, dict):
raise TypeError(
"target is expected to be str, tvm.target.Target, "
+ "or dict of str to str/tvm.target.Target, but received "
+ "{}".format(type(target))
)

tgts = {}
for dev, tgt in target.items():
dev_type = tvm.tir.IntImm("int32", tvm.nd.device(dev).device_type)
if isinstance(tgt, str):
tgt = tvm.target.Target(tgt)

tgts[dev_type] = tgt

return tgts

def _update_target_host(self, target, target_host):
Expand Down
19 changes: 7 additions & 12 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1156,37 +1156,32 @@ void VMCompiler::Codegen() {
if (cached_funcs.size() == 0) {
return;
}
std::unordered_map<std::string, IRModule> funcs;
Map<Target, IRModule> funcs;

for (auto& cfunc : cached_funcs) {
std::string target_str = cfunc->target->str();
Target target = cfunc->target;
// NOTE: because module, is mutable, we need to make an
// explicit copy of the IRModule.
IRModule mod = cfunc->funcs;
mod.CopyOnWrite();

if (target_str == "ext_dev") {
if (target->kind->device_type == kDLExtDev) {
// Collect metadata in functions that are handled by external codegen.
ICHECK(mod->ContainGlobalVar(cfunc->func_name));
Function func = Downcast<Function>(mod->Lookup(cfunc->func_name));
backend::UpdateConstants(func, &params_);
continue;
} else if (funcs.count(target_str) == 0) {
funcs.emplace(target_str, mod);
} else if (funcs.count(target) == 0) {
funcs.Set(target, mod);
} else {
funcs[target_str]->Update(mod);
funcs[target]->Update(mod);
}
}

auto compile_engine = CompileEngine::Global();
auto ext_mods = compile_engine->LowerExternalFunctions();
runtime::Module lib;
if (funcs.size() > 0) {
Map<String, IRModule> build_funcs;
for (const auto& i : funcs) {
build_funcs.Set(i.first, i.second);
}
lib = tvm::build(build_funcs, target_host_);
lib = tvm::build(funcs, target_host_);
} else {
// There is no function handled by TVM. We create a virtual main module
// to make sure a DSO module will be also available.
Expand Down

0 comments on commit e535ec8

Please sign in to comment.