Skip to content

Commit

Permalink
rebase and fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Sep 3, 2020
1 parent df47ba4 commit 8f572b3
Show file tree
Hide file tree
Showing 9 changed files with 23 additions and 29 deletions.
3 changes: 2 additions & 1 deletion python/tvm/runtime/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,8 @@ def _setup_ctx(self, ctx, memory_cfg):
ctxs = ctx
if not isinstance(ctx, (list, tuple)):
if not isinstance(ctx, tvm.runtime.TVMContext):
raise TypeError("ctx is expected to be TVMContex")
raise TypeError("ctx is expected to be TVMContext or \
List[TVMContext]")
ctxs = [ctx]

# CPU is required for executing shape functions
Expand Down
6 changes: 1 addition & 5 deletions src/relay/analysis/context_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,6 @@ class ContextAnalyzer : public ExprVisitor {
inps.push_back(fn->params[0]);
outs.push_back(call->op);
Expr body = fn->body;
// outs.push_back(fn->body);
CHECK(body->IsInstance<CallNode>() && IsDeviceCopy(body));
Call call_body = Downcast<Call>(body);
attrs = call_body->attrs.as<DeviceCopyAttrs>();
Expand Down Expand Up @@ -715,10 +714,7 @@ PackedAnalysisResultMap ContextAnalysisPacked(const IRModule& mod,
return ret;
}

TVM_REGISTER_GLOBAL("relay.analysis.ContextAnalysis")
.set_body_typed([](IRModule mod, TVMContext default_context) {
return ContextAnalysisPacked(mod, default_context);
});
TVM_REGISTER_GLOBAL("relay.analysis.ContextAnalysis").set_body_typed(ContextAnalysisPacked);

} // namespace relay
} // namespace tvm
14 changes: 6 additions & 8 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -531,14 +531,12 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
const auto& it = targets_.begin();
target = (*it).second;
} else {
if (expr_device_map_.count(func) == 0 ||
targets_.count(expr_device_map_[func].device_type) == 0) {
int fallback_dev = GetFallbackDevice();
auto dev_name = runtime::DeviceName(fallback_dev);
if (expr_device_map_.count(func) == 0) {
LOG(WARNING) << "The function is not annotated. Fallback to " << dev_name;
}
target = CreateDefaultTarget(fallback_dev);
CHECK_GT(expr_device_map_.count(func), 0U)
<< "Found not annotated expression, please make sure "
"context analysis has been executed";
int dev_type = expr_device_map_[func].device_type;
if (targets_.count(dev_type) == 0) {
target = CreateDefaultTarget(dev_type);
} else {
target = targets_[expr_device_map_[func].device_type];
}
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/benchmarking/benchmark_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def get_vm_output(mod, data, params, target, ctx, dtype='float32',
# random input
data = np.random.uniform(size=data_shape).astype(dtype)

for target, ctx in testing.ctx_list():
for target, ctx in testing.enabled_targets():
tvm_out = get_graph_runtime_output(mod, tvm.nd.array(data.astype(dtype)),
params, target, ctx, dtype)
vm_out = get_vm_output(mod, tvm.nd.array(data.astype(dtype)), params,
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_adt.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,7 @@ def test_iterate():

def check_tensor_array(ta_mod, ref_res, *args, dtype="float32", rtol=1e-5):
for kind in ["debug", "vm"]:
for target, ctx in testing.ctx_list():
for target, ctx in testing.enabled_targets():
if kind == "debug" and ctx.device_type != tvm.cpu().device_type:
continue
ex = relay.create_executor(kind, mod=ta_mod, ctx=ctx, target=target)
Expand Down
3 changes: 1 addition & 2 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from tvm.relay.loops import while_loop
from tvm.relay.testing import run_infer_type as infer_type
import tvm.topi.testing
from tvm.relay.testing.config import ctx_list

def int32(val):
return relay.const(val, 'int32')
Expand All @@ -37,7 +36,7 @@ def any_dims(ndim):
def check_result(args, mod, expected, flatten=False, assert_shape=False,
only_vm=False):
for kind in ["debug", "vm"]:
for tgt, ctx in ctx_list():
for tgt, ctx in tvm.testing.enabled_targets():
if kind == "debug" and (only_vm or ctx.device_type !=
tvm.cpu().device_type):
continue
Expand Down
16 changes: 8 additions & 8 deletions tests/python/relay/test_pass_context_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@

import tvm
from tvm import relay
from tvm.relay import expr as _expr, transform
from tvm.relay import expr as _expr
from tvm.relay.analysis import context_analysis


def test_device_copy():
if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist:
if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist:
return

mod = tvm.IRModule()
Expand All @@ -49,7 +49,7 @@ def test_device_copy():


def test_shape_func():
if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist:
if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist:
return

mod = tvm.IRModule()
Expand Down Expand Up @@ -78,7 +78,7 @@ def test_shape_func():


def test_vm_shape_of():
if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist:
if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist:
return

mod = tvm.IRModule()
Expand All @@ -96,7 +96,7 @@ def test_vm_shape_of():


def test_alloc_storage():
if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist:
if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist:
return

mod = tvm.IRModule()
Expand Down Expand Up @@ -126,7 +126,7 @@ def test_alloc_storage():


def test_alloc_tensor():
if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist:
if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist:
return

mod = tvm.IRModule()
Expand Down Expand Up @@ -155,7 +155,7 @@ def test_alloc_tensor():


def test_vm_reshape_tensor():
if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist:
if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist:
return

x = relay.var("x", shape=(2, 8), dtype="float32")
Expand All @@ -181,7 +181,7 @@ def test_vm_reshape_tensor():


def test_dynamic_input():
if not tvm.runtime.enabled("cuda") or not tvm.gpu(0).exist:
if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist:
return

mod = tvm.IRModule()
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_vm_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def test_dynamic_bcast():
x_data = np.random.uniform(size=(1, 2)).astype(dtype)
y_data = np.random.uniform(size=(3, 2)).astype(dtype)
res_np = np.add(x_data, y_data)
for target, ctx in testing.ctx_list():
for target, ctx in testing.enabled_targets():
res = get_serialized_output(mod, *(x_data, y_data), target=target,
ctx=ctx)
tvm.testing.assert_allclose(res.asnumpy(), res_np)
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_runtime_vm_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@

from tvm.runtime import profiler_vm
from tvm import relay
from tvm.relay.testing import resnet, ctx_list
from tvm.relay.testing import resnet, enabled_targets

def test_basic():
mod, params = resnet.get_workload()
if not profiler_vm.enabled():
return

for target, ctx in ctx_list():
for target, ctx in enabled_targets():
exe = relay.vm.compile(mod, target, params=params)
vm = profiler_vm.VirtualMachineProfiler(exe, ctx)

Expand Down

0 comments on commit 8f572b3

Please sign in to comment.