Skip to content

Commit

Permalink
Parallelize cumsum in get_valid_counts
Browse files Browse the repository at this point in the history
  • Loading branch information
mbrookhart committed Dec 17, 2020
1 parent 829be98 commit e162f84
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 17 deletions.
2 changes: 1 addition & 1 deletion python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def _build_for_device(input_mod, target, target_host):
lambda f: "calling_conv" not in f.attrs
or f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH
),
tvm.tir.transform.Apply(lambda f: f.with_attr("target", target)),
tvm.tir.transform.Apply(lambda f: f.with_attr("target", target_host)),
tvm.tir.transform.LowerTVMBuiltin(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerCustomDatatypes(),
Expand Down
117 changes: 104 additions & 13 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,28 +211,119 @@ def get_valid_indices_ir(valid_boxes, valid_count, valid_indices):
valid_count = ib.buffer_ptr(valid_count)
valid_indices = ib.buffer_ptr(valid_indices)

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)

def ceil_div(a, b):
return tvm.tir.indexdiv(a + b - 1, b)

# Copy boxes to valid_indices
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = ceil_div(batch_size * num_anchors, max_threads)
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
with ib.if_scope(tid < batch_size * num_anchors):
valid_indices[tid] = valid_boxes[tid]

nthread_tx = max_threads
nthread_bx = ceil_div(num_anchors, max_threads)
nthread_by = batch_size

## The following algorithm performs parallel prefix sum to get
## a tensor that can later be used to select valid indices
# Up Sweep of prefix sum
lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(num_anchors, "float64"))), "int64"
)
with ib.for_range(0, lim, dtype="int64") as l2_width:
width = 2 << l2_width

with ib.new_scope():
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(
bx,
"thread_extent",
tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"),
)
tid = bx * nthread_tx + tx

by = te.thread_axis("blockIdx.y")
ib.scope_attr(by, "thread_extent", nthread_by)
start = ib.allocate("int64", (1,), name="start", scope="local")
middle = ib.allocate("int64", (1,), name="middle", scope="local")
end = ib.allocate("int64", (1,), name="end", scope="local")
start[0] = width * tid
with ib.if_scope(start[0] < num_anchors):
middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
end[0] = tvm.te.min(start[0] + width, num_anchors)
with ib.if_scope(middle[0] < num_anchors):
valid_indices[by * num_anchors + end[0] - 1] += valid_indices[
by * num_anchors + middle[0] - 1
]

# Down Sweep of prefix sum
with ib.for_range(0, lim - 1, dtype="int64") as l2_width:
width = 2 << (lim - l2_width - 2)

with ib.new_scope():
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(
bx,
"thread_extent",
tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"),
)
tid = bx * nthread_tx + tx

by = te.thread_axis("blockIdx.y")
ib.scope_attr(by, "thread_extent", nthread_by)
start = ib.allocate("int64", (1,), name="start", scope="local")
middle = ib.allocate("int64", (1,), name="middle", scope="local")
end = ib.allocate("int64", (1,), name="end", scope="local")
start[0] = width * tid
with ib.if_scope(tvm.tir.all(start[0] > 0, start[0] < num_anchors)):
middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
with ib.if_scope(middle[0] < num_anchors):
valid_indices[by * num_anchors + middle[0] - 1] += valid_indices[
by * num_anchors + start[0] - 1
]

## Write Sum to valid_count
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = batch_size // max_threads + 1
nthread_bx = ceil_div(batch_size, max_threads)
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
# TODO(mbrookhart): Parallelize the sum and cumsum here
current_index = ib.allocate("int32", (1,), name="current_index", scope="local")
with ib.if_scope(tid < batch_size):
current_index[0] = 0
valid_count[tid] = 0
with ib.for_range(0, num_anchors) as j:
idx = tid * num_anchors + j
valid_count[tid] = valid_count[tid] + valid_boxes[idx]
with ib.if_scope(valid_boxes[idx] == 1):
valid_indices[idx] = current_index[0]
current_index[0] = current_index[0] + 1
with ib.else_scope():
valid_indices[idx] = -1
valid_count[tid] = valid_indices[tid * num_anchors + num_anchors - 1]

## Remove invalid indices
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = ceil_div(batch_size * num_anchors, max_threads)
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
with ib.if_scope(tid < batch_size * num_anchors):
with ib.if_scope(valid_boxes[tid] < 1):
# if this is an invalid box, mark -1
valid_indices[tid] = -1
with ib.else_scope():
# if this is a valid box, subtract 1 to get 0-based indexing
valid_indices[tid] += -1

return ib.get()


Expand Down
4 changes: 1 addition & 3 deletions tests/python/relay/test_op_level5.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,10 +313,8 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index):
for target, ctx in tvm.testing.enabled_targets():
intrp = relay.create_executor("debug", ctx=ctx, target=target)
out = intrp.evaluate(func)(np_data)

tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3, atol=1e-04)
# get_valid_count for opencl doesn't do data rearrangement
if target in ["opencl"]:
return
tvm.testing.assert_allclose(out[1].asnumpy(), np_out2, rtol=1e-3, atol=1e-04)
tvm.testing.assert_allclose(out[2].asnumpy(), np_out3, rtol=1e-3, atol=1e-04)

Expand Down

0 comments on commit e162f84

Please sign in to comment.