Skip to content

Commit

Permalink
minor improvement when topk is available
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 28, 2020
1 parent 9b42008 commit 0a9e4ab
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
"int32", (1,), name="num_valid_boxes_local", scope="local"
)
num_valid_boxes_local[0] = 0
nkeep = if_then_else(tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i])

def nms_inner_loop(ib, j):
# The box j is valid, invalidate other boxes that overlap with j above iou_threshold
Expand All @@ -545,15 +546,15 @@ def nms_inner_loop(ib, j):
num_valid_boxes_local[0] += 1

offset_j = j * box_data_length
num_iter_per_thread = ceil_div(valid_count[i] - (j + 1), nthread_tx)
num_iter_per_thread = ceil_div(nkeep - (j + 1), nthread_tx)

with ib.for_range(0, num_iter_per_thread) as _k:
k = j + 1 + _k * nthread_tx + tx
offset_k = k * box_data_length

with ib.if_scope(
tvm.tir.all(
k < num_anchors,
k < nkeep,
out[base_idx + offset_k + score_index] > 0, # is the box k still valid?
tvm.tir.any(
force_suppress > 0,
Expand Down Expand Up @@ -582,7 +583,7 @@ def nms_inner_loop(ib, j):

with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
# Apply nms
with ib.for_range(0, valid_count[i]) as j:
with ib.for_range(0, nkeep) as j:
# Proceed to the inner loop if the box j is still valid
with ib.if_scope(out[base_idx + (j * box_data_length) + score_index] > -1.0):
with ib.if_scope(max_output_size > 0):
Expand Down

0 comments on commit 0a9e4ab

Please sign in to comment.