Skip to content

Commit

Permalink
introduce prefix_scan.py and move scan ir in nms.py
Browse files Browse the repository at this point in the history
  • Loading branch information
masa authored and masahi committed Dec 24, 2020
1 parent 667bdd3 commit 9876c90
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 114 deletions.
130 changes: 16 additions & 114 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from tvm.tir import if_then_else
from .sort import argsort, argsort_thrust
from .prefix_scan import exclusive_sum_scan2d


def cuda_atomic_add_rule(op):
Expand Down Expand Up @@ -123,122 +124,22 @@ def get_valid_boxes_ir(data, valid_boxes, score_threshold, id_index, score_index
return ib.get()


def get_valid_indices_ir(valid_boxes, valid_count, valid_indices):
"""Low level IR to get the ouput indices of valid boxes
and the count of valid boxes
Parameters
----------
valid_boxes: Buffer
2D Buffer indicating valid boxes with shape [batch_size, num_anchors].
Returns
-------
valid_count: Buffer
1D Buffer of number of valid boxes per batch [batch_size].
valid_indices: Buffer
2D Buffer indicating output sorted indcies of valid boxes [batch_size, num_anchors].
"""
def get_num_valid_boxes_ir(valid_boxes, valid_boxes_ex_scan, valid_count):
"""TODO"""
batch_size = valid_boxes.shape[0]
num_anchors = valid_boxes.shape[1]

ib = tvm.tir.ir_builder.create()

valid_boxes = ib.buffer_ptr(valid_boxes)

valid_boxes_ex_scan = ib.buffer_ptr(valid_boxes_ex_scan)
valid_count = ib.buffer_ptr(valid_count)
valid_indices = ib.buffer_ptr(valid_indices)

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)

def ceil_div(a, b):
return tvm.tir.indexdiv(a + b - 1, b)

# Copy boxes to valid_indices
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = ceil_div(num_anchors, max_threads)
nthread_by = batch_size
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
by = te.thread_axis("blockIdx.y")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
ib.scope_attr(by, "thread_extent", nthread_by)
tid = bx * nthread_tx + tx
with ib.if_scope(tid == 0):
valid_indices[by, 0] = 0
with ib.else_scope():
with ib.if_scope(tid < num_anchors):
valid_indices[by, tid] = valid_boxes[by, tid - 1]

nthread_tx = max_threads
nthread_bx = ceil_div(num_anchors, max_threads)
nthread_by = batch_size

## The following algorithm performs parallel prefix sum to get
## a tensor that can later be used to select valid indices
# Up Sweep of prefix sum
lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(num_anchors, "float64"))), "int64"
)
with ib.for_range(0, lim, dtype="int64") as l2_width:
width = 2 << l2_width

with ib.new_scope():
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(
bx,
"thread_extent",
tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"),
)
tid = bx * nthread_tx + tx

by = te.thread_axis("blockIdx.y")
ib.scope_attr(by, "thread_extent", nthread_by)
start = ib.allocate("int64", (1,), name="start", scope="local")
middle = ib.allocate("int64", (1,), name="middle", scope="local")
end = ib.allocate("int64", (1,), name="end", scope="local")
start[0] = width * tid
with ib.if_scope(start[0] < num_anchors):
middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
end[0] = tvm.te.min(start[0] + width, num_anchors)
with ib.if_scope(middle[0] < num_anchors):
valid_indices[by * num_anchors + end[0] - 1] += valid_indices[
by * num_anchors + middle[0] - 1
]

# Down Sweep of prefix sum
with ib.for_range(0, lim - 1, dtype="int64") as l2_width:
width = 2 << (lim - l2_width - 2)

with ib.new_scope():
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(
bx,
"thread_extent",
tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"),
)
tid = bx * nthread_tx + tx

by = te.thread_axis("blockIdx.y")
ib.scope_attr(by, "thread_extent", nthread_by)
start = ib.allocate("int64", (1,), name="start", scope="local")
middle = ib.allocate("int64", (1,), name="middle", scope="local")
end = ib.allocate("int64", (1,), name="end", scope="local")
start[0] = width * tid
with ib.if_scope(tvm.tir.all(start[0] > 0, start[0] < num_anchors)):
middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
with ib.if_scope(middle[0] < num_anchors):
valid_indices[by * num_anchors + middle[0] - 1] += valid_indices[
by * num_anchors + start[0] - 1
]

## Write Sum to valid_count
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
with ib.new_scope():
Expand All @@ -250,10 +151,8 @@ def ceil_div(a, b):
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
with ib.if_scope(tid < batch_size):
# Add valid_boxes[tid, num_anchors - 1] because valid_indices is
# an exclusive scan of valid_boxes
valid_count[tid] = (
valid_indices[tid, num_anchors - 1] + valid_boxes[tid, num_anchors - 1]
valid_boxes_ex_scan[tid, num_anchors - 1] + valid_boxes[tid, num_anchors - 1]
)

