From e162f840b8b5df3a2eea2424e6e802644d2cc6e9 Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Thu, 17 Dec 2020 10:01:11 -0700 Subject: [PATCH] Parallelize cumsum in get_valid_counts --- python/tvm/driver/build_module.py | 2 +- python/tvm/topi/cuda/nms.py | 117 ++++++++++++++++++++++++--- tests/python/relay/test_op_level5.py | 4 +- 3 files changed, 106 insertions(+), 17 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 058bd62d6226..dc9d741b2726 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -277,7 +277,7 @@ def _build_for_device(input_mod, target, target_host): lambda f: "calling_conv" not in f.attrs or f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH ), - tvm.tir.transform.Apply(lambda f: f.with_attr("target", target)), + tvm.tir.transform.Apply(lambda f: f.with_attr("target", target_host)), tvm.tir.transform.LowerTVMBuiltin(), tvm.tir.transform.LowerDeviceStorageAccessInfo(), tvm.tir.transform.LowerCustomDatatypes(), diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index d0915d9aa55f..116852e7bf01 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -211,28 +211,119 @@ def get_valid_indices_ir(valid_boxes, valid_count, valid_indices): valid_count = ib.buffer_ptr(valid_count) valid_indices = ib.buffer_ptr(valid_indices) + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + + def ceil_div(a, b): + return tvm.tir.indexdiv(a + b - 1, b) + + # Copy boxes to valid_indices + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size * num_anchors, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size * num_anchors): + valid_indices[tid] = valid_boxes[tid] + + nthread_tx = max_threads + nthread_bx = ceil_div(num_anchors, max_threads) + nthread_by = batch_size + + ## The following algorithm performs parallel prefix sum to get + ## a tensor that can later be used to select valid indices + # Up Sweep of prefix sum + lim = tvm.tir.generic.cast( + tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(num_anchors, "float64"))), "int64" + ) + with ib.for_range(0, lim, dtype="int64") as l2_width: + width = 2 << l2_width + + with ib.new_scope(): + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr( + bx, + "thread_extent", + tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"), + ) + tid = bx * nthread_tx + tx + + by = te.thread_axis("blockIdx.y") + ib.scope_attr(by, "thread_extent", nthread_by) + start = ib.allocate("int64", (1,), name="start", scope="local") + middle = ib.allocate("int64", (1,), name="middle", scope="local") + end = ib.allocate("int64", (1,), name="end", scope="local") + start[0] = width * tid + with ib.if_scope(start[0] < num_anchors): + middle[0] = start[0] + tvm.tir.indexdiv(width, 2) + end[0] = tvm.te.min(start[0] + width, num_anchors) + with ib.if_scope(middle[0] < num_anchors): + valid_indices[by * num_anchors + end[0] - 1] += valid_indices[ + by * num_anchors + middle[0] - 1 + ] + + # Down Sweep of prefix sum + with ib.for_range(0, lim - 1, dtype="int64") as l2_width: + width = 2 << (lim - l2_width - 2) + + with ib.new_scope(): + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr( + bx, + "thread_extent", + tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"), + ) + tid = bx * nthread_tx + tx + + by = te.thread_axis("blockIdx.y") + ib.scope_attr(by, "thread_extent", nthread_by) + start = ib.allocate("int64", (1,), name="start", scope="local") + middle = ib.allocate("int64", (1,), name="middle", scope="local") + end = ib.allocate("int64", (1,), name="end", scope="local") + start[0] = width * tid + with ib.if_scope(tvm.tir.all(start[0] > 0, start[0] < num_anchors)): + middle[0] = start[0] + tvm.tir.indexdiv(width, 2) + with ib.if_scope(middle[0] < num_anchors): + valid_indices[by * num_anchors + middle[0] - 1] += valid_indices[ + by * num_anchors + start[0] - 1 + ] + + ## Write Sum to valid_count max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) with ib.new_scope(): nthread_tx = max_threads - nthread_bx = batch_size // max_threads + 1 + nthread_bx = ceil_div(batch_size, max_threads) tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * max_threads + tx - # TODO(mbrookhart): Parallelize the sum and cumsum here - current_index = ib.allocate("int32", (1,), name="current_index", scope="local") with ib.if_scope(tid < batch_size): - current_index[0] = 0 - valid_count[tid] = 0 - with ib.for_range(0, num_anchors) as j: - idx = tid * num_anchors + j - valid_count[tid] = valid_count[tid] + valid_boxes[idx] - with ib.if_scope(valid_boxes[idx] == 1): - valid_indices[idx] = current_index[0] - current_index[0] = current_index[0] + 1 - with ib.else_scope(): - valid_indices[idx] = -1 + valid_count[tid] = valid_indices[tid * num_anchors + num_anchors - 1] + + ## Remove invalid indices + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size * num_anchors, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size * num_anchors): + with ib.if_scope(valid_boxes[tid] < 1): + # if this is an invalid box, mark -1 + valid_indices[tid] = -1 + with ib.else_scope(): + # if this is a valid box, subtract 1 to get 0-based indexing + valid_indices[tid] += -1 + return ib.get() diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 1ce8a182f034..cdf3b240507b 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -313,10 +313,8 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): for target, ctx in tvm.testing.enabled_targets(): intrp = relay.create_executor("debug", ctx=ctx, target=target) out = intrp.evaluate(func)(np_data) + tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3, atol=1e-04) - # get_valid_count for opencl doesn't do data rearrangement - if target in ["opencl"]: - return tvm.testing.assert_allclose(out[1].asnumpy(), np_out2, rtol=1e-3, atol=1e-04) tvm.testing.assert_allclose(out[2].asnumpy(), np_out3, rtol=1e-3, atol=1e-04)