Skip to content

Commit

Permalink
[ONNX] Support NMS Center Box (apache#7900)
Browse files Browse the repository at this point in the history
* [ONNX] Support NMS Center Box

* fix silly mistake in contional
  • Loading branch information
Matthew Brookhart authored and Trevor Morris committed May 6, 2021
1 parent f46b556 commit e43b80d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
16 changes: 11 additions & 5 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2543,11 +2543,17 @@ def _impl_v10(cls, inputs, attr, params):
iou_threshold = inputs[3]
score_threshold = inputs[4]

if "center_point_box" in attr:
if attr["center_point_box"] != 0:
raise NotImplementedError(
"Only support center_point_box = 0 in ONNX NonMaxSuprresion"
)
boxes_dtype = infer_type(boxes).checked_type.dtype

if attr.get("center_point_box", 0) != 0:
xc, yc, w, h = _op.split(boxes, 4, axis=2)
half_w = w / _expr.const(2.0, boxes_dtype)
half_h = h / _expr.const(2.0, boxes_dtype)
x1 = xc - half_w
x2 = xc + half_w
y1 = yc - half_h
y2 = yc + half_h
boxes = _op.concatenate([y1, x1, y2, x2], axis=2)

if iou_threshold is None:
iou_threshold = _expr.const(0.0, dtype="float32")
Expand Down
1 change: 0 additions & 1 deletion tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4215,7 +4215,6 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"):
"test_maxpool_with_argmax_2d_precomputed_strides/",
"test_maxunpool_export_with_output_shape/",
"test_mvn/",
"test_nonmaxsuppression_center_point_box_format/",
"test_qlinearconv/",
"test_qlinearmatmul_2D/",
"test_qlinearmatmul_3D/",
Expand Down

0 comments on commit e43b80d

Please sign in to comment.