Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] Restore class-aware NMS for detection models by graph rewrite #7154

Merged
merged 9 commits into from
Jan 13, 2021

Conversation

masahi
Copy link
Member

@masahi masahi commented Dec 22, 2020

NMS used by PyTorch detection model actually performs multiclass NMS in one go, by adding different offsets to boxes from different classes so that two boxes from different classes never overlap. See

https://github.com/pytorch/vision/blob/3d60f498e71ba63b428edb184c9ac38fa3737fa6/torchvision/ops/boxes.py#L80-L89

But this means most of O(N**2) IOU tests we do in the NMS triangle loop are useless. The goal of this PR is to restore class indices which is one of the inputs to batched_nms function above and perform class-aware NMS for TVM-compiled detection models.

I did this by pattern matching and rewriting after model import. Specifically, I pattern match against this subgraph corresponding to PyTorch batched_nms used by maskrcnn / faster rcnn.

Unfortunately, this optimization didn't yield speedup I hoped: On GPU it only makes 70ms faster, and on CPU it actually makes it slightly slower (?) for some reason. I haven't looked into why it is not going much faster.

nvprof output from running MaskRCNN on GPU

Before

            Type  Time(%)      Time     Calls       Avg       Min       Max  Name                                                                                            
 GPU activities:   56.57%  711.34ms         2  355.67ms  15.296ms  696.05ms  fused_vision_non_max_suppression_kernel2                                                        
                   17.38%  218.54ms         1  218.54ms  218.54ms  218.54ms  fused_nn_dense_add_nn_relu_kernel0

After

            Type  Time(%)      Time     Calls       Avg       Min       Max  Name                                                                                            
 GPU activities:   54.52%  645.28ms         2  322.64ms  15.498ms  629.78ms  fused_vision_non_max_suppression_kernel2                                                        
                   18.44%  218.27ms         1  218.27ms  218.27ms  218.27ms  fused_nn_dense_add_nn_relu_kernel0     

On CPU, the output from VM profiler:

Before

#OpName                         #InvokeCount    #Duration(us): Sum/Mean/Min/Max
...
fused_vision_non_max_suppression        2               7902.54/3951.27/339.03/7563.51

After

#OpName                         #InvokeCount    #Duration(us): Sum/Mean/Min/Max
fused_vision_non_max_suppression        2               8129.3/4064.65/304.878/7824.42

So performance wise this change doesn't matter much, but I hope this also serves as a non trivial use of pattern matching and rewrite.

cc @kevinthesun @mbrookhart @zhiics @t-vi What do you think?

@zhiics
Copy link
Member

zhiics commented Dec 23, 2020

@masahi Thanks for the perf improvement. Could you provide the CPU numbers as well?

@masahi
Copy link
Member Author

masahi commented Dec 23, 2020

@zhiics Sure updated the description. Unfortunately I cannot claim that this is perf improvement. The regression is only 200 us on CPU, so it may be just a measurement noise, though.

I have no idea why I'm not getting good speed up. IOU tests, including memory access to boxes should be definitely reduced. The only additional overhead I think of is that the input to NMS is one column wider, due to storing class ids.

Performance is not great, but I believe having access to class ids should not be a bad idea...

@zhiics
Copy link
Member

zhiics commented Dec 23, 2020

@masahi I think this is an plausible as well particularly it is only in the parser. @kevinthesun please help take a look as well. Thanks.

@masahi
Copy link
Member Author

masahi commented Dec 23, 2020

I should mention that this rewrite is not run by default, so there is no perf risk.

@mbrookhart
Copy link
Contributor

This is a bit of a shot in the dark.

I wonder if we're memory access limited, and so that's why you don't see a performance improvement.

When we do the nested loop, we always have to check if the id of instance k matches the id of instance j. Since the input shape is (batch_size, num_anchors, features), and features = 6 here, I wouldn't be surprised if checking the instance of k ends up reading all of the features of k into registers, and that memory read is the expensive operation. Once it's in memory, actually doing the iou calculation is relatively cheap, so skipping it doesn't help that much.

@masahi
Copy link
Member Author

masahi commented Dec 23, 2020

When we do the nested loop, we always have to check if the id of instance k matches the id of instance j. Since the input shape is (batch_size, num_anchors, features), and features = 6 here, I wouldn't be surprised if checking the instance of k ends up reading all of the features of k into registers, and that memory read is the expensive operation. Once it's in memory, actually doing the iou calculation is relatively cheap, so skipping it doesn't help that much.

That's highly possible. Looking at this if condition in the triangle inner loop:

tvm.tir.any(
force_suppress > 0,
id_index < 0,
out[base_idx + offset_k + id_index]
== out[base_idx + offset_j + id_index],
),

previously, force_suppress is always True, so this condition short circuit and access to out[base_idx + offset_k + id_index] and out[base_idx + offset_j + id_index] just below never happen. But now, to make NMS class aware, I had to change force_suppress to False, and now access to out[base_idx + offset_k + id_index] and out[base_idx + offset_j + id_index] always happen. This may be cancelling the speedup from reduced IOU tests. Storing the class IDs in a different 1D tensor may help.

