diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index dd9d3f8a1d0e..2dc177a0fae8 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -95,7 +95,7 @@ def get_valid_boxes_ir(data, valid_boxes, score_threshold, id_index, score_index max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) with ib.new_scope(): nthread_tx = max_threads - nthread_bx = num_anchors // max_threads + 1 + nthread_bx = ceil_div(num_anchors, max_threads) nthread_by = batch_size tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") @@ -151,31 +151,103 @@ def get_valid_indices_ir(valid_boxes, valid_count, valid_indices): valid_indices = ib.buffer_ptr(valid_indices) max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + + # Copy boxes to valid_indices with ib.new_scope(): nthread_tx = max_threads - nthread_bx = batch_size // max_threads + 1 + nthread_bx = ceil_div(num_anchors, max_threads) + nthread_by = batch_size tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) - tid = bx * max_threads + tx - # TODO(mbrookhart): Parallelize the sum and cumsum here - current_index = ib.allocate("int32", (1,), name="current_index", scope="local") - with ib.if_scope(tid < batch_size): - current_index[0] = 0 - valid_count[tid] = 0 - with ib.for_range(0, num_anchors) as j: - idx = tid * num_anchors + j - valid_count[tid] = valid_count[tid] + valid_boxes[idx] - with ib.if_scope(valid_boxes[idx] == 1): - valid_indices[idx] = current_index[0] - current_index[0] = current_index[0] + 1 - with ib.else_scope(): - valid_indices[idx] = -1 + ib.scope_attr(by, "thread_extent", nthread_by) + tid = bx * nthread_tx + tx + with ib.if_scope(tid < num_anchors): + valid_indices[by, tid] = valid_boxes[by, tid] + + nthread_tx = max_threads + nthread_bx = ceil_div(num_anchors, max_threads) + nthread_by = batch_size + + ## The following algorithm performs parallel exclusive scan to get + ## a tensor that can later be used to select valid indices + # Up Sweep of exclusive scan + lim = tvm.tir.generic.cast( + tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(num_anchors, "float64"))), "int64" + ) + with ib.for_range(0, lim, dtype="int64") as l2_width: + width = 2 << l2_width + + with ib.new_scope(): + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr( + bx, + "thread_extent", + tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"), + ) + tid = bx * nthread_tx + tx + + by = te.thread_axis("blockIdx.y") + ib.scope_attr(by, "thread_extent", nthread_by) + start = ib.allocate("int64", (1,), name="start", scope="local") + middle = ib.allocate("int64", (1,), name="middle", scope="local") + end = ib.allocate("int64", (1,), name="end", scope="local") + start[0] = width * tid + with ib.if_scope(start[0] < num_anchors): + middle[0] = start[0] + tvm.tir.indexdiv(width, 2) + end[0] = tvm.te.min(start[0] + width, num_anchors) + with ib.if_scope(middle[0] < num_anchors): + valid_indices[by * num_anchors + end[0] - 1] += valid_indices[ + by * num_anchors + middle[0] - 1 + ] + + # Down Sweep of exclusive scan + with ib.new_scope(): + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(bx, "thread_extent", batch_size) + with ib.if_scope(bx < batch_size): + valid_count[bx] = valid_indices[(bx + 1) * num_anchors - 1] + valid_indices[(bx + 1) * num_anchors - 1] = 0 + + with ib.for_range(0, lim, dtype="int64") as l2_width: + width = 2 << (lim - l2_width - 1) + + with ib.new_scope(): + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr( + bx, + "thread_extent", + tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"), + ) + tid = bx * nthread_tx + tx + + by = te.thread_axis("blockIdx.y") + ib.scope_attr(by, "thread_extent", nthread_by) + start = ib.allocate("int64", (1,), name="start", scope="local") + middle = ib.allocate("int64", (1,), name="middle", scope="local") + end = ib.allocate("int64", (1,), name="end", scope="local") + tmp = ib.allocate("int32", (1,), name="end", scope="local") + start[0] = width * tid + with ib.if_scope(tvm.tir.all(start[0] < num_anchors)): + middle[0] = start[0] + tvm.tir.indexdiv(width, 2) + end[0] = tvm.tir.min(start[0] + width, num_anchors) + with ib.if_scope(middle[0] < num_anchors): + tmp[0] = valid_indices[by * num_anchors + middle[0] - 1] + valid_indices[by * num_anchors + middle[0] - 1] = valid_indices[ + by * num_anchors + end[0] - 1 + ] + valid_indices[by * num_anchors + end[0] - 1] += tmp[0] + return ib.get() -def get_valid_counts_ir(data, valid_indices, out, out_indices): +def get_valid_counts_ir(data, valid_indices, valid_boxes, out, out_indices): """Low level IR to get valid count of bounding boxes given a score threshold. Also prepares to move valid boxes to the top of input data. @@ -203,8 +275,9 @@ def get_valid_counts_ir(data, valid_indices, out, out_indices): ib = tvm.tir.ir_builder.create() data = ib.buffer_ptr(data) - valid_indices = ib.buffer_ptr(valid_indices) + valid_boxes = ib.buffer_ptr(valid_boxes) + out = ib.buffer_ptr(out) out_indices = ib.buffer_ptr(out_indices) one = tvm.tir.const(1, dtype=out.dtype) @@ -213,41 +286,36 @@ def get_valid_counts_ir(data, valid_indices, out, out_indices): nthread_tx = max_threads nthread_bx = num_anchors // max_threads + 1 nthread_by = batch_size - nthread_bz = elem_length with ib.new_scope(): tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") by = te.thread_axis("blockIdx.y") - bz = te.thread_axis("blockIdx.z") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) ib.scope_attr(by, "thread_extent", nthread_by) - ib.scope_attr(bz, "thread_extent", nthread_bz) tid = bx * max_threads + tx with ib.if_scope(tid < num_anchors): i = by j = tid - k = bz - out[(i * num_anchors + j) * elem_length + k] = -one + with ib.for_range(0, elem_length) as k: + out[(i * num_anchors + j) * elem_length + k] = -one out_indices[i * num_anchors + j] = -1 with ib.new_scope(): tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") by = te.thread_axis("blockIdx.y") - bz = te.thread_axis("blockIdx.z") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) ib.scope_attr(by, "thread_extent", nthread_by) - ib.scope_attr(bz, "thread_extent", nthread_bz) tid = bx * max_threads + tx with ib.if_scope(tid < num_anchors): i = by j = tid - k = bz - with ib.if_scope(valid_indices[i, tid] >= 0): - out[(i * num_anchors + valid_indices[i, tid]) * elem_length + k] = data[ - (i * num_anchors + j) * elem_length + k - ] + with ib.if_scope(valid_boxes[i, tid] > 0): + with ib.for_range(0, elem_length) as k: + out[(i * num_anchors + valid_indices[i, tid]) * elem_length + k] = data[ + (i * num_anchors + j) * elem_length + k + ] out_indices[i * num_anchors + valid_indices[i, tid]] = j return ib.get() @@ -321,10 +389,10 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): out, out_indices = te.extern( [data.shape, (batch_size, num_anchors)], - [data, valid_indices], - lambda ins, outs: get_valid_counts_ir(ins[0], ins[1], outs[0], outs[1]), + [data, valid_indices, valid_boxes], + lambda ins, outs: get_valid_counts_ir(ins[0], ins[1], ins[2], outs[0], outs[1]), dtype=["int32", data.dtype], - in_buffers=[data_buf, valid_indices_buf], + in_buffers=[data_buf, valid_indices_buf, valid_boxes_buf], out_buffers=[out_buf, out_indices_buf], name="get_valid_counts", tag="get_valid_counts_gpu", diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index 035d19f25ec7..cbf136a5552c 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -213,7 +213,7 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): out_indices: tvm.te.Tensor or numpy NDArray Related index in input data. """ - if isinstance(score_threshold, float): + if isinstance(score_threshold, (float, int)): score_threshold = tvm.tir.const(score_threshold, dtype=data.dtype) id_index_const = tvm.tir.const(id_index, "int32") score_index_const = tvm.tir.const(score_index, "int32") diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 1ce8a182f034..cdf3b240507b 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -313,10 +313,8 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): for target, ctx in tvm.testing.enabled_targets(): intrp = relay.create_executor("debug", ctx=ctx, target=target) out = intrp.evaluate(func)(np_data) + tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3, atol=1e-04) - # get_valid_count for opencl doesn't do data rearrangement - if target in ["opencl"]: - return tvm.testing.assert_allclose(out[1].asnumpy(), np_out2, rtol=1e-3, atol=1e-04) tvm.testing.assert_allclose(out[2].asnumpy(), np_out3, rtol=1e-3, atol=1e-04) diff --git a/tests/python/topi/python/test_topi_vision.py b/tests/python/topi/python/test_topi_vision.py index 778843be37de..697ef8a24f67 100644 --- a/tests/python/topi/python/test_topi_vision.py +++ b/tests/python/topi/python/test_topi_vision.py @@ -105,27 +105,18 @@ def check_device(device): tvm_out1 = tvm.nd.array(np.zeros(np_out1.shape, dtype="int32"), ctx) tvm_out2 = tvm.nd.array(np.zeros(np_out2.shape, dtype=dtype), ctx) tvm_out3 = tvm.nd.array(np.zeros(np_out3.shape, dtype="int32"), ctx) - if device == "llvm": - f = tvm.build(s, [data, outs[0], outs[1], outs[2]], device) - f(tvm_input_data, tvm_out1, tvm_out2, tvm_out3) - tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3) - tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3) - tvm.testing.assert_allclose(tvm_out3.asnumpy(), np_out3, rtol=1e-3) - else: - f = tvm.build(s, [data, outs[0], outs[1]], device) - f(tvm_input_data, tvm_out1, tvm_out2) - tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3) - tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3) + + f = tvm.build(s, [data, outs[0], outs[1], outs[2]], device) + f(tvm_input_data, tvm_out1, tvm_out2, tvm_out3) + tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3) + tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3) + tvm.testing.assert_allclose(tvm_out3.asnumpy(), np_out3, rtol=1e-3) for device in ["llvm", "cuda", "opencl"]: check_device(device) @tvm.testing.uses_gpu -@pytest.mark.skip( - "Skip this test as it is intermittent." - "See https://github.com/apache/tvm/pull/4901#issuecomment-595040094" -) def test_get_valid_counts(): verify_get_valid_counts((1, 1000, 5), 0.5, -1, 0) verify_get_valid_counts((1, 2500, 6), 0, 0, 1)