Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] Restore class-aware NMS for detection models by graph rewrite #7154

Merged
merged 9 commits into from
Jan 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1857,26 +1857,26 @@ def nms(self, inputs, input_types):
scores = inputs[1]
iou_threshold = inputs[2]

num_boxes = _op.shape_of(scores)

# TVM NMS assumes score > 0
scores = scores - _op.min(scores) + _op.const(1.0)

num_boxes = _op.shape_of(scores)
# PyTorch NMS doesn't have score_threshold, so no need to run get_valid_count
indices = _op.transform.arange(_op.squeeze(num_boxes), dtype="int32")
indices = _op.expand_dims(indices, 0, 1)

# Generate data with shape (1, num_anchors, 5)
scores = AttrCvt(op_name="expand_dims", extras={"axis": -1, "num_newaxis": 1})([scores], {})
data = _op.concatenate([scores, boxes], -1)
data = _op.expand_dims(data, 0, 1)
# PyTorch NMS doesn't have score_threshold, so no need to run get_valid_count
indices = _op.transform.arange(_op.squeeze(num_boxes), dtype="int32")
indices = _op.expand_dims(indices, 0, 1)
ct = num_boxes

# Perform Non-Maximum Suppression,
# PyTorch NMS doesn't have parameter top_k and max_output_size
score_index = 0
top_k = max_out_size = -1
nms_ret = get_relay_op("non_max_suppression")(
data=data,
valid_count=ct,
valid_count=num_boxes,
indices=indices,
max_output_size=max_out_size,
iou_threshold=iou_threshold,
Expand Down
153 changes: 152 additions & 1 deletion python/tvm/relay/frontend/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-outside-toplevel
# pylint: disable=import-outside-toplevel, unused-argument, invalid-name
""" Common utilities used by PyTorch frontend """
from .. import op
from ..dataflow_pattern import (
is_constant,
is_op,
rewrite,
is_tuple,
wildcard,
DFPatternCallback,
)


def is_version_greater_than(ver):
Expand All @@ -25,3 +34,145 @@ def is_version_greater_than(ver):
return "".join(re.findall(r"(\d+\.)(\d+\.)(\d)", torch.__version__)[0]) > "".join(
re.findall(r"(\d+\.)(\d+\.)(\d)", ver)[0]
)


def batched_nms_pattern(boxes, scores, idxs, iou_threshold, num_boxes, indices):
"""A pattern to detect batched_nms function in torchvision

The inputs to this function, boxes, scores, idxs, iou_threshold are wildcard
patterns which can be used later in the rewriting to extract matched Relay fragments.

We want to detect the following PyTorch code snippet:

def batched_nms(boxes, scores, idxs, iou_threshold):
max_coordinate = boxes.max()
offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
boxes_for_nms = boxes + offsets[:, None]
keep = nms(boxes_for_nms, scores, iou_threshold)
return keep

Here is how PyTorch frontend lowers above PyTorch code. For simplicity, Relay ops for
dealing with dynamic strided_slice are omitted. %num_boxes, %indices are complex
expressions, but since we can use the wildcard part for them, we do not need to construct
their patterns.

%2 = expand_dims(%scores, axis=-1);
%3 = cast(%idxs, dtype="float32");
%4 = max(%boxes);
%5 = add(%4, 1f);
%6 = multiply(%3, %5);
%7 = strided_slice(%6, begin=[0], end=[4507], strides=[1]);
%8 = expand_dims(%7, axis=1);
%9 = add(%boxes, %8);
%10 = (%2, %9);
%11 = concatenate(%10, axis=-1);
%12 = expand_dims(%11, axis=0);
...
...
%17 = vision.non_max_suppression(%12, %num_boxes, %indices, -1, 0.7f, ...);

"""
one = is_constant()
zero = is_constant()

# Equivelent PyTorch code from above snippet
# offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
cast = is_op("cast")(idxs)
mx = is_op("max")(boxes)
add = is_op("add")(mx, one)
mul = is_op("multiply")(cast, add)

# The following doesn't appear in the above Relay snippet. It is required for dynamic
# stride_slice handling
cast_like = is_op("cast_like")(zero, is_constant())
less = is_op("less")(is_constant(), cast_like)
shape_of = is_op("shape_of")(mul)
cast_like = is_op("cast_like")(shape_of, is_constant())
add = is_op("add")(is_constant(), cast_like)
where = is_op("where")(less, add, is_constant())
shape_of = is_op("shape_of")(mul)
cast = is_op("cast")(shape_of)

# This corresponds to offsets[:, None], where offsets is the result of multiplication
dyn_strided_slice = is_op("dyn.strided_slice")(mul, where, cast, is_constant())

# Add offsets to the boxes
expand_dims = is_op("expand_dims")(dyn_strided_slice)
add = is_op("add")(boxes, expand_dims)

# The rest of patterns correspond to the PyTorch frontend conversion
# function for torchvision::nms
score_expand_dims = is_op("expand_dims")(scores)
tup = is_tuple([score_expand_dims, add])
concat = is_op("concatenate")(tup)
data = is_op("expand_dims")(concat)

return is_op("vision.non_max_suppression")(
data, num_boxes, indices, is_constant(), iou_threshold
)


class NMSRewrite(DFPatternCallback):
"""A callback to rewrite nms and restore batched nms"""

def __init__(self):
super().__init__()
# exprs to extract
self.boxes = wildcard()
self.scores = wildcard()
self.idxs = wildcard()
self.iou_threshold = wildcard()
self.num_boxes = wildcard()
self.indices = wildcard()

self.pattern = batched_nms_pattern(
self.boxes,
self.scores,
self.idxs,
self.iou_threshold,
self.num_boxes,
self.indices,
)

def convert_batched_nms(self, boxes, scores, idxs, iou_thres, num_boxes, indices):
"""Restore class-aware NMS using extracted class indices"""
scores = op.expand_dims(scores, axis=-1, num_newaxis=1)
idxs = op.expand_dims(idxs, axis=-1, num_newaxis=1)
idxs = op.cast(idxs, "float32")
data = op.concatenate([idxs, scores, boxes], -1)
data = op.expand_dims(data, 0, 1)

top_k = max_out_size = -1
out = op.vision.non_max_suppression(
data=data,
valid_count=num_boxes,
indices=indices,
max_output_size=max_out_size,
iou_threshold=iou_thres,
force_suppress=False,
top_k=top_k,
coord_start=2,
score_index=1,
id_index=0,
return_indices=True,
invalid_to_bottom=False,
)
return out.tuple_value

def callback(self, pre, post, node_map):
boxes = node_map[self.boxes][0]
scores = node_map[self.scores][0]
idxs = node_map[self.idxs][0]
iou_thres = node_map[self.iou_threshold][0]
num_boxes = node_map[self.num_boxes][0]
indices = node_map[self.indices][0]
return self.convert_batched_nms(boxes, scores, idxs, iou_thres, num_boxes, indices)


def rewrite_nms_to_batched_nms(mod):
"""Rewrite the input graph to replace non maximum surpression
in torchvision that does not take class id into account with the one
that avoids IOU tests between different classes.
"""
mod["main"] = rewrite(NMSRewrite(), mod["main"])
return mod
20 changes: 17 additions & 3 deletions tests/python/frontend/pytorch/test_object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import tvm.testing
from tvm import relay
from tvm.runtime.vm import VirtualMachine
from tvm.relay.frontend.pytorch_utils import rewrite_nms_to_batched_nms
from tvm.contrib.download import download


Expand Down Expand Up @@ -108,15 +109,17 @@ def test_detection_models():
with torch.no_grad():
pt_res = scripted_model(data)

for target in ["llvm", "cuda"]:
def compile_and_run_vm(mod, params, data_np, target):
with tvm.transform.PassContext(opt_level=3):
vm_exec = relay.vm.compile(mod, target=target, params=params)

ctx = tvm.context(target, 0)
vm = VirtualMachine(vm_exec, ctx)

vm.set_input("main", **{input_name: data_np})
tvm_res = vm.run()
return vm.run()

for target in ["cuda", "llvm"]:
tvm_res = compile_and_run_vm(mod, params, data_np, target)

# Bounding boxes
tvm.testing.assert_allclose(
Expand All @@ -132,3 +135,14 @@ def test_detection_models():
score_threshold = 0.9
print("Num boxes:", pt_res[0].cpu().numpy().shape[0])
print("Num valid boxes:", np.sum(pt_res[1].cpu().numpy() >= score_threshold))

before = mod["main"]
mod = rewrite_nms_to_batched_nms(mod)
after = mod["main"]
assert not tvm.ir.structural_equal(after, before)

tvm_res_after_rewrite = compile_and_run_vm(mod, params, data_np, "llvm")

# Results should be equivalent after rewriting
for res1, res2 in zip(tvm_res, tvm_res_after_rewrite):
tvm.testing.assert_allclose(res1.asnumpy(), res2.asnumpy())