From ce95a825e09e17831af6663a42031c247b19ec7d Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Wed, 15 Apr 2020 08:32:50 -0700 Subject: [PATCH] [TOPI] Improve get_valid_count and nms performance for CUDA (#5339) * get_valid_count updated to have correct results * speedup nms * update nms * revert back nms * recover one test for get_valid_count --- python/tvm/relay/frontend/mxnet.py | 1 - tests/python/relay/test_op_level5.py | 3 + topi/python/topi/cuda/nms.py | 302 +---------------------- topi/python/topi/nn/deformable_conv2d.py | 2 +- topi/tests/python/test_topi_vision.py | 4 +- 5 files changed, 17 insertions(+), 295 deletions(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index e6aa5f1fb05d..4edf0b80de4c 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -853,7 +853,6 @@ def _mx_smooth_l1(inputs, attrs): def _mx_deformable_convolution(inputs, attrs): new_attrs = {} - assert attrs.get_bool("no_bias") new_attrs["kernel_size"] = attrs.get_int_tuple("kernel") new_attrs["strides"] = attrs.get_int_tuple("stride") new_attrs["padding"] = attrs.get_int_tuple("pad") diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 2dc549f2f354..b29b69667653 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -225,6 +225,9 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): intrp = relay.create_executor("debug", ctx=ctx, target=target) out = intrp.evaluate(func)(np_data) tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3, atol=1e-04) + # get_valid_count for cuda doesn't do data rearrangement + if target == 'cuda': + return tvm.testing.assert_allclose(out[1].asnumpy(), np_out2, rtol=1e-3, atol=1e-04) verify_get_valid_counts((1, 2500, 6), 0, 0, 1) diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index d295116e72fd..d8be3bd1b886 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -17,7 +17,6 @@ # pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison # pylint: disable=bad-continuation, unused-argument """Non-maximum suppression operator""" -import math import tvm from tvm import te @@ -44,7 +43,7 @@ def atomic_add(x, y): return tvm.tir.call_pure_intrin(y.dtype, "atomic_add", x, y) -def get_valid_counts_ir(data, valid_count, flag, score_threshold, id_index, score_index): +def get_valid_counts_ir(data, valid_count, out, score_threshold, id_index, score_index): """Low level IR to get valid count of bounding boxes given a score threshold. Also prepares to move valid boxes to the top of input data. @@ -83,10 +82,11 @@ def get_valid_counts_ir(data, valid_count, flag, score_threshold, id_index, scor data = ib.buffer_ptr(data) valid_count = ib.buffer_ptr(valid_count) - flag = ib.buffer_ptr(flag) + out = ib.buffer_ptr(out) atomic_add_return = ib.allocate( valid_count.dtype, (1,), name='atomic_add_return', scope='local') one_count = tvm.tir.const(1, dtype=valid_count.dtype) + one = tvm.tir.const(1, dtype=out.dtype) score_threshold = tvm.ir.make_node( "FloatImm", dtype="float32", value=score_threshold) id_index = tvm.ir.make_node("IntImm", dtype="int32", value=id_index) @@ -106,132 +106,16 @@ def get_valid_counts_ir(data, valid_count, flag, score_threshold, id_index, scor # initialize valid_count with ib.if_scope(tid < batch_size): valid_count[tid] = 0 - # initialize flag - with ib.if_scope(tid < batch_size * num_anchors): - flag[tid] = 0 with ib.if_scope(tid < batch_size * num_anchors): i = idxd(tid, num_anchors) with ib.if_scope( tvm.tir.all(data[tid * elem_length + score_index] > score_threshold, tvm.tir.any(id_index < 0, data[tid * elem_length + id_index] >= 0))): - flag[tid] = 1 atomic_add_return[0] = atomic_add(tvm.tir.call_pure_intrin("handle", "tvm_address_of", valid_count[i]), one_count) - - return ib.get() - - -def flag_scan(flag, prefix_sum): - """Low level IR to calculate correct positions for valid boxes. - - Parameters - ---------- - flag : Buffer - 2D Buffer of flag indicating valid data with shape [batch_size, num_anchors]. - - prefix_sum : Buffer - 2D Buffer of prefix sum of flags indicating new locations of valid boxes - with same shape as flag. - - Returns - ------- - stmt : Stmt - The result IR statement. - """ - batch_size = flag.shape[0] - num_anchors = flag.shape[1] - - ib = tvm.tir.ir_builder.create() - - flag = ib.buffer_ptr(flag) - prefix_sum = ib.buffer_ptr(prefix_sum) - - max_threads = int(tvm.target.Target.current( - allow_none=False).max_num_threads) - nthread_tx = max_threads - nthread_bx = batch_size * num_anchors // max_threads + 1 - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - tid = bx * max_threads + tx - idxd = tvm.tir.indexdiv - idxm = tvm.tir.indexmod - - # initialize prefix_sum - with ib.if_scope(tid < batch_size * num_anchors): - prefix_sum[tid] = 0 - with ib.if_scope(tid < batch_size * num_anchors): - i = idxd(tid, num_anchors) - j = idxm(tid, num_anchors) - with ib.for_range(0, j) as r: - prefix_sum[tid] += flag[i * num_anchors + r] - - return ib.get() - - -def out_rewrite(data, flag, prefix_sum, valid_count, out): - """Low level IR to move valid boxes to the - top of input data. - - Parameters - ---------- - data : Buffer - Input data. 3-D Buffer with shape [batch_size, num_anchors, elem_length]. - - flag : Buffer - 2D Buffer of flag indicating valid data with shape [batch_size, num_anchors]. - - prefix_sum : Buffer - 2D Buffer of prefix sum of flags indicating new locations of valid boxes - with same shape as flag. - - valid_count : Buffer - 1D buffer for valid number of boxes with shape [batch_size, ]. - - out : Buffer - Rearranged data buffer. - - Returns - ------- - stmt : Stmt - The result IR statement. - """ - batch_size = out.shape[0] - num_anchors = out.shape[1] - elem_length = out.shape[2] - - ib = tvm.tir.ir_builder.create() - - one = tvm.tir.const(1, dtype=out.dtype) - data = ib.buffer_ptr(data) - flag = ib.buffer_ptr(flag) - valid_count = ib.buffer_ptr(valid_count) - prefix_sum = ib.buffer_ptr(prefix_sum) - out = ib.buffer_ptr(out) - - max_threads = int(tvm.target.Target.current( - allow_none=False).max_num_threads) - nthread_tx = max_threads - nthread_bx = batch_size * num_anchors // max_threads + 1 - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - tid = bx * max_threads + tx - idxd = tvm.tir.indexdiv - idxm = tvm.tir.indexmod - - with ib.if_scope(tid < batch_size * num_anchors): - i = idxd(tid, num_anchors) - j = idxm(tid, num_anchors) - base_idx = i * num_anchors * elem_length - with ib.if_scope(tvm.tir.all(flag[tid] > 0, prefix_sum[tid] >= 0, - prefix_sum[tid] < num_anchors)): with ib.for_range(0, elem_length) as k: - out[base_idx + prefix_sum[tid] * elem_length + - k] = data[tid * elem_length + k] - with ib.if_scope(j >= valid_count[i]): + out[tid * elem_length + k] = data[tid * elem_length + k] + with ib.else_scope(): with ib.for_range(0, elem_length) as k: out[tid * elem_length + k] = -one @@ -265,47 +149,23 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): Rearranged data tensor. """ batch_size = data.shape[0] - num_anchors = data.shape[1] data_buf = tvm.tir.decl_buffer( data.shape, data.dtype, "data_buf", data_alignment=8) valid_count_buf = tvm.tir.decl_buffer( (batch_size,), "int32", "valid_count_buf", data_alignment=8) - temp_flag_buf = tvm.tir.decl_buffer( - (batch_size, num_anchors,), "int32", "temp_flag", data_alignment=8) - temp_partial_buf = tvm.tir.decl_buffer( - (batch_size, num_anchors), "int32", "temp_partial", data_alignment=8) out_buf = tvm.tir.decl_buffer( data.shape, data.dtype, "out_buf", data_alignment=8) - valid_count, temp_flag = \ - te.extern([(batch_size,), (batch_size, num_anchors)], [data], + valid_count, out = \ + te.extern([(batch_size,), data.shape], [data], lambda ins, outs: get_valid_counts_ir( ins[0], outs[0], outs[1], score_threshold, id_index, score_index), - dtype=["int32", "int32"], + dtype=["int32", data.dtype], in_buffers=[data_buf], - out_buffers=[valid_count_buf, temp_flag_buf], + out_buffers=[valid_count_buf, out_buf], name="get_valid_counts", tag="get_valid_counts_gpu") - temp_partial = \ - te.extern([(batch_size, num_anchors)], [temp_flag], - lambda ins, outs: flag_scan( - ins[0], outs[0]), - dtype=["int32"], - in_buffers=[temp_flag_buf], - out_buffers=[temp_partial_buf], - name="flag_scan") - - out = \ - te.extern([data.shape], [data, temp_flag, temp_partial, valid_count], - lambda ins, outs: out_rewrite( - ins[0], ins[1], ins[2], ins[3], outs[0]), - dtype=[data.dtype], - in_buffers=[data_buf, temp_flag_buf, - temp_partial_buf, valid_count_buf], - out_buffers=[out_buf], - name="out_rewrite") - return [valid_count, out] @@ -475,117 +335,6 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): return ib.get() -def invalid_to_bottom_pre(data, flag, idx): - """Low level IR to rearrange nms output to move all valid entries to top. - - Parameters - ---------- - data: Buffer - 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms. - - flag : Buffer - 1D Buffer of flag indicating valid data with [num_anchors]. - - idx : Buffer - 1D Buffer of valid data indices with [num_anchors]. - - Returns - ------- - stmt : Stmt - The result IR statement. - """ - batch_size = data.shape[0] - num_anchors = data.shape[1] - elem_length = data.shape[2] - - ib = tvm.tir.ir_builder.create() - - data = ib.buffer_ptr(data) - flag = ib.buffer_ptr(flag) - idx = ib.buffer_ptr(idx) - - max_threads = int(math.sqrt( - tvm.target.Target.current(allow_none=False).max_num_threads)) - nthread_tx = max_threads - nthread_bx = num_anchors // max_threads + 1 - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - j = bx * max_threads + tx - - with ib.for_range(0, batch_size, for_type="unroll") as i: - base_idx = i * num_anchors * elem_length - with ib.if_scope(j < num_anchors): - with ib.if_scope(data[base_idx + j * elem_length] >= 0): - flag[i * num_anchors + j] = 1 - idx[i * num_anchors + j] = 1 - with ib.else_scope(): - flag[i * num_anchors + j] = 0 - idx[i * num_anchors + j] = 0 - - with ib.if_scope(j < batch_size): - with ib.for_range(0, num_anchors) as k: - with ib.if_scope(k > 0): - idx[j * num_anchors + k] += idx[j * num_anchors + k - 1] - return ib.get() - - -def invalid_to_bottom_ir(data, flag, idx, out): - """Low level IR to rearrange nms output to move all valid entries to top. - - Parameters - ---------- - data: Buffer - 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms. - - flag : Buffer - 1D Buffer of flag indicating valid data with [num_anchors]. - - idx : Buffer - 1D Buffer of valid data indices with [num_anchors]. - - out : Buffer - 3D Buffer of rearranged nms output with shape [batch_size, num_anchors, elem_length]. - - Returns - ------- - stmt : Stmt - The result IR statement. - """ - batch_size = data.shape[0] - num_anchors = data.shape[1] - elem_length = data.shape[2] - - ib = tvm.tir.ir_builder.create() - - data = ib.buffer_ptr(data) - flag = ib.buffer_ptr(flag) - idx = ib.buffer_ptr(idx) - out = ib.buffer_ptr(out) - - max_threads = int(math.sqrt( - tvm.target.Target.current(allow_none=False).max_num_threads)) - nthread_tx = max_threads - nthread_bx = num_anchors // max_threads + 1 - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - j = bx * max_threads + tx - - with ib.for_range(0, batch_size, for_type="unroll") as i: - base_idx = i * num_anchors * elem_length - with ib.if_scope(j < num_anchors): - with ib.for_range(0, elem_length) as k: - out[base_idx + j * elem_length + k] = -1.0 - with ib.if_scope(flag[i * num_anchors + j] > 0): - with ib.for_range(0, elem_length) as k: - out[base_idx + (idx[i * num_anchors + j] - 1) * elem_length + k] \ - = data[base_idx + j * elem_length + k] - return ib.get() - - def non_max_suppression(data, valid_count, max_output_size=-1, iou_threshold=0.5, force_suppress=False, top_k=-1, coord_start=2, score_index=1, id_index=0, @@ -670,10 +419,10 @@ def non_max_suppression(data, valid_count, max_output_size=-1, score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE) if tvm.get_global_func("tvm.contrib.thrust.sort_nms", allow_missing=True): sort_tensor = argsort_thrust( - score_tensor, valid_count=valid_count, axis=1, is_ascend=False) + score_tensor, valid_count=None, axis=1, is_ascend=False, dtype=valid_count_dtype) else: sort_tensor = argsort( - score_tensor, valid_count=valid_count, axis=1, is_ascend=False) + score_tensor, valid_count=None, axis=1, is_ascend=False, dtype=valid_count_dtype) sort_tensor_buf = tvm.tir.decl_buffer(sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", data_alignment=8) @@ -681,9 +430,6 @@ def non_max_suppression(data, valid_count, max_output_size=-1, data_buf = tvm.tir.decl_buffer( data.shape, data.dtype, "data_buf", data_alignment=8) - out_buf = tvm.tir.decl_buffer( - data.shape, data.dtype, "out_buf", data_alignment=8) - out, box_indices = \ te.extern([data.shape, score_shape], [data, sort_tensor, valid_count], @@ -699,30 +445,4 @@ def non_max_suppression(data, valid_count, max_output_size=-1, if return_indices: return box_indices - if invalid_to_bottom: - output_buf = tvm.tir.decl_buffer( - data.shape, data.dtype, "output_buf", data_alignment=8) - temp_flag_buf = tvm.tir.decl_buffer( - score_shape, valid_count_dtype, "temp_flag", data_alignment=8) - temp_idx_buf = tvm.tir.decl_buffer( - score_shape, valid_count_dtype, "temp_idx", data_alignment=8) - temp_flag, temp_idx = te.extern([score_shape, score_shape], [out], - lambda ins, outs: invalid_to_bottom_pre( - ins[0], outs[0], outs[1]), - dtype=["int32", "int32"], - in_buffers=[out_buf], - out_buffers=[ - temp_flag_buf, temp_idx_buf], - name="invalid_to_bottom_phase_one") - - output = te.extern([data.shape], [out, temp_flag, temp_idx], - lambda ins, outs: invalid_to_bottom_ir( - ins[0], ins[1], ins[2], outs[0]), - dtype=[data.dtype], - in_buffers=[out_buf, temp_flag_buf, temp_idx_buf], - out_buffers=[output_buf], - name="invalid_to_bottom", - tag="invalid_to_bottom") - return output - return out diff --git a/topi/python/topi/nn/deformable_conv2d.py b/topi/python/topi/nn/deformable_conv2d.py index 9f95fd1ae790..39be6d6d861f 100644 --- a/topi/python/topi/nn/deformable_conv2d.py +++ b/topi/python/topi/nn/deformable_conv2d.py @@ -106,7 +106,7 @@ def _bilinear(n, c, h, w): (kh * kernel_w + kw) * 2, y, x], x * stride_w - pad_left + kw * dilation_w + offset[n, c // ic_per_dgroup * (kernel_w*kernel_h*2) + - (kh * kernel_w + kw) * 2 + 1, y, x])) + (kh * kernel_w + kw) * 2 + 1, y, x]), tag="data_deform") return te.compute( (batch, out_channel, out_height, out_width), lambda n, f, y, x: te.sum( diff --git a/topi/tests/python/test_topi_vision.py b/topi/tests/python/test_topi_vision.py index fe94a4ca9138..3ccb44d0f47c 100644 --- a/topi/tests/python/test_topi_vision.py +++ b/topi/tests/python/test_topi_vision.py @@ -106,8 +106,8 @@ def check_device(device): """ Skip this test as it is intermittent see https://github.com/apache/incubator-tvm/pull/4901#issuecomment-595040094 for device in ['llvm', 'cuda', 'opencl']: - # Disable opencl test for now - if device != "llvm" and device != "cuda": + # Disable gpu test for now + if device != "llvm": continue check_device(device) """