Skip to content

Commit

Permalink
[Relay] Shape func fix for all_class_nms and where op (apache#7910)
Browse files Browse the repository at this point in the history
* fix missing cast to int64 in all_class_nms shape func

* fix scalar in where shape func

* add add test

* update test

* minor fix

* add where scalar shape func test
  • Loading branch information
masahi authored and Trevor Morris committed May 6, 2021
1 parent 5a6ca12 commit e82187f
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 8 deletions.
12 changes: 9 additions & 3 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,9 +1029,15 @@ def where_shape_func(attrs, inputs, _):
"""
Shape func for where.
"""
cond_shape = inputs[0]
x_shape = inputs[1]
y_shape = inputs[2]

def ensure_tensor(tensor):
if len(tensor.shape) == 0:
return topi.full((1,), "int64", 1)
return tensor

cond_shape = ensure_tensor(inputs[0])
x_shape = ensure_tensor(inputs[1])
y_shape = ensure_tensor(inputs[2])

bcast_shape = _broadcast_shape_tensors(x_shape, y_shape)
out_shape = _broadcast_shape_tensors(bcast_shape, cond_shape)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/vision/_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _all_class_nms_shape_func(boxes_shape, scores_shape):
count_shape = output_tensor((1,), "int64")

out_shape[0] = boxes_shape[0] * scores_shape[1] * boxes_shape[1]
out_shape[1] = 3
out_shape[1] = int64(3)
count_shape[0] = int64(1)
return out_shape, count_shape

Expand Down
124 changes: 120 additions & 4 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -1512,9 +1512,26 @@ def test_any_where():
any_dims(2), any_dims(2), any_dims(2), (3, 4), (3, 1), (1, 4), y_np_shape_invalid=(2, 4)
)

# Test scalar where in a dynamically shaped graph
x = relay.var("x", shape=any_dims(1), dtype="int64")
y = relay.var("y", shape=any_dims(2), dtype="float32")

# TODO(kevinthesun): enable gpu test when Thrust is available in ci.
# @tvm.testing.uses_gpu
left = relay.take(x, relay.const(1, dtype="int32")) + relay.const(4, "int64")
right = relay.const(4, "int64")
where = relay.where(relay.const(False, "bool"), left, right)
z = relay.take(y, where, axis=1)

mod = tvm.IRModule()
mod["main"] = relay.Function([x, y], z)

x_np = np.random.randn(2).astype("int64")
y_np = np.random.randn(2, 6).astype("float32")
expected = y_np[:, 4]

check_result([x_np, y_np], mod, expected)


@tvm.testing.uses_gpu
def test_non_max_suppression():
x0 = relay.var("x0", relay.ty.TensorType((1, relay.Any(), 6), "float32"))
x1 = relay.var("x1", relay.ty.TensorType((1,), "int32"))
Expand Down Expand Up @@ -1558,7 +1575,6 @@ def test_non_max_suppression():
mod,
[np_indices_result, np_valid_box_count],
only_vm=False,
disable_targets=["nvptx"],
)

np_data = np.zeros((1, 0, 6)).astype("float32")
Expand All @@ -1573,7 +1589,107 @@ def test_non_max_suppression():
mod,
[np_indices_result, np_valid_box_count],
only_vm=False,
disable_targets=["nvptx"],
)


@tvm.testing.uses_gpu
def test_all_class_non_max_suppression():
def verify_all_class_non_max_suppression(
boxes_np,
scores_np,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
expected_indices,
):
batch_size = boxes_np.shape[0]
num_classes = scores_np.shape[1]
num_boxes = relay.Any()
boxes = relay.var("boxes", relay.ty.TensorType((batch_size, num_boxes, 4), "float32"))
scores = relay.var(
"scores", relay.ty.TensorType((batch_size, num_classes, num_boxes), "float32")
)

nms_out = relay.vision.all_class_non_max_suppression(
boxes,
scores,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
)

three = relay.const(np.array([3]), dtype="int64")
begin = relay.const(np.array([0, 0]), dtype="int64")
end = relay.op.concatenate([nms_out[1], three], axis=0)
strides = relay.const(np.array([1, 1]), dtype="int64")
out = relay.op.strided_slice(nms_out[0], begin, end, strides)

mod = tvm.IRModule()
mod["main"] = relay.Function([boxes, scores], out)

check_result([boxes_np, scores_np], mod, [expected_indices])

boxes = np.array(
[
[
[0.0, 0.0, 0.3, 0.3],
[0.5, 0.5, 0.4, 0.4],
[0.0, 0.0, 0.5, 0.5],
[0.5, 0.5, 0.9, 0.9],
[0.5, 0.5, 1.0, 1.0],
],
]
).astype("float32")

scores = np.array(
[
[[0.1, 0.2, 0.6, 0.3, 0.9], [0.8, 0.2, 0.6, 0.3, 0.9]],
]
).astype("float32")

max_output_boxes_per_class = 2
iou_threshold = 0.8
score_threshold = 0.4

expected = np.array([[0, 0, 4], [0, 0, 2], [0, 1, 4], [0, 1, 0]])

verify_all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, expected
)

boxes = np.array(
[
[
[0.0, 0.0, 1.0, 1.0],
[0.0, 0.1, 0.9, 1.2],
]
]
).astype(np.float32)
scores = np.array([[[0.2, 0.3], [0.3, 0.2]]]).astype(np.float32)
iou_threshold = 0.3
score_threshold = 0.15

expected = np.array([[0, 0, 1], [0, 1, 0]])

verify_all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, expected
)

# zero box detection case
boxes = np.array(
[
[
[0.0, 0.0, 1.0, 1.0],
]
]
).astype(np.float32)
scores = np.array([[[0.2]]]).astype(np.float32)
score_threshold = 0.4

expected = np.zeros((0, 3))

verify_all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, expected
)


Expand Down

0 comments on commit e82187f

Please sign in to comment.