Skip to content

Commit

Permalink
[TEST] Refactor RPC test to isolate runs into a sub-function (apache#…
Browse files Browse the repository at this point in the history
…8656)

We kill the rpc server in the del function. When a server
co-exist with remote resources in the same function scope,
the destruction order is not determined.

This can cause server to be destructed before the actual remote array.
As a side effect, it can cause sometime test to timeout due to
waiting on the socket.
  • Loading branch information
tqchen authored and ylc committed Jan 13, 2022
1 parent 1b8e470 commit 2ee9094
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 117 deletions.
7 changes: 3 additions & 4 deletions tests/python/contrib/test_edgetpu_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def init_interpreter(model_path, target_edgetpu):
interpreter = tflite.Interpreter(model_path=model_path)
return interpreter

def check_remote(target_edgetpu=False):
def check_remote(server, target_edgetpu=False):
tflite_model_path = get_tflite_model_path(target_edgetpu)

# inference via tflite interpreter python apis
Expand All @@ -67,7 +67,6 @@ def check_remote(target_edgetpu=False):
tflite_output = interpreter.get_tensor(output_details[0]["index"])

# inference via remote tvm tflite runtime
server = rpc.Server("127.0.0.1")
remote = rpc.connect(server.host, server.port)
dev = remote.cpu(0)
if target_edgetpu:
Expand All @@ -83,9 +82,9 @@ def check_remote(target_edgetpu=False):
np.testing.assert_equal(out.numpy(), tflite_output)

# Target CPU on coral board
check_remote()
check_remote(rpc.Server("127.0.0.1"))
# Target EdgeTPU on coral board
check_remote(target_edgetpu=True)
check_remote(rpc.Server("127.0.0.1"), target_edgetpu=True)


if __name__ == "__main__":
Expand Down
21 changes: 12 additions & 9 deletions tests/python/contrib/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,20 @@ def test_rpc(dtype):
return

np_ones = np.ones((512, 512), dtype=dtype)
server = rpc.Server("127.0.0.1")
remote = rpc.connect(server.host, server.port)
value = tvm.nd.empty((512, 512), dtype, remote.cpu())
random_fill = remote.get_function("tvm.contrib.random.random_fill")
random_fill(value)

assert np.count_nonzero(value.numpy()) == 512 * 512
def check_remote(server):
remote = rpc.connect(server.host, server.port)
value = tvm.nd.empty((512, 512), dtype, remote.cpu())
random_fill = remote.get_function("tvm.contrib.random.random_fill")
random_fill(value)

# make sure arithmentic doesn't overflow too
np_values = value.numpy()
assert np.isfinite(np_values * np_values + np_values).any()
assert np.count_nonzero(value.numpy()) == 512 * 512

# make sure arithmentic doesn't overflow too
np_values = value.numpy()
assert np.isfinite(np_values * np_values + np_values).any()

check_remote(rpc.Server("127.0.0.1"))

for dtype in [
"bool",
Expand Down
24 changes: 12 additions & 12 deletions tests/python/contrib/test_tflite_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,18 +128,18 @@ def test_remote():
tflite_output = interpreter.get_tensor(output_details[0]["index"])

# inference via remote tvm tflite runtime
server = rpc.Server("127.0.0.1")
remote = rpc.connect(server.host, server.port)
a = remote.upload(tflite_model_path)

with open(tflite_model_path, "rb") as model_fin:
runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0))
runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0)))
runtime.invoke()
out = runtime.get_output(0)
np.testing.assert_equal(out.numpy(), tflite_output)

server.terminate()
def check_remote(server):
remote = rpc.connect(server.host, server.port)
a = remote.upload(tflite_model_path)

with open(tflite_model_path, "rb") as model_fin:
runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0))
runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0)))
runtime.invoke()
out = runtime.get_output(0)
np.testing.assert_equal(out.numpy(), tflite_output)

check_remote(rpc.Server("127.0.0.1"))


if __name__ == "__main__":
Expand Down
42 changes: 19 additions & 23 deletions tests/python/relay/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,29 +879,25 @@ def test_vm_rpc():
# Use local rpc server for testing.
# Server must use popen so it doesn't inherit the current process state. It
# will crash otherwise.
server = rpc.Server("localhost", port=9120)
remote = rpc.connect(server.host, server.port, session_timeout=10)

# Upload the serialized Executable.
remote.upload(path)
# Get a handle to remote Executable.
rexec = remote.load_module("vm_library.so")

ctx = remote.cpu()
# Build a VM out of the executable and context.
vm_factory = runtime.vm.VirtualMachine(rexec, ctx)
np_input = np.random.uniform(size=(10, 1)).astype("float32")
input_tensor = tvm.nd.array(np_input, ctx)
# Invoke its "main" function.
out = vm_factory.invoke("main", input_tensor)
# Check the result.
np.testing.assert_allclose(out.numpy(), np_input + np_input)