return ib.get()
Expand Down Expand Up @@ -388,15 +287,18 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1):
valid_count_buf = tvm.tir.decl_buffer(
(batch_size,), "int32", "valid_count_buf", data_alignment=8
)
valid_count, valid_indices = te.extern(
[(batch_size,), (batch_size, num_anchors)],
[valid_boxes],
lambda ins, outs: get_valid_indices_ir(ins[0], outs[0], outs[1]),

valid_indices = exclusive_sum_scan2d(valid_boxes)

valid_count = te.extern(
[(batch_size,)],
[valid_boxes, valid_indices],
lambda ins, outs: get_num_valid_boxes_ir(ins[0], ins[1], outs[0]),
dtype=["int32"],
in_buffers=[valid_boxes_buf],
out_buffers=[valid_count_buf, valid_indices_buf],
name="get_valid_indices",
tag="get_valid_indices_gpu",
in_buffers=[valid_boxes_buf, valid_indices_buf],
out_buffers=[valid_count_buf],
name="get_valid_indices_sum",
tag="get_valid_indices_sum_gpu",
)

out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf", data_alignment=8)
Expand Down
132 changes: 132 additions & 0 deletions python/tvm/topi/cuda/prefix_scan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"Scan related operators"
import tvm
from tvm import te


def exclusive_sum_scan2d_ir(data, output):
"""
TODO
"""
num_rows = data.shape[0]
scan_size = data.shape[1]

ib = tvm.tir.ir_builder.create()

data = ib.buffer_ptr(data)
output = ib.buffer_ptr(output)

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)

def ceil_div(a, b):
return tvm.tir.indexdiv(a + b - 1, b)

with ib.new_scope():
nthread_tx = max_threads
nthread_bx = ceil_div(scan_size, max_threads)
nthread_by = num_rows
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
by = te.thread_axis("blockIdx.y")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
ib.scope_attr(by, "thread_extent", nthread_by)
tid = bx * nthread_tx + tx
with ib.if_scope(tid == 0):
output[by, 0] = 0
with ib.else_scope():
with ib.if_scope(tid < scan_size):
output[by, tid] = data[by, tid - 1]

nthread_tx = max_threads
nthread_bx = ceil_div(scan_size, max_threads)
nthread_by = num_rows

# Up Sweep of prefix sum
lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(scan_size, "float64"))), "int64"
)
with ib.for_range(0, lim, dtype="int64") as l2_width:
width = 2 << l2_width

with ib.new_scope():
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(
bx,
"thread_extent",
tvm.tir.generic.cast(ceil_div(scan_size, max_threads * width), "int32"),
)
tid = bx * nthread_tx + tx

by = te.thread_axis("blockIdx.y")
ib.scope_attr(by, "thread_extent", nthread_by)
start = ib.allocate("int64", (1,), name="start", scope="local")
middle = ib.allocate("int64", (1,), name="middle", scope="local")
end = ib.allocate("int64", (1,), name="end", scope="local")
start[0] = width * tid
with ib.if_scope(start[0] < scan_size):
middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
end[0] = tvm.te.min(start[0] + width, scan_size)
with ib.if_scope(middle[0] < scan_size):
output[by * scan_size + end[0] - 1] += output[by * scan_size + middle[0] - 1]

# Down Sweep of prefix sum
with ib.for_range(0, lim - 1, dtype="int64") as l2_width:
width = 2 << (lim - l2_width - 2)

with ib.new_scope():
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(
bx,
"thread_extent",
tvm.tir.generic.cast(ceil_div(scan_size, max_threads * width), "int32"),
)
tid = bx * nthread_tx + tx

by = te.thread_axis("blockIdx.y")
ib.scope_attr(by, "thread_extent", nthread_by)
start = ib.allocate("int64", (1,), name="start", scope="local")
middle = ib.allocate("int64", (1,), name="middle", scope="local")
end = ib.allocate("int64", (1,), name="end", scope="local")
start[0] = width * tid
with ib.if_scope(tvm.tir.all(start[0] > 0, start[0] < scan_size)):
middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
with ib.if_scope(middle[0] < scan_size):
output[by * scan_size + middle[0] - 1] += output[by * scan_size + start[0] - 1]

return ib.get()


def exclusive_sum_scan2d(data):
data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
output_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "output_buf", data_alignment=8)

return te.extern(
[data.shape],
[data],
lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0]),
dtype=[data.dtype],
in_buffers=[data_buf],
out_buffers=[output_buf],
name="exclusive_sum_scan2d",
tag="exclusive_sum_scan2d_gpu",
)

0 comments on commit 9876c90

Please sign in to comment.