Skip to content

Commit

Permalink
fix pylint
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurawly committed Apr 14, 2020
1 parent 6fce2c9 commit 62fb1df
Showing 1 changed file with 291 additions and 17 deletions.
308 changes: 291 additions & 17 deletions topi/python/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison
# pylint: disable=bad-continuation, unused-argument
"""Non-maximum suppression operator"""
import math
import tvm
from tvm import te

from tvm.tir import if_then_else
from .sort import argsort, argsort_thrust
from .sort import argsort
from .. import tag


Expand All @@ -44,7 +43,7 @@ def atomic_add(x, y):
return tvm.tir.call_pure_intrin(y.dtype, "atomic_add", x, y)


def get_valid_counts_ir(data, valid_count, out, score_threshold, id_index, score_index):
def get_valid_counts_ir(data, valid_count, flag, score_threshold, id_index, score_index):
"""Low level IR to get valid count of bounding boxes
given a score threshold. Also prepares to move valid boxes to the
top of input data.
Expand Down Expand Up @@ -83,11 +82,10 @@ def get_valid_counts_ir(data, valid_count, out, score_threshold, id_index, score
data = ib.buffer_ptr(data)

valid_count = ib.buffer_ptr(valid_count)
out = ib.buffer_ptr(out)
flag = ib.buffer_ptr(flag)
atomic_add_return = ib.allocate(
valid_count.dtype, (1,), name='atomic_add_return', scope='local')
one_count = tvm.tir.const(1, dtype=valid_count.dtype)
one = tvm.tir.const(1, dtype=out.dtype)
score_threshold = tvm.ir.make_node(
"FloatImm", dtype="float32", value=score_threshold)
id_index = tvm.ir.make_node("IntImm", dtype="int32", value=id_index)
Expand All @@ -107,16 +105,132 @@ def get_valid_counts_ir(data, valid_count, out, score_threshold, id_index, score
# initialize valid_count
with ib.if_scope(tid < batch_size):
valid_count[tid] = 0
# initialize flag
with ib.if_scope(tid < batch_size * num_anchors):
flag[tid] = 0
with ib.if_scope(tid < batch_size * num_anchors):
i = idxd(tid, num_anchors)
with ib.if_scope(
tvm.tir.all(data[tid * elem_length + score_index] > score_threshold,
tvm.tir.any(id_index < 0, data[tid * elem_length + id_index] >= 0))):
flag[tid] = 1
atomic_add_return[0] = atomic_add(tvm.tir.call_pure_intrin("handle", "tvm_address_of",
valid_count[i]), one_count)

return ib.get()


def flag_scan(flag, prefix_sum):
"""Low level IR to calculate correct positions for valid boxes.
Parameters
----------
flag : Buffer
2D Buffer of flag indicating valid data with shape [batch_size, num_anchors].
prefix_sum : Buffer
2D Buffer of prefix sum of flags indicating new locations of valid boxes
with same shape as flag.
Returns
-------
stmt : Stmt
The result IR statement.
"""
batch_size = flag.shape[0]
num_anchors = flag.shape[1]

ib = tvm.tir.ir_builder.create()

flag = ib.buffer_ptr(flag)
prefix_sum = ib.buffer_ptr(prefix_sum)

max_threads = int(tvm.target.Target.current(
allow_none=False).max_num_threads)
nthread_tx = max_threads
nthread_bx = batch_size * num_anchors // max_threads + 1
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
idxd = tvm.tir.indexdiv
idxm = tvm.tir.indexmod

# initialize prefix_sum
with ib.if_scope(tid < batch_size * num_anchors):
prefix_sum[tid] = 0
with ib.if_scope(tid < batch_size * num_anchors):
i = idxd(tid, num_anchors)
j = idxm(tid, num_anchors)
with ib.for_range(0, j) as r:
prefix_sum[tid] += flag[i * num_anchors + r]

return ib.get()


def out_rewrite(data, flag, prefix_sum, valid_count, out):
"""Low level IR to move valid boxes to the
top of input data.
Parameters
----------
data : Buffer
Input data. 3-D Buffer with shape [batch_size, num_anchors, elem_length].
flag : Buffer
2D Buffer of flag indicating valid data with shape [batch_size, num_anchors].
prefix_sum : Buffer
2D Buffer of prefix sum of flags indicating new locations of valid boxes
with same shape as flag.
valid_count : Buffer
1D buffer for valid number of boxes with shape [batch_size, ].
out : Buffer
Rearranged data buffer.
Returns
-------
stmt : Stmt
The result IR statement.
"""
batch_size = out.shape[0]
num_anchors = out.shape[1]
elem_length = out.shape[2]

ib = tvm.tir.ir_builder.create()

one = tvm.tir.const(1, dtype=out.dtype)
data = ib.buffer_ptr(data)
flag = ib.buffer_ptr(flag)
valid_count = ib.buffer_ptr(valid_count)
prefix_sum = ib.buffer_ptr(prefix_sum)
out = ib.buffer_ptr(out)

max_threads = int(tvm.target.Target.current(
allow_none=False).max_num_threads)
nthread_tx = max_threads
nthread_bx = batch_size * num_anchors // max_threads + 1
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
idxd = tvm.tir.indexdiv
idxm = tvm.tir.indexmod

with ib.if_scope(tid < batch_size * num_anchors):
i = idxd(tid, num_anchors)
j = idxm(tid, num_anchors)
base_idx = i * num_anchors * elem_length
with ib.if_scope(tvm.tir.all(flag[tid] > 0, prefix_sum[tid] >= 0,
prefix_sum[tid] < num_anchors)):
with ib.for_range(0, elem_length) as k:
out[tid * elem_length + k] = data[tid * elem_length + k]
with ib.else_scope():
out[base_idx + prefix_sum[tid] * elem_length +
k] = data[tid * elem_length + k]
with ib.if_scope(j >= valid_count[i]):
with ib.for_range(0, elem_length) as k:
out[tid * elem_length + k] = -one

Expand Down Expand Up @@ -150,23 +264,47 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1):
Rearranged data tensor.
"""
batch_size = data.shape[0]
num_anchors = data.shape[1]
data_buf = tvm.tir.decl_buffer(
data.shape, data.dtype, "data_buf", data_alignment=8)
valid_count_buf = tvm.tir.decl_buffer(
(batch_size,), "int32", "valid_count_buf", data_alignment=8)
temp_flag_buf = tvm.tir.decl_buffer(
(batch_size, num_anchors,), "int32", "temp_flag", data_alignment=8)
temp_partial_buf = tvm.tir.decl_buffer(
(batch_size, num_anchors), "int32", "temp_partial", data_alignment=8)
out_buf = tvm.tir.decl_buffer(
data.shape, data.dtype, "out_buf", data_alignment=8)

valid_count, out = \
te.extern([(batch_size,), data.shape], [data],
valid_count, temp_flag = \
te.extern([(batch_size,), (batch_size, num_anchors)], [data],
lambda ins, outs: get_valid_counts_ir(
ins[0], outs[0], outs[1], score_threshold, id_index, score_index),
dtype=["int32", data.dtype],
dtype=["int32", "int32"],
in_buffers=[data_buf],
out_buffers=[valid_count_buf, out_buf],
out_buffers=[valid_count_buf, temp_flag_buf],
name="get_valid_counts",
tag="get_valid_counts_gpu")

temp_partial = \
te.extern([(batch_size, num_anchors)], [temp_flag],
lambda ins, outs: flag_scan(
ins[0], outs[0]),
dtype=["int32"],
in_buffers=[temp_flag_buf],
out_buffers=[temp_partial_buf],
name="flag_scan")

out = \
te.extern([data.shape], [data, temp_flag, temp_partial, valid_count],
lambda ins, outs: out_rewrite(
ins[0], ins[1], ins[2], ins[3], outs[0]),
dtype=[data.dtype],
in_buffers=[data_buf, temp_flag_buf,
temp_partial_buf, valid_count_buf],
out_buffers=[out_buf],
name="out_rewrite")

return [valid_count, out]


Expand Down Expand Up @@ -336,6 +474,117 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
return ib.get()


def invalid_to_bottom_pre(data, flag, idx):
"""Low level IR to rearrange nms output to move all valid entries to top.
Parameters
----------
data: Buffer
3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
flag : Buffer
1D Buffer of flag indicating valid data with [num_anchors].
idx : Buffer
1D Buffer of valid data indices with [num_anchors].
Returns
-------
stmt : Stmt
The result IR statement.
"""
batch_size = data.shape[0]
num_anchors = data.shape[1]
elem_length = data.shape[2]

ib = tvm.tir.ir_builder.create()

data = ib.buffer_ptr(data)
flag = ib.buffer_ptr(flag)
idx = ib.buffer_ptr(idx)

max_threads = int(math.sqrt(
tvm.target.Target.current(allow_none=False).max_num_threads))
nthread_tx = max_threads
nthread_bx = num_anchors // max_threads + 1
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
j = bx * max_threads + tx

with ib.for_range(0, batch_size, for_type="unroll") as i:
base_idx = i * num_anchors * elem_length
with ib.if_scope(j < num_anchors):
with ib.if_scope(data[base_idx + j * elem_length] >= 0):
flag[i * num_anchors + j] = 1
idx[i * num_anchors + j] = 1
with ib.else_scope():
flag[i * num_anchors + j] = 0
idx[i * num_anchors + j] = 0

with ib.if_scope(j < batch_size):
with ib.for_range(0, num_anchors) as k:
with ib.if_scope(k > 0):
idx[j * num_anchors + k] += idx[j * num_anchors + k - 1]
return ib.get()


def invalid_to_bottom_ir(data, flag, idx, out):
"""Low level IR to rearrange nms output to move all valid entries to top.
Parameters
----------
data: Buffer
3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
flag : Buffer
1D Buffer of flag indicating valid data with [num_anchors].
idx : Buffer
1D Buffer of valid data indices with [num_anchors].
out : Buffer
3D Buffer of rearranged nms output with shape [batch_size, num_anchors, elem_length].
Returns
-------
stmt : Stmt
The result IR statement.
"""
batch_size = data.shape[0]
num_anchors = data.shape[1]
elem_length = data.shape[2]

ib = tvm.tir.ir_builder.create()

data = ib.buffer_ptr(data)
flag = ib.buffer_ptr(flag)
idx = ib.buffer_ptr(idx)
out = ib.buffer_ptr(out)

max_threads = int(math.sqrt(
tvm.target.Target.current(allow_none=False).max_num_threads))
nthread_tx = max_threads
nthread_bx = num_anchors // max_threads + 1
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
j = bx * max_threads + tx

with ib.for_range(0, batch_size, for_type="unroll") as i:
base_idx = i * num_anchors * elem_length
with ib.if_scope(j < num_anchors):
with ib.for_range(0, elem_length) as k:
out[base_idx + j * elem_length + k] = -1.0
with ib.if_scope(flag[i * num_anchors + j] > 0):
with ib.for_range(0, elem_length) as k:
out[base_idx + (idx[i * num_anchors + j] - 1) * elem_length + k] \
= data[base_idx + j * elem_length + k]
return ib.get()


def non_max_suppression(data, valid_count, max_output_size=-1,
iou_threshold=0.5, force_suppress=False, top_k=-1,
coord_start=2, score_index=1, id_index=0,
Expand Down Expand Up @@ -418,19 +667,18 @@ def non_max_suppression(data, valid_count, max_output_size=-1,
score_shape = (batch_size, num_anchors)
score_tensor = te.compute(
score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE)
if tvm.get_global_func("tvm.contrib.thrust.sort_nms", allow_missing=True):
sort_tensor = argsort_thrust(
score_tensor, valid_count=None, axis=1, is_ascend=False, dtype=valid_count_dtype)
else:
sort_tensor = argsort(
score_tensor, valid_count=None, axis=1, is_ascend=False, dtype=valid_count_dtype)
sort_tensor = argsort(
score_tensor, valid_count=valid_count, axis=1, is_ascend=False)

sort_tensor_buf = tvm.tir.decl_buffer(sort_tensor.shape, sort_tensor.dtype,
"sort_tensor_buf", data_alignment=8)

data_buf = tvm.tir.decl_buffer(
data.shape, data.dtype, "data_buf", data_alignment=8)

out_buf = tvm.tir.decl_buffer(
data.shape, data.dtype, "out_buf", data_alignment=8)

out, box_indices = \
te.extern([data.shape, score_shape],
[data, sort_tensor, valid_count],
Expand All @@ -446,4 +694,30 @@ def non_max_suppression(data, valid_count, max_output_size=-1,
if return_indices:
return box_indices

if invalid_to_bottom:
output_buf = tvm.tir.decl_buffer(
data.shape, data.dtype, "output_buf", data_alignment=8)
temp_flag_buf = tvm.tir.decl_buffer(
score_shape, valid_count_dtype, "temp_flag", data_alignment=8)
temp_idx_buf = tvm.tir.decl_buffer(
score_shape, valid_count_dtype, "temp_idx", data_alignment=8)
temp_flag, temp_idx = te.extern([score_shape, score_shape], [out],
lambda ins, outs: invalid_to_bottom_pre(
ins[0], outs[0], outs[1]),
dtype=["int32", "int32"],
in_buffers=[out_buf],
out_buffers=[
temp_flag_buf, temp_idx_buf],
name="invalid_to_bottom_phase_one")

output = te.extern([data.shape], [out, temp_flag, temp_idx],
lambda ins, outs: invalid_to_bottom_ir(
ins[0], ins[1], ins[2], outs[0]),
dtype=[data.dtype],
in_buffers=[out_buf, temp_flag_buf, temp_idx_buf],
out_buffers=[output_buf],
name="invalid_to_bottom",
tag="invalid_to_bottom")
return output

return out

0 comments on commit 62fb1df

Please sign in to comment.