# delete tensors before the server shuts down so we don't throw errors.
del input_tensor
del out

server.terminate()
def check_remote(server):
remote = rpc.connect(server.host, server.port, session_timeout=10)

# Upload the serialized Executable.
remote.upload(path)
# Get a handle to remote Executable.
rexec = remote.load_module("vm_library.so")

ctx = remote.cpu()
# Build a VM out of the executable and context.
vm_factory = runtime.vm.VirtualMachine(rexec, ctx)
np_input = np.random.uniform(size=(10, 1)).astype("float32")
input_tensor = tvm.nd.array(np_input, ctx)
# Invoke its "main" function.
out = vm_factory.invoke("main", input_tensor)
# Check the result.
np.testing.assert_allclose(out.numpy(), np_input + np_input)

check_remote(rpc.Server("127.0.0.1"))


def test_get_output_single():
Expand Down
5 changes: 2 additions & 3 deletions tests/python/unittest/test_runtime_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,8 @@ def check_verify():
out = mod.get_output(0, tvm.nd.empty((n,)))
np.testing.assert_equal(out.numpy(), a + 1)

def check_remote():
def check_remote(server):
mlib = tvm.build(s, [A, B], "llvm", name="myadd")
server = rpc.Server("127.0.0.1")
remote = rpc.connect(server.host, server.port)
temp = utils.tempdir()
dev = remote.cpu(0)
Expand Down Expand Up @@ -115,7 +114,7 @@ def check_sharing():
del mod

check_verify()
check_remote()
check_remote(rpc.Server("127.0.0.1"))
check_sharing()


Expand Down
6 changes: 3 additions & 3 deletions tests/python/unittest/test_runtime_graph_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@


@tvm.testing.requires_llvm
@tvm.testing.requires_rpc
def test_graph_simple():
n = 4
A = te.placeholder((n,), name="A")
Expand Down Expand Up @@ -160,9 +161,8 @@ def split_debug_line(i):
# verify dump root delete after cleanup
assert not os.path.exists(directory)

def check_remote():
def check_remote(server):
mlib = tvm.build(s, [A, B], "llvm", name="myadd")
server = rpc.Server("127.0.0.1")
remote = rpc.connect(server.host, server.port)
temp = utils.tempdir()
dev = remote.cpu(0)
Expand All @@ -182,7 +182,7 @@ def check_remote():
np.testing.assert_equal(out.numpy(), a + 1)

check_verify()
check_remote()
check_remote(rpc.Server("127.0.0.1"))


if __name__ == "__main__":
Expand Down
48 changes: 25 additions & 23 deletions tests/python/unittest/test_runtime_module_based_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,29 +275,31 @@ def verify_rpc_gpu_export(obj_format):

from tvm import rpc

server = rpc.Server("127.0.0.1", port=9094)
remote = rpc.connect(server.host, server.port)
remote.upload(path_lib)
loaded_lib = remote.load_module(path_lib)
data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32")
dev = remote.cuda()

# raw api
gmod = loaded_lib["default"](dev)
set_input = gmod["set_input"]
run = gmod["run"]
get_output = gmod["get_output"]
set_input("data", tvm.nd.array(data, device=dev))
run()
out = get_output(0).numpy()
tvm.testing.assert_allclose(out, verify(data), atol=1e-5)

# graph executor wrapper
gmod = graph_executor.GraphModule(loaded_lib["default"](dev))
gmod.set_input("data", data)
gmod.run()
out = gmod.get_output(0).numpy()
tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
def check_remote(server):
remote = rpc.connect(server.host, server.port)
remote.upload(path_lib)
loaded_lib = remote.load_module(path_lib)
data = np.random.uniform(-1, 1, size=input_shape(mod)).astype("float32")
dev = remote.cuda()

# raw api
gmod = loaded_lib["default"](dev)
set_input = gmod["set_input"]
run = gmod["run"]
get_output = gmod["get_output"]
set_input("data", tvm.nd.array(data, device=dev))
run()
out = get_output(0).numpy()
tvm.testing.assert_allclose(out, verify(data), atol=1e-5)

# graph executor wrapper
gmod = graph_executor.GraphModule(loaded_lib["default"](dev))
gmod.set_input("data", data)
gmod.run()
out = gmod.get_output(0).numpy()
tvm.testing.assert_allclose(out, verify(data), atol=1e-5)

check_remote(rpc.Server("127.0.0.1"))

for obj_format in [".so", ".tar"]:
verify_cpu_export(obj_format)
Expand Down
Loading

0 comments on commit 2ee9094

Please sign in to comment.