That brings me to one of my pain points with our NMS API: I belieave our NMS API needs to be reworked. The current way of packing class ids and scores together with bbox coordinates is a design mistake that we inherited from MXNet. To store class ids, I have to cast ids to float32, update and pass id_index appropriately. Since our NMS API also requires scores to be packed with bbox, I had to update score_index too and all frontends except MXNet needs to do this concatenation. The worst part is that in NMS IR, the very first thing we do is the extraction 1D score tensor from packed data. So I see no good reason to pack score tensor and bbox coordinates.

@mbrookhart
Copy link
Contributor

mbrookhart commented Dec 24, 2020

Sorry for the delay in responding to this, I wanted to look at the frameworks more closely. We currently have 5 importers that leverage NMS:

MXNET does multibox_transform_loc and then NMS on the outputs. multi_box_transform_loc converts a 3D array of scores with shape (batch_size, class_num, num_anchors) into a most likely class and score for that class, plus does some coordinate transforms on the box.

ONNX takes a 3D tensor of (batch_size, class, num_anchors), does slicing/concatenating with the boxes, and then does a per-class get_valid_counts->non_max.

Pytorch takes in a 1D tensor of scores and concats it with the boxes before performing get_valid_counts and nms. As @masahi shows in this PR, there is preprocessing to embed all classes into that 1D tensor outside of the op.

TF takes a 1D tensor of scores and concats it to the boxes before performing get_valid_counts and nms. I'm not sure if the rest of the TF graph is handling the loop over batch size and classes.

TFlite takes a 3D score tensor of shape (batch size, num_anchors, class_id), reorders it to (batch_size, class_id, num_anchors), performs multibox_transform_loc->nms, and strangely does get_valid_counts after NMS.

It looks like we're doing pre-processing in every framework to reduce the amount of score information and convert it to the 5 or 6 D form the nms API wants. None of the frameworks give us inputs in the packed form the API expects, and we jump through hoops in every importer to convert inputs into that form. Then in at least TFLite and ONNX, we perform further splitting/slicing/concatenating to restore the separate class ids.

I think I agree with @masahi, we seem to be jumping through a lot of hoops in the importers to support a TVM NMS API that's out of line with the frameworks, and that might be hurting our overall performance.

@masahi masahi force-pushed the torch-maskrcnn-rewrite branch 2 times, most recently from d9b9995 to 4d43fdc Compare December 26, 2020 02:31
@trevor-m
Copy link
Contributor

I highly agree with you guys.

For class-aware NMS, the [batch, num_anchors, 6] format seems very inefficient. It means all anchors need to be checked just to see if the classes match. A [batch, num_classes, num_anchors, 5] format would give us a nicely defined slice of memory where the same-class anchors are located.

TF takes a 1D tensor of scores and concats it to the boxes before performing get_valid_counts and nms. I'm not sure if the rest of the TF graph is handling the loop over batch size and classes.

That's correct, TF's NMS is only for single class and single batch, so the TF graph loops over batches and classes. To do that, they use tf.map_fn so the execution of each NMS can actually still run in parallel. However, this turns into a mess of control flow operators and TensorArrays, so Relay isn't able to do the same parallelization. This PR's graph rewrite could actually benefit TF OD models as well, but the pattern is a lot more complicated for TF.

@masahi
Copy link
Member Author

masahi commented Jan 12, 2021

@kevinthesun @zhiics @mbrookhart

As shown in my new NMS PR #7257, this rewrite results in a better speed up with improved memory layout. Can we merge this? I have new rewrites coming to further optimize PyTorch NMS and MaskRCNN / FasterRCNN.

Copy link
Contributor

@mbrookhart mbrookhart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhiics zhiics merged commit 86479ba into apache:main Jan 13, 2021
masahi added a commit to masahi/tvm that referenced this pull request Jan 14, 2021
…apache#7154)

* add a pattern to rewrite nms to batched nms

* update object detection test to add rewrite

* updated tutorial

* add doc

* fixed coord_start

* test fixed by setting force_surpress=False

* revert tutorial change

* add some comment to explain the pattern

* update NMS pattern following frontend change
masahi added a commit to masahi/tvm that referenced this pull request Jan 18, 2021
…apache#7154)

* add a pattern to rewrite nms to batched nms

* update object detection test to add rewrite

* updated tutorial

* add doc

* fixed coord_start

* test fixed by setting force_surpress=False

* revert tutorial change

* add some comment to explain the pattern

* update NMS pattern following frontend change
TusharKanekiDey pushed a commit to TusharKanekiDey/tvm that referenced this pull request Jan 20, 2021
…apache#7154)

* add a pattern to rewrite nms to batched nms

* update object detection test to add rewrite

* updated tutorial

* add doc

* fixed coord_start

* test fixed by setting force_surpress=False

* revert tutorial change

* add some comment to explain the pattern

* update NMS pattern following frontend change
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Jan 21, 2021
…apache#7154)

* add a pattern to rewrite nms to batched nms

* update object detection test to add rewrite

* updated tutorial

* add doc

* fixed coord_start

* test fixed by setting force_surpress=False

* revert tutorial change

* add some comment to explain the pattern

* update NMS pattern following frontend change
electriclilies pushed a commit to electriclilies/tvm that referenced this pull request Feb 18, 2021
…apache#7154)

* add a pattern to rewrite nms to batched nms

* update object detection test to add rewrite

* updated tutorial

* add doc

* fixed coord_start

* test fixed by setting force_surpress=False

* revert tutorial change

* add some comment to explain the pattern

* update NMS pattern following frontend change
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants