From eb142d3ee9bb16ddf8d37fdec10c1bcda209deaa Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 25 Dec 2020 07:22:00 +0900 Subject: [PATCH] integrate new cumsum change --- python/tvm/topi/cuda/__init__.py | 1 + python/tvm/topi/cuda/scan.py | 168 +++++++++++++++++++------------ 2 files changed, 103 insertions(+), 66 deletions(-) diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py index 23c625ae7ff7..f407e885d3e8 100644 --- a/python/tvm/topi/cuda/__init__.py +++ b/python/tvm/topi/cuda/__init__.py @@ -55,3 +55,4 @@ from .correlation import * from .sparse import * from .argwhere import * +from .scan import * diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 8d058d783b9a..566aad4d9957 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -22,24 +22,28 @@ from ..utils import ceil_div -def exclusive_sum_scan2d_ir(data, output): +def exclusive_sum_scan2d_ir(data, output, reduction=None): """ TODO """ - num_rows = data.shape[0] - scan_size = data.shape[1] + batch_size = data.shape[0] + num_anchors = data.shape[1] ib = tvm.tir.ir_builder.create() data = ib.buffer_ptr(data) output = ib.buffer_ptr(output) + if reduction is not None: + reduction = ib.buffer_ptr(reduction) + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + # Copy boxes to output with ib.new_scope(): nthread_tx = max_threads - nthread_bx = ceil_div(scan_size, max_threads) - nthread_by = num_rows + nthread_bx = ceil_div(num_anchors, max_threads) + nthread_by = batch_size tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") by = te.thread_axis("blockIdx.y") @@ -47,19 +51,18 @@ def exclusive_sum_scan2d_ir(data, output): ib.scope_attr(bx, "thread_extent", nthread_bx) ib.scope_attr(by, "thread_extent", nthread_by) tid = bx * nthread_tx + tx - with ib.if_scope(tid == 0): - output[by, 0] = 0 - with ib.else_scope(): - with ib.if_scope(tid < scan_size): - output[by, tid] = data[by, tid - 1] + with ib.if_scope(tid < num_anchors): + output[by, tid] = data[by, tid] nthread_tx = max_threads - nthread_bx = ceil_div(scan_size, max_threads) - nthread_by = num_rows + nthread_bx = ceil_div(num_anchors, max_threads) + nthread_by = batch_size - # Up Sweep of prefix sum + ## The following algorithm performs parallel exclusive scan to get + ## a tensor that can later be used to select valid indices + # Up Sweep of exclusive scan lim = tvm.tir.generic.cast( - tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(scan_size, "float64"))), "int64" + tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(num_anchors, "float64"))), "int64" ) with ib.for_range(0, lim, dtype="int64") as l2_width: width = 2 << l2_width @@ -71,7 +74,7 @@ def exclusive_sum_scan2d_ir(data, output): ib.scope_attr( bx, "thread_extent", - tvm.tir.generic.cast(ceil_div(scan_size, max_threads * width), "int32"), + tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"), ) tid = bx * nthread_tx + tx @@ -81,15 +84,25 @@ def exclusive_sum_scan2d_ir(data, output): middle = ib.allocate("int64", (1,), name="middle", scope="local") end = ib.allocate("int64", (1,), name="end", scope="local") start[0] = width * tid - with ib.if_scope(start[0] < scan_size): + with ib.if_scope(start[0] < num_anchors): middle[0] = start[0] + tvm.tir.indexdiv(width, 2) - end[0] = tvm.te.min(start[0] + width, scan_size) - with ib.if_scope(middle[0] < scan_size): - output[by * scan_size + end[0] - 1] += output[by * scan_size + middle[0] - 1] + end[0] = tvm.te.min(start[0] + width, num_anchors) + with ib.if_scope(middle[0] < num_anchors): + output[by * num_anchors + end[0] - 1] += output[ + by * num_anchors + middle[0] - 1 + ] + + # Down Sweep of exclusive scan + with ib.new_scope(): + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(bx, "thread_extent", batch_size) + with ib.if_scope(bx < batch_size): + if reduction is not None: + reduction[bx] = output[(bx + 1) * num_anchors - 1] + output[(bx + 1) * num_anchors - 1] = 0 - # Down Sweep of prefix sum - with ib.for_range(0, lim - 1, dtype="int64") as l2_width: - width = 2 << (lim - l2_width - 2) + with ib.for_range(0, lim, dtype="int64") as l2_width: + width = 2 << (lim - l2_width - 1) with ib.new_scope(): tx = te.thread_axis("threadIdx.x") @@ -98,7 +111,7 @@ def exclusive_sum_scan2d_ir(data, output): ib.scope_attr( bx, "thread_extent", - tvm.tir.generic.cast(ceil_div(scan_size, max_threads * width), "int32"), + tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"), ) tid = bx * nthread_tx + tx @@ -107,39 +120,19 @@ def exclusive_sum_scan2d_ir(data, output): start = ib.allocate("int64", (1,), name="start", scope="local") middle = ib.allocate("int64", (1,), name="middle", scope="local") end = ib.allocate("int64", (1,), name="end", scope="local") + tmp = ib.allocate("int32", (1,), name="end", scope="local") start[0] = width * tid - with ib.if_scope(tvm.tir.all(start[0] > 0, start[0] < scan_size)): + with ib.if_scope(tvm.tir.all(start[0] < num_anchors)): middle[0] = start[0] + tvm.tir.indexdiv(width, 2) - with ib.if_scope(middle[0] < scan_size): - output[by * scan_size + middle[0] - 1] += output[by * scan_size + start[0] - 1] + end[0] = tvm.tir.min(start[0] + width, num_anchors) + with ib.if_scope(middle[0] < num_anchors): + tmp[0] = output[by * num_anchors + middle[0] - 1] + output[by * num_anchors + middle[0] - 1] = output[by * num_anchors + end[0] - 1] + output[by * num_anchors + end[0] - 1] += tmp[0] return ib.get() -def is_thrust_available(): - """ - Test if thrust based scan ops are available. - """ - return get_global_func("tvm.contrib.thrust.sum_scan", allow_missing=True) is not None - - -def scan_thrust(data, exclusive=True): - data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) - output_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "output_buf", data_alignment=8) - return te.extern( - [data.shape], - [data], - lambda ins, outs: tvm.tir.call_packed( - "tvm.contrib.thrust.sum_scan", ins[0], outs[0], exclusive - ), - dtype=[data.dtype], - in_buffers=[data_buf], - out_buffers=[output_buf], - name="exclusive_sum_scan2d", - tag="exclusive_sum_scan2d_gpu", - ) - - def get_reduction_from_exclusive_scan_ir(data, data_ex_scan, reduction): """TODO""" batch_size = data.shape[0] @@ -185,6 +178,42 @@ def get_reduction_from_exclusive_scan(data, ex_scan_output): ) +def is_thrust_available(): + """ + Test if thrust based scan ops are available. + """ + return get_global_func("tvm.contrib.thrust.sum_scan", allow_missing=True) is not None + + +def scan_thrust(data, exclusive=True, return_reduction=False): + data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + output_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "output_buf", data_alignment=8) + output = te.extern( + [data.shape], + [data], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.thrust.sum_scan", ins[0], outs[0], exclusive + ), + dtype=[data.dtype], + in_buffers=[data_buf], + out_buffers=[output_buf], + name="exclusive_sum_scan2d", + tag="exclusive_sum_scan2d_gpu", + ) + + if return_reduction: + ndim = len(data.shape) + if ndim == 1: + output = expand_dims(output, axis=0) + reduction = get_reduction_from_exclusive_scan(data, output) + reduction = squeeze(reduction, 0) + else: + reduction = get_reduction_from_exclusive_scan(data, output) + return output, reduction + + return output + + def exclusive_scan(data, axis=-1, return_reduction=False): # TODO(masahi): support other binary associative operators ndim = len(data.shape) @@ -194,17 +223,27 @@ def exclusive_scan(data, axis=-1, return_reduction=False): target = tvm.target.Target.current() if target and target.kind.name == "cuda" and is_thrust_available(): - output = scan_thrust(data, exclusive=True) - if ndim == 1 and return_reduction: - output = expand_dims(data, axis=0) - else: - if ndim == 1: - data = expand_dims(data, axis=0) + return scan_thrust(data, exclusive=True, return_reduction=return_reduction) - data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) - output_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "output_buf", data_alignment=8) + if ndim == 1: + data = expand_dims(data, axis=0) + + data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + output_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "output_buf", data_alignment=8) - if ndim == 2: + if ndim == 2: + if return_reduction: + output, reduction = te.extern( + [data.shape, (data.shape[0],)], + [data], + lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0], outs[1]), + dtype=[data.dtype], + in_buffers=[data_buf], + out_buffers=[output_buf], + name="exclusive_scan", + tag="exclusive_scan_gpu", + ) + else: output = te.extern( [data.shape], [data], @@ -215,19 +254,16 @@ def exclusive_scan(data, axis=-1, return_reduction=False): name="exclusive_scan", tag="exclusive_scan_gpu", ) - else: - assert False, "Unsupported dimension {}".format(ndim) - - if return_reduction: - reduction = get_reduction_from_exclusive_scan(data, output) + reduction = None + else: + assert False, "Unsupported dimension {}".format(ndim) if ndim == 1: output = squeeze(output, 0) if return_reduction: reduction = squeeze(reduction, 0) - return output, reduction - return reduction if return_reduction: return output, reduction + return output