Skip to content

Commit

Permalink
Fix trt Test (apache#7016)
Browse files Browse the repository at this point in the history
* Fix trt Test

* Fixed stuff

* Done

* fix 0

* Trigger Build

Co-authored-by: Ubuntu <ubuntu@ip-172-31-27-149.us-east-2.compute.internal>
  • Loading branch information
2 people authored and trevor-m committed Dec 4, 2020
1 parent 637aec5 commit 4a76a07
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions tests/python/contrib/test_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,26 +1050,26 @@ def test_tensorrt_dynamic_batch():
batches_to_test = [1, 1, 0, 2, 3, 0, 1, 3, 2]
x_shape = (relay.Any(), 1, 8, 8)
x_data = np.ones([max(batches_to_test)] + list(x_shape)[1:]).astype("float32")
result_dict = {}
result_arr = [{} for _ in range(len(batches_to_test))]
for use_trt in [True, False]:
x = relay.var("x", shape=x_shape, dtype="float32")
out = relay.nn.relu(x)
f = relay.Function([x], out)
mod = tvm.IRModule()
mod["main"] = f
if use_trt:
mod = relay.tensorrt.EnableTrt(mod)
mod, _ = tensorrt.partition_for_tensorrt(mod)

if not skip_runtime_test():
with relay.build_config(opt_level=3):
relay_exec = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), target="llvm")

for i, batch_size in enumerate(batches_to_test):
result_dict[(i, use_trt)] = relay_exec.evaluate()(x_data[:batch_size, ...])
result_arr[i][use_trt] = relay_exec.evaluate()(x_data[:batch_size, ...])

if not skip_runtime_test():
for i in range(len(batches_to_test)):
assert_result_matches(result_dict[(i, True)], result_dict[(i, False)])
assert_result_dict_holds(result_arr[i])


def test_tensorrt_dynamic_batch_conv():
Expand All @@ -1080,7 +1080,7 @@ def test_tensorrt_dynamic_batch_conv():
x_data = np.ones([max(batches_to_test)] + list(x_shape)[1:]).astype("float32")
k_shape = (16, 32, 3, 3)
params = {"kernel": np.random.uniform(-1, 1, k_shape).astype("float32")}
result_dict = {}
result_arr = [{} for _ in range(len(batches_to_test))]
for use_trt in [True, False]:
x = relay.var("x", shape=x_shape, dtype="float32")
kernel = relay.var("kernel", shape=k_shape, dtype="float32")
Expand All @@ -1089,20 +1089,18 @@ def test_tensorrt_dynamic_batch_conv():
mod = tvm.IRModule()
mod["main"] = f
if use_trt:
mod = tensorrt.partition_for_tensorrt(mod, params)
mod, _ = tensorrt.partition_for_tensorrt(mod, params)

if not skip_runtime_test():
with relay.build_config(opt_level=3):
relay_exec = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), target="llvm")

for i, batch_size in enumerate(batches_to_test):
result_dict[(i, use_trt)] = relay_exec.evaluate()(
x=x_data[:batch_size, ...], **params
)
result_arr[i][use_trt] = relay_exec.evaluate()(x_data[:batch_size, ...], **params)

if not skip_runtime_test():
for i in range(len(batches_to_test)):
assert_result_matches(result_dict[(i, True)], result_dict[(i, False)])
assert_result_dict_holds(result_arr[i])


def test_maskrcnn_resnet50() -> None:
Expand Down

0 comments on commit 4a76a07

Please sign in to comment.