Skip to content

Commit

Permalink
[TOPI] Make cumsum IR reusable, add thrust scan (apache#7303)
Browse files Browse the repository at this point in the history
* import changes from scan branch

commit cf0d4fd
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri Dec 25 10:12:01 2020 +0900

    get valid count test working

commit eb142d3
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri Dec 25 07:22:00 2020 +0900

    integrate new cumsum change

commit f89684d
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri Dec 25 06:56:46 2020 +0900

    remove ceil_div from nms

commit a2ad4de
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sun Dec 20 20:36:34 2020 +0900

    add api for returning reduction from ex scan output

commit b7f4ef7
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sun Dec 20 19:49:07 2020 +0900

    move ceil_div to utils

commit a9a57e3
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sun Dec 20 19:38:15 2020 +0900

    rename prefix_scan.py to scan.py

commit 03ed43f
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat Dec 19 06:12:55 2020 +0900

    surpress cpplint

commit abceac9
Author: masa <masa@pop-os.localdomain>
Date:   Fri Dec 18 20:36:24 2020 +0900

    support more data type

commit 3e7d1f8
Author: masa <masa@pop-os.localdomain>
Date:   Fri Dec 18 20:09:51 2020 +0900

    1d thrust scan working

commit ac13b40
Author: masa <masa@pop-os.localdomain>
Date:   Fri Dec 18 19:49:25 2020 +0900

    adding thrust scan support

commit 65634e8
Author: masa <masa@pop-os.localdomain>
Date:   Fri Dec 18 19:01:11 2020 +0900

    add thrust scan python stub

commit 9876c90
Author: masa <masa@pop-os.localdomain>
Date:   Fri Dec 18 20:55:14 2020 +0900

    introduce prefix_scan.py and move scan ir in nms.py

commit 667bdd3
Author: masa <masa@pop-os.localdomain>
Date:   Fri Dec 18 15:06:18 2020 +0900

    make the scan loop exclusive

commit 480787b
Author: mbrookhart <mbrookhart@octoml.ai>
Date:   Thu Dec 17 10:01:11 2020 -0700

    Parallelize cumsum in get_valid_counts

* fix for 1d scan

* rename

* cast to out dtype

* do not run return reduction for inclusive scan

* remove another ceil_div definition

* adding scan test

* add scheduling for scan op, fixed scan 1d test

* pylint fix

* add doc string

* add more thrust scan test

* add dynamic get valid count test, including empty size tensor

* fix hard coded gpu targets for cpu only env

* try retunring early if scan_size is 0

* another change for empty tensor and thrust path

Co-authored-by: masa <masa@pop-os.localdomain>
  • Loading branch information
2 people authored and Tushar Dey committed Jan 20, 2021
1 parent 1346826 commit 7a379ee
Show file tree
Hide file tree
Showing 10 changed files with 666 additions and 194 deletions.
151 changes: 4 additions & 147 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

from tvm.tir import if_then_else
from .sort import argsort, argsort_thrust, is_thrust_available
from .scan import exclusive_scan
from ..utils import ceil_div


def cuda_atomic_add_rule(op):
Expand Down Expand Up @@ -51,10 +53,6 @@ def atomic_add(x, y):
return tvm.tir.call_intrin(y.dtype, "tir.atomic_add", x, y)


def ceil_div(a, b):
return tvm.tir.indexdiv(a + b - 1, b)


def get_valid_boxes_ir(data, valid_boxes, score_threshold, id_index, score_index):
"""Low level IR to identify bounding boxes given a score threshold.
Expand Down Expand Up @@ -123,136 +121,6 @@ def get_valid_boxes_ir(data, valid_boxes, score_threshold, id_index, score_index
return ib.get()


def get_valid_indices_ir(valid_boxes, valid_count, valid_indices):
"""Low level IR to get the ouput indices of valid boxes
and the count of valid boxes
Parameters
----------
valid_boxes: Buffer
2D Buffer indicating valid boxes with shape [batch_size, num_anchors].
Returns
-------
valid_count: Buffer
1D Buffer of number of valid boxes per batch [batch_size].
valid_indices: Buffer
2D Buffer indicating output sorted indcies of valid boxes [batch_size, num_anchors].
"""
batch_size = valid_boxes.shape[0]
num_anchors = valid_boxes.shape[1]

ib = tvm.tir.ir_builder.create()

valid_boxes = ib.buffer_ptr(valid_boxes)

valid_count = ib.buffer_ptr(valid_count)
valid_indices = ib.buffer_ptr(valid_indices)

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
with ib.if_scope(num_anchors > 0):
# Copy boxes to valid_indices
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = ceil_div(num_anchors, max_threads)
nthread_by = batch_size
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
by = te.thread_axis("blockIdx.y")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
ib.scope_attr(by, "thread_extent", nthread_by)
tid = bx * nthread_tx + tx
with ib.if_scope(tid < num_anchors):
valid_indices[by, tid] = valid_boxes[by, tid]

nthread_tx = max_threads
nthread_bx = ceil_div(num_anchors, max_threads)
nthread_by = batch_size

## The following algorithm performs parallel exclusive scan to get
## a tensor that can later be used to select valid indices
# Up Sweep of exclusive scan
lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(num_anchors, "float64"))), "int64"
)
with ib.for_range(0, lim, dtype="int64") as l2_width:
width = 2 << l2_width

with ib.new_scope():
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(
bx,
"thread_extent",
tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"),
)
tid = bx * nthread_tx + tx

by = te.thread_axis("blockIdx.y")
ib.scope_attr(by, "thread_extent", nthread_by)
start = ib.allocate("int64", (1,), name="start", scope="local")
middle = ib.allocate("int64", (1,), name="middle", scope="local")
end = ib.allocate("int64", (1,), name="end", scope="local")
start[0] = width * tid
with ib.if_scope(start[0] < num_anchors):
middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
end[0] = tvm.te.min(start[0] + width, num_anchors)
with ib.if_scope(middle[0] < num_anchors):
valid_indices[by * num_anchors + end[0] - 1] += valid_indices[
by * num_anchors + middle[0] - 1
]

# Down Sweep of exclusive scan
with ib.new_scope():
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(bx, "thread_extent", batch_size)
with ib.if_scope(bx < batch_size):
valid_count[bx] = valid_indices[(bx + 1) * num_anchors - 1]
valid_indices[(bx + 1) * num_anchors - 1] = 0

with ib.for_range(0, lim, dtype="int64") as l2_width:
width = 2 << (lim - l2_width - 1)

with ib.new_scope():
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(
bx,
"thread_extent",
tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"),
)
tid = bx * nthread_tx + tx

by = te.thread_axis("blockIdx.y")
ib.scope_attr(by, "thread_extent", nthread_by)
start = ib.allocate("int64", (1,), name="start", scope="local")
middle = ib.allocate("int64", (1,), name="middle", scope="local")
end = ib.allocate("int64", (1,), name="end", scope="local")
tmp = ib.allocate("int32", (1,), name="end", scope="local")
start[0] = width * tid
with ib.if_scope(tvm.tir.all(start[0] < num_anchors)):
middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
end[0] = tvm.tir.min(start[0] + width, num_anchors)
with ib.if_scope(middle[0] < num_anchors):
tmp[0] = valid_indices[by * num_anchors + middle[0] - 1]
valid_indices[by * num_anchors + middle[0] - 1] = valid_indices[
by * num_anchors + end[0] - 1
]
valid_indices[by * num_anchors + end[0] - 1] += tmp[0]
with ib.else_scope():
with ib.new_scope():
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(bx, "thread_extent", batch_size)
with ib.if_scope(bx < batch_size):
valid_count[bx] = 0

return ib.get()


def get_valid_counts_ir(data, valid_indices, valid_boxes, out, out_indices):
"""Low level IR to get valid count of bounding boxes
given a score threshold. Also prepares to move valid boxes to the
Expand Down Expand Up @@ -374,19 +242,8 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1):
valid_indices_buf = tvm.tir.decl_buffer(
(batch_size, num_anchors), "int32", "valid_indices_buf", data_alignment=8
)
valid_count_buf = tvm.tir.decl_buffer(
(batch_size,), "int32", "valid_count_buf", data_alignment=8
)
valid_count, valid_indices = te.extern(
[(batch_size,), (batch_size, num_anchors)],
[valid_boxes],
lambda ins, outs: get_valid_indices_ir(ins[0], outs[0], outs[1]),
dtype=["int32"],
in_buffers=[valid_boxes_buf],
out_buffers=[valid_count_buf, valid_indices_buf],
name="get_valid_indices",
tag="get_valid_indices_gpu",
)

valid_indices, valid_count = exclusive_scan(valid_boxes, axis=1, return_reduction=True)

out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf", data_alignment=8)
out_indices_buf = tvm.tir.decl_buffer(
Expand Down
Loading

0 comments on commit 7a379ee

Please sign in to comment.