From 62fb1df9b8f97de11e1531bec05c239fc82d701a Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Tue, 14 Apr 2020 21:03:07 +0000 Subject: [PATCH] fix pylint --- topi/python/topi/cuda/nms.py | 308 +++++++++++++++++++++++++++++++++-- 1 file changed, 291 insertions(+), 17 deletions(-) diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index 13f97987094ef..8ea62d12c969d 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -17,12 +17,11 @@ # pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison # pylint: disable=bad-continuation, unused-argument """Non-maximum suppression operator""" -import math import tvm from tvm import te from tvm.tir import if_then_else -from .sort import argsort, argsort_thrust +from .sort import argsort from .. import tag @@ -44,7 +43,7 @@ def atomic_add(x, y): return tvm.tir.call_pure_intrin(y.dtype, "atomic_add", x, y) -def get_valid_counts_ir(data, valid_count, out, score_threshold, id_index, score_index): +def get_valid_counts_ir(data, valid_count, flag, score_threshold, id_index, score_index): """Low level IR to get valid count of bounding boxes given a score threshold. Also prepares to move valid boxes to the top of input data. @@ -83,11 +82,10 @@ def get_valid_counts_ir(data, valid_count, out, score_threshold, id_index, score data = ib.buffer_ptr(data) valid_count = ib.buffer_ptr(valid_count) - out = ib.buffer_ptr(out) + flag = ib.buffer_ptr(flag) atomic_add_return = ib.allocate( valid_count.dtype, (1,), name='atomic_add_return', scope='local') one_count = tvm.tir.const(1, dtype=valid_count.dtype) - one = tvm.tir.const(1, dtype=out.dtype) score_threshold = tvm.ir.make_node( "FloatImm", dtype="float32", value=score_threshold) id_index = tvm.ir.make_node("IntImm", dtype="int32", value=id_index) @@ -107,16 +105,132 @@ def get_valid_counts_ir(data, valid_count, out, score_threshold, id_index, score # initialize valid_count with ib.if_scope(tid < batch_size): valid_count[tid] = 0 + # initialize flag + with ib.if_scope(tid < batch_size * num_anchors): + flag[tid] = 0 with ib.if_scope(tid < batch_size * num_anchors): i = idxd(tid, num_anchors) with ib.if_scope( tvm.tir.all(data[tid * elem_length + score_index] > score_threshold, tvm.tir.any(id_index < 0, data[tid * elem_length + id_index] >= 0))): + flag[tid] = 1 atomic_add_return[0] = atomic_add(tvm.tir.call_pure_intrin("handle", "tvm_address_of", valid_count[i]), one_count) + + return ib.get() + + +def flag_scan(flag, prefix_sum): + """Low level IR to calculate correct positions for valid boxes. + + Parameters + ---------- + flag : Buffer + 2D Buffer of flag indicating valid data with shape [batch_size, num_anchors]. + + prefix_sum : Buffer + 2D Buffer of prefix sum of flags indicating new locations of valid boxes + with same shape as flag. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + batch_size = flag.shape[0] + num_anchors = flag.shape[1] + + ib = tvm.tir.ir_builder.create() + + flag = ib.buffer_ptr(flag) + prefix_sum = ib.buffer_ptr(prefix_sum) + + max_threads = int(tvm.target.Target.current( + allow_none=False).max_num_threads) + nthread_tx = max_threads + nthread_bx = batch_size * num_anchors // max_threads + 1 + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + idxd = tvm.tir.indexdiv + idxm = tvm.tir.indexmod + + # initialize prefix_sum + with ib.if_scope(tid < batch_size * num_anchors): + prefix_sum[tid] = 0 + with ib.if_scope(tid < batch_size * num_anchors): + i = idxd(tid, num_anchors) + j = idxm(tid, num_anchors) + with ib.for_range(0, j) as r: + prefix_sum[tid] += flag[i * num_anchors + r] + + return ib.get() + + +def out_rewrite(data, flag, prefix_sum, valid_count, out): + """Low level IR to move valid boxes to the + top of input data. + + Parameters + ---------- + data : Buffer + Input data. 3-D Buffer with shape [batch_size, num_anchors, elem_length]. + + flag : Buffer + 2D Buffer of flag indicating valid data with shape [batch_size, num_anchors]. + + prefix_sum : Buffer + 2D Buffer of prefix sum of flags indicating new locations of valid boxes + with same shape as flag. + + valid_count : Buffer + 1D buffer for valid number of boxes with shape [batch_size, ]. + + out : Buffer + Rearranged data buffer. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + batch_size = out.shape[0] + num_anchors = out.shape[1] + elem_length = out.shape[2] + + ib = tvm.tir.ir_builder.create() + + one = tvm.tir.const(1, dtype=out.dtype) + data = ib.buffer_ptr(data) + flag = ib.buffer_ptr(flag) + valid_count = ib.buffer_ptr(valid_count) + prefix_sum = ib.buffer_ptr(prefix_sum) + out = ib.buffer_ptr(out) + + max_threads = int(tvm.target.Target.current( + allow_none=False).max_num_threads) + nthread_tx = max_threads + nthread_bx = batch_size * num_anchors // max_threads + 1 + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + idxd = tvm.tir.indexdiv + idxm = tvm.tir.indexmod + + with ib.if_scope(tid < batch_size * num_anchors): + i = idxd(tid, num_anchors) + j = idxm(tid, num_anchors) + base_idx = i * num_anchors * elem_length + with ib.if_scope(tvm.tir.all(flag[tid] > 0, prefix_sum[tid] >= 0, + prefix_sum[tid] < num_anchors)): with ib.for_range(0, elem_length) as k: - out[tid * elem_length + k] = data[tid * elem_length + k] - with ib.else_scope(): + out[base_idx + prefix_sum[tid] * elem_length + + k] = data[tid * elem_length + k] + with ib.if_scope(j >= valid_count[i]): with ib.for_range(0, elem_length) as k: out[tid * elem_length + k] = -one @@ -150,23 +264,47 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): Rearranged data tensor. """ batch_size = data.shape[0] + num_anchors = data.shape[1] data_buf = tvm.tir.decl_buffer( data.shape, data.dtype, "data_buf", data_alignment=8) valid_count_buf = tvm.tir.decl_buffer( (batch_size,), "int32", "valid_count_buf", data_alignment=8) + temp_flag_buf = tvm.tir.decl_buffer( + (batch_size, num_anchors,), "int32", "temp_flag", data_alignment=8) + temp_partial_buf = tvm.tir.decl_buffer( + (batch_size, num_anchors), "int32", "temp_partial", data_alignment=8) out_buf = tvm.tir.decl_buffer( data.shape, data.dtype, "out_buf", data_alignment=8) - valid_count, out = \ - te.extern([(batch_size,), data.shape], [data], + valid_count, temp_flag = \ + te.extern([(batch_size,), (batch_size, num_anchors)], [data], lambda ins, outs: get_valid_counts_ir( ins[0], outs[0], outs[1], score_threshold, id_index, score_index), - dtype=["int32", data.dtype], + dtype=["int32", "int32"], in_buffers=[data_buf], - out_buffers=[valid_count_buf, out_buf], + out_buffers=[valid_count_buf, temp_flag_buf], name="get_valid_counts", tag="get_valid_counts_gpu") + temp_partial = \ + te.extern([(batch_size, num_anchors)], [temp_flag], + lambda ins, outs: flag_scan( + ins[0], outs[0]), + dtype=["int32"], + in_buffers=[temp_flag_buf], + out_buffers=[temp_partial_buf], + name="flag_scan") + + out = \ + te.extern([data.shape], [data, temp_flag, temp_partial, valid_count], + lambda ins, outs: out_rewrite( + ins[0], ins[1], ins[2], ins[3], outs[0]), + dtype=[data.dtype], + in_buffers=[data_buf, temp_flag_buf, + temp_partial_buf, valid_count_buf], + out_buffers=[out_buf], + name="out_rewrite") + return [valid_count, out] @@ -336,6 +474,117 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): return ib.get() +def invalid_to_bottom_pre(data, flag, idx): + """Low level IR to rearrange nms output to move all valid entries to top. + + Parameters + ---------- + data: Buffer + 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms. + + flag : Buffer + 1D Buffer of flag indicating valid data with [num_anchors]. + + idx : Buffer + 1D Buffer of valid data indices with [num_anchors]. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + batch_size = data.shape[0] + num_anchors = data.shape[1] + elem_length = data.shape[2] + + ib = tvm.tir.ir_builder.create() + + data = ib.buffer_ptr(data) + flag = ib.buffer_ptr(flag) + idx = ib.buffer_ptr(idx) + + max_threads = int(math.sqrt( + tvm.target.Target.current(allow_none=False).max_num_threads)) + nthread_tx = max_threads + nthread_bx = num_anchors // max_threads + 1 + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + j = bx * max_threads + tx + + with ib.for_range(0, batch_size, for_type="unroll") as i: + base_idx = i * num_anchors * elem_length + with ib.if_scope(j < num_anchors): + with ib.if_scope(data[base_idx + j * elem_length] >= 0): + flag[i * num_anchors + j] = 1 + idx[i * num_anchors + j] = 1 + with ib.else_scope(): + flag[i * num_anchors + j] = 0 + idx[i * num_anchors + j] = 0 + + with ib.if_scope(j < batch_size): + with ib.for_range(0, num_anchors) as k: + with ib.if_scope(k > 0): + idx[j * num_anchors + k] += idx[j * num_anchors + k - 1] + return ib.get() + + +def invalid_to_bottom_ir(data, flag, idx, out): + """Low level IR to rearrange nms output to move all valid entries to top. + + Parameters + ---------- + data: Buffer + 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms. + + flag : Buffer + 1D Buffer of flag indicating valid data with [num_anchors]. + + idx : Buffer + 1D Buffer of valid data indices with [num_anchors]. + + out : Buffer + 3D Buffer of rearranged nms output with shape [batch_size, num_anchors, elem_length]. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + batch_size = data.shape[0] + num_anchors = data.shape[1] + elem_length = data.shape[2] + + ib = tvm.tir.ir_builder.create() + + data = ib.buffer_ptr(data) + flag = ib.buffer_ptr(flag) + idx = ib.buffer_ptr(idx) + out = ib.buffer_ptr(out) + + max_threads = int(math.sqrt( + tvm.target.Target.current(allow_none=False).max_num_threads)) + nthread_tx = max_threads + nthread_bx = num_anchors // max_threads + 1 + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + j = bx * max_threads + tx + + with ib.for_range(0, batch_size, for_type="unroll") as i: + base_idx = i * num_anchors * elem_length + with ib.if_scope(j < num_anchors): + with ib.for_range(0, elem_length) as k: + out[base_idx + j * elem_length + k] = -1.0 + with ib.if_scope(flag[i * num_anchors + j] > 0): + with ib.for_range(0, elem_length) as k: + out[base_idx + (idx[i * num_anchors + j] - 1) * elem_length + k] \ + = data[base_idx + j * elem_length + k] + return ib.get() + + def non_max_suppression(data, valid_count, max_output_size=-1, iou_threshold=0.5, force_suppress=False, top_k=-1, coord_start=2, score_index=1, id_index=0, @@ -418,12 +667,8 @@ def non_max_suppression(data, valid_count, max_output_size=-1, score_shape = (batch_size, num_anchors) score_tensor = te.compute( score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE) - if tvm.get_global_func("tvm.contrib.thrust.sort_nms", allow_missing=True): - sort_tensor = argsort_thrust( - score_tensor, valid_count=None, axis=1, is_ascend=False, dtype=valid_count_dtype) - else: - sort_tensor = argsort( - score_tensor, valid_count=None, axis=1, is_ascend=False, dtype=valid_count_dtype) + sort_tensor = argsort( + score_tensor, valid_count=valid_count, axis=1, is_ascend=False) sort_tensor_buf = tvm.tir.decl_buffer(sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", data_alignment=8) @@ -431,6 +676,9 @@ def non_max_suppression(data, valid_count, max_output_size=-1, data_buf = tvm.tir.decl_buffer( data.shape, data.dtype, "data_buf", data_alignment=8) + out_buf = tvm.tir.decl_buffer( + data.shape, data.dtype, "out_buf", data_alignment=8) + out, box_indices = \ te.extern([data.shape, score_shape], [data, sort_tensor, valid_count], @@ -446,4 +694,30 @@ def non_max_suppression(data, valid_count, max_output_size=-1, if return_indices: return box_indices + if invalid_to_bottom: + output_buf = tvm.tir.decl_buffer( + data.shape, data.dtype, "output_buf", data_alignment=8) + temp_flag_buf = tvm.tir.decl_buffer( + score_shape, valid_count_dtype, "temp_flag", data_alignment=8) + temp_idx_buf = tvm.tir.decl_buffer( + score_shape, valid_count_dtype, "temp_idx", data_alignment=8) + temp_flag, temp_idx = te.extern([score_shape, score_shape], [out], + lambda ins, outs: invalid_to_bottom_pre( + ins[0], outs[0], outs[1]), + dtype=["int32", "int32"], + in_buffers=[out_buf], + out_buffers=[ + temp_flag_buf, temp_idx_buf], + name="invalid_to_bottom_phase_one") + + output = te.extern([data.shape], [out, temp_flag, temp_idx], + lambda ins, outs: invalid_to_bottom_ir( + ins[0], ins[1], ins[2], outs[0]), + dtype=[data.dtype], + in_buffers=[out_buf, temp_flag_buf, temp_idx_buf], + out_buffers=[output_buf], + name="invalid_to_bottom", + tag="invalid_to_bottom") + return output + return out