From 9876c901ee8b406bc9d75ba91c4734d55f85811b Mon Sep 17 00:00:00 2001 From: masa Date: Fri, 18 Dec 2020 20:55:14 +0900 Subject: [PATCH] introduce prefix_scan.py and move scan ir in nms.py --- python/tvm/topi/cuda/nms.py | 130 ++++----------------------- python/tvm/topi/cuda/prefix_scan.py | 132 ++++++++++++++++++++++++++++ 2 files changed, 148 insertions(+), 114 deletions(-) create mode 100644 python/tvm/topi/cuda/prefix_scan.py diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 72e5f6c42e78..fcf93f705df7 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -22,6 +22,7 @@ from tvm.tir import if_then_else from .sort import argsort, argsort_thrust +from .prefix_scan import exclusive_sum_scan2d def cuda_atomic_add_rule(op): @@ -123,122 +124,22 @@ def get_valid_boxes_ir(data, valid_boxes, score_threshold, id_index, score_index return ib.get() -def get_valid_indices_ir(valid_boxes, valid_count, valid_indices): - """Low level IR to get the ouput indices of valid boxes - and the count of valid boxes - - Parameters - ---------- - valid_boxes: Buffer - 2D Buffer indicating valid boxes with shape [batch_size, num_anchors]. - - Returns - ------- - valid_count: Buffer - 1D Buffer of number of valid boxes per batch [batch_size]. - - valid_indices: Buffer - 2D Buffer indicating output sorted indcies of valid boxes [batch_size, num_anchors]. - """ +def get_num_valid_boxes_ir(valid_boxes, valid_boxes_ex_scan, valid_count): + """TODO""" batch_size = valid_boxes.shape[0] num_anchors = valid_boxes.shape[1] ib = tvm.tir.ir_builder.create() valid_boxes = ib.buffer_ptr(valid_boxes) - + valid_boxes_ex_scan = ib.buffer_ptr(valid_boxes_ex_scan) valid_count = ib.buffer_ptr(valid_count) - valid_indices = ib.buffer_ptr(valid_indices) max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) def ceil_div(a, b): return tvm.tir.indexdiv(a + b - 1, b) - # Copy boxes to valid_indices - with ib.new_scope(): - nthread_tx = max_threads - nthread_bx = ceil_div(num_anchors, max_threads) - nthread_by = batch_size - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - by = te.thread_axis("blockIdx.y") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - ib.scope_attr(by, "thread_extent", nthread_by) - tid = bx * nthread_tx + tx - with ib.if_scope(tid == 0): - valid_indices[by, 0] = 0 - with ib.else_scope(): - with ib.if_scope(tid < num_anchors): - valid_indices[by, tid] = valid_boxes[by, tid - 1] - - nthread_tx = max_threads - nthread_bx = ceil_div(num_anchors, max_threads) - nthread_by = batch_size - - ## The following algorithm performs parallel prefix sum to get - ## a tensor that can later be used to select valid indices - # Up Sweep of prefix sum - lim = tvm.tir.generic.cast( - tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(num_anchors, "float64"))), "int64" - ) - with ib.for_range(0, lim, dtype="int64") as l2_width: - width = 2 << l2_width - - with ib.new_scope(): - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr( - bx, - "thread_extent", - tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"), - ) - tid = bx * nthread_tx + tx - - by = te.thread_axis("blockIdx.y") - ib.scope_attr(by, "thread_extent", nthread_by) - start = ib.allocate("int64", (1,), name="start", scope="local") - middle = ib.allocate("int64", (1,), name="middle", scope="local") - end = ib.allocate("int64", (1,), name="end", scope="local") - start[0] = width * tid - with ib.if_scope(start[0] < num_anchors): - middle[0] = start[0] + tvm.tir.indexdiv(width, 2) - end[0] = tvm.te.min(start[0] + width, num_anchors) - with ib.if_scope(middle[0] < num_anchors): - valid_indices[by * num_anchors + end[0] - 1] += valid_indices[ - by * num_anchors + middle[0] - 1 - ] - - # Down Sweep of prefix sum - with ib.for_range(0, lim - 1, dtype="int64") as l2_width: - width = 2 << (lim - l2_width - 2) - - with ib.new_scope(): - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr( - bx, - "thread_extent", - tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"), - ) - tid = bx * nthread_tx + tx - - by = te.thread_axis("blockIdx.y") - ib.scope_attr(by, "thread_extent", nthread_by) - start = ib.allocate("int64", (1,), name="start", scope="local") - middle = ib.allocate("int64", (1,), name="middle", scope="local") - end = ib.allocate("int64", (1,), name="end", scope="local") - start[0] = width * tid - with ib.if_scope(tvm.tir.all(start[0] > 0, start[0] < num_anchors)): - middle[0] = start[0] + tvm.tir.indexdiv(width, 2) - with ib.if_scope(middle[0] < num_anchors): - valid_indices[by * num_anchors + middle[0] - 1] += valid_indices[ - by * num_anchors + start[0] - 1 - ] - ## Write Sum to valid_count max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) with ib.new_scope(): @@ -250,10 +151,8 @@ def ceil_div(a, b): ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * max_threads + tx with ib.if_scope(tid < batch_size): - # Add valid_boxes[tid, num_anchors - 1] because valid_indices is - # an exclusive scan of valid_boxes valid_count[tid] = ( - valid_indices[tid, num_anchors - 1] + valid_boxes[tid, num_anchors - 1] + valid_boxes_ex_scan[tid, num_anchors - 1] + valid_boxes[tid, num_anchors - 1] ) return ib.get() @@ -388,15 +287,18 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): valid_count_buf = tvm.tir.decl_buffer( (batch_size,), "int32", "valid_count_buf", data_alignment=8 ) - valid_count, valid_indices = te.extern( - [(batch_size,), (batch_size, num_anchors)], - [valid_boxes], - lambda ins, outs: get_valid_indices_ir(ins[0], outs[0], outs[1]), + + valid_indices = exclusive_sum_scan2d(valid_boxes) + + valid_count = te.extern( + [(batch_size,)], + [valid_boxes, valid_indices], + lambda ins, outs: get_num_valid_boxes_ir(ins[0], ins[1], outs[0]), dtype=["int32"], - in_buffers=[valid_boxes_buf], - out_buffers=[valid_count_buf, valid_indices_buf], - name="get_valid_indices", - tag="get_valid_indices_gpu", + in_buffers=[valid_boxes_buf, valid_indices_buf], + out_buffers=[valid_count_buf], + name="get_valid_indices_sum", + tag="get_valid_indices_sum_gpu", ) out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf", data_alignment=8) diff --git a/python/tvm/topi/cuda/prefix_scan.py b/python/tvm/topi/cuda/prefix_scan.py new file mode 100644 index 000000000000..b4815cb2531a --- /dev/null +++ b/python/tvm/topi/cuda/prefix_scan.py @@ -0,0 +1,132 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"Scan related operators" +import tvm +from tvm import te + + +def exclusive_sum_scan2d_ir(data, output): + """ + TODO + """ + num_rows = data.shape[0] + scan_size = data.shape[1] + + ib = tvm.tir.ir_builder.create() + + data = ib.buffer_ptr(data) + output = ib.buffer_ptr(output) + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + + def ceil_div(a, b): + return tvm.tir.indexdiv(a + b - 1, b) + + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(scan_size, max_threads) + nthread_by = num_rows + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + ib.scope_attr(by, "thread_extent", nthread_by) + tid = bx * nthread_tx + tx + with ib.if_scope(tid == 0): + output[by, 0] = 0 + with ib.else_scope(): + with ib.if_scope(tid < scan_size): + output[by, tid] = data[by, tid - 1] + + nthread_tx = max_threads + nthread_bx = ceil_div(scan_size, max_threads) + nthread_by = num_rows + + # Up Sweep of prefix sum + lim = tvm.tir.generic.cast( + tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(scan_size, "float64"))), "int64" + ) + with ib.for_range(0, lim, dtype="int64") as l2_width: + width = 2 << l2_width + + with ib.new_scope(): + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr( + bx, + "thread_extent", + tvm.tir.generic.cast(ceil_div(scan_size, max_threads * width), "int32"), + ) + tid = bx * nthread_tx + tx + + by = te.thread_axis("blockIdx.y") + ib.scope_attr(by, "thread_extent", nthread_by) + start = ib.allocate("int64", (1,), name="start", scope="local") + middle = ib.allocate("int64", (1,), name="middle", scope="local") + end = ib.allocate("int64", (1,), name="end", scope="local") + start[0] = width * tid + with ib.if_scope(start[0] < scan_size): + middle[0] = start[0] + tvm.tir.indexdiv(width, 2) + end[0] = tvm.te.min(start[0] + width, scan_size) + with ib.if_scope(middle[0] < scan_size): + output[by * scan_size + end[0] - 1] += output[by * scan_size + middle[0] - 1] + + # Down Sweep of prefix sum + with ib.for_range(0, lim - 1, dtype="int64") as l2_width: + width = 2 << (lim - l2_width - 2) + + with ib.new_scope(): + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr( + bx, + "thread_extent", + tvm.tir.generic.cast(ceil_div(scan_size, max_threads * width), "int32"), + ) + tid = bx * nthread_tx + tx + + by = te.thread_axis("blockIdx.y") + ib.scope_attr(by, "thread_extent", nthread_by) + start = ib.allocate("int64", (1,), name="start", scope="local") + middle = ib.allocate("int64", (1,), name="middle", scope="local") + end = ib.allocate("int64", (1,), name="end", scope="local") + start[0] = width * tid + with ib.if_scope(tvm.tir.all(start[0] > 0, start[0] < scan_size)): + middle[0] = start[0] + tvm.tir.indexdiv(width, 2) + with ib.if_scope(middle[0] < scan_size): + output[by * scan_size + middle[0] - 1] += output[by * scan_size + start[0] - 1] + + return ib.get() + + +def exclusive_sum_scan2d(data): + data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + output_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "output_buf", data_alignment=8) + + return te.extern( + [data.shape], + [data], + lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0]), + dtype=[data.dtype], + in_buffers=[data_buf], + out_buffers=[output_buf], + name="exclusive_sum_scan2d", + tag="exclusive_sum_scan2d_gpu", + )