Skip to content

Commit

Permalink
integrate new cumsum change
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 24, 2020
1 parent f89684d commit eb142d3
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 66 deletions.
1 change: 1 addition & 0 deletions python/tvm/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,4 @@
from .correlation import *
from .sparse import *
from .argwhere import *
from .scan import *
168 changes: 102 additions & 66 deletions python/tvm/topi/cuda/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,44 +22,47 @@
from ..utils import ceil_div


def exclusive_sum_scan2d_ir(data, output):
def exclusive_sum_scan2d_ir(data, output, reduction=None):
"""
TODO
"""
num_rows = data.shape[0]
scan_size = data.shape[1]
batch_size = data.shape[0]
num_anchors = data.shape[1]

ib = tvm.tir.ir_builder.create()

data = ib.buffer_ptr(data)
output = ib.buffer_ptr(output)

if reduction is not None:
reduction = ib.buffer_ptr(reduction)

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)

# Copy boxes to output
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = ceil_div(scan_size, max_threads)
nthread_by = num_rows
nthread_bx = ceil_div(num_anchors, max_threads)
nthread_by = batch_size
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
by = te.thread_axis("blockIdx.y")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
ib.scope_attr(by, "thread_extent", nthread_by)
tid = bx * nthread_tx + tx
with ib.if_scope(tid == 0):
output[by, 0] = 0
with ib.else_scope():
with ib.if_scope(tid < scan_size):
output[by, tid] = data[by, tid - 1]
with ib.if_scope(tid < num_anchors):
output[by, tid] = data[by, tid]

nthread_tx = max_threads
nthread_bx = ceil_div(scan_size, max_threads)
nthread_by = num_rows
nthread_bx = ceil_div(num_anchors, max_threads)
nthread_by = batch_size

# Up Sweep of prefix sum
## The following algorithm performs parallel exclusive scan to get
## a tensor that can later be used to select valid indices
# Up Sweep of exclusive scan
lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(scan_size, "float64"))), "int64"
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(num_anchors, "float64"))), "int64"
)
with ib.for_range(0, lim, dtype="int64") as l2_width:
width = 2 << l2_width
Expand All @@ -71,7 +74,7 @@ def exclusive_sum_scan2d_ir(data, output):
ib.scope_attr(
bx,
"thread_extent",
tvm.tir.generic.cast(ceil_div(scan_size, max_threads * width), "int32"),
tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"),
)
tid = bx * nthread_tx + tx

Expand All @@ -81,15 +84,25 @@ def exclusive_sum_scan2d_ir(data, output):
middle = ib.allocate("int64", (1,), name="middle", scope="local")
end = ib.allocate("int64", (1,), name="end", scope="local")
start[0] = width * tid
with ib.if_scope(start[0] < scan_size):
with ib.if_scope(start[0] < num_anchors):
middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
end[0] = tvm.te.min(start[0] + width, scan_size)
with ib.if_scope(middle[0] < scan_size):
output[by * scan_size + end[0] - 1] += output[by * scan_size + middle[0] - 1]
end[0] = tvm.te.min(start[0] + width, num_anchors)
with ib.if_scope(middle[0] < num_anchors):
output[by * num_anchors + end[0] - 1] += output[
by * num_anchors + middle[0] - 1
]

# Down Sweep of exclusive scan
with ib.new_scope():
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(bx, "thread_extent", batch_size)
with ib.if_scope(bx < batch_size):
if reduction is not None:
reduction[bx] = output[(bx + 1) * num_anchors - 1]
output[(bx + 1) * num_anchors - 1] = 0

# Down Sweep of prefix sum
with ib.for_range(0, lim - 1, dtype="int64") as l2_width:
width = 2 << (lim - l2_width - 2)
with ib.for_range(0, lim, dtype="int64") as l2_width:
width = 2 << (lim - l2_width - 1)

with ib.new_scope():
tx = te.thread_axis("threadIdx.x")
Expand All @@ -98,7 +111,7 @@ def exclusive_sum_scan2d_ir(data, output):
ib.scope_attr(
bx,
"thread_extent",
tvm.tir.generic.cast(ceil_div(scan_size, max_threads * width), "int32"),
tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"),
)
tid = bx * nthread_tx + tx

Expand All @@ -107,39 +120,19 @@ def exclusive_sum_scan2d_ir(data, output):
start = ib.allocate("int64", (1,), name="start", scope="local")
middle = ib.allocate("int64", (1,), name="middle", scope="local")
end = ib.allocate("int64", (1,), name="end", scope="local")
tmp = ib.allocate("int32", (1,), name="end", scope="local")
start[0] = width * tid
with ib.if_scope(tvm.tir.all(start[0] > 0, start[0] < scan_size)):
with ib.if_scope(tvm.tir.all(start[0] < num_anchors)):
middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
with ib.if_scope(middle[0] < scan_size):
output[by * scan_size + middle[0] - 1] += output[by * scan_size + start[0] - 1]
end[0] = tvm.tir.min(start[0] + width, num_anchors)
with ib.if_scope(middle[0] < num_anchors):
tmp[0] = output[by * num_anchors + middle[0] - 1]
output[by * num_anchors + middle[0] - 1] = output[by * num_anchors + end[0] - 1]
output[by * num_anchors + end[0] - 1] += tmp[0]

return ib.get()


def is_thrust_available():
"""
Test if thrust based scan ops are available.
"""
return get_global_func("tvm.contrib.thrust.sum_scan", allow_missing=True) is not None


def scan_thrust(data, exclusive=True):
data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
output_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "output_buf", data_alignment=8)
return te.extern(
[data.shape],
[data],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.thrust.sum_scan", ins[0], outs[0], exclusive
),
dtype=[data.dtype],
in_buffers=[data_buf],
out_buffers=[output_buf],
name="exclusive_sum_scan2d",
tag="exclusive_sum_scan2d_gpu",
)


def get_reduction_from_exclusive_scan_ir(data, data_ex_scan, reduction):
"""TODO"""
batch_size = data.shape[0]
Expand Down Expand Up @@ -185,6 +178,42 @@ def get_reduction_from_exclusive_scan(data, ex_scan_output):
)


def is_thrust_available():
"""
Test if thrust based scan ops are available.
"""
return get_global_func("tvm.contrib.thrust.sum_scan", allow_missing=True) is not None


def scan_thrust(data, exclusive=True, return_reduction=False):
data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
output_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "output_buf", data_alignment=8)
output = te.extern(
[data.shape],
[data],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.thrust.sum_scan", ins[0], outs[0], exclusive
),
dtype=[data.dtype],
in_buffers=[data_buf],
out_buffers=[output_buf],
name="exclusive_sum_scan2d",
tag="exclusive_sum_scan2d_gpu",
)

if return_reduction:
ndim = len(data.shape)
if ndim == 1:
output = expand_dims(output, axis=0)
reduction = get_reduction_from_exclusive_scan(data, output)
reduction = squeeze(reduction, 0)
else:
reduction = get_reduction_from_exclusive_scan(data, output)
return output, reduction

return output


def exclusive_scan(data, axis=-1, return_reduction=False):
# TODO(masahi): support other binary associative operators
ndim = len(data.shape)
Expand All @@ -194,17 +223,27 @@ def exclusive_scan(data, axis=-1, return_reduction=False):

target = tvm.target.Target.current()
if target and target.kind.name == "cuda" and is_thrust_available():
output = scan_thrust(data, exclusive=True)
if ndim == 1 and return_reduction:
output = expand_dims(data, axis=0)
else:
if ndim == 1:
data = expand_dims(data, axis=0)
return scan_thrust(data, exclusive=True, return_reduction=return_reduction)

data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
output_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "output_buf", data_alignment=8)
if ndim == 1:
data = expand_dims(data, axis=0)

data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
output_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "output_buf", data_alignment=8)

if ndim == 2:
if ndim == 2:
if return_reduction:
output, reduction = te.extern(
[data.shape, (data.shape[0],)],
[data],
lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0], outs[1]),
dtype=[data.dtype],
in_buffers=[data_buf],
out_buffers=[output_buf],
name="exclusive_scan",
tag="exclusive_scan_gpu",
)
else:
output = te.extern(
[data.shape],
[data],
Expand All @@ -215,19 +254,16 @@ def exclusive_scan(data, axis=-1, return_reduction=False):
name="exclusive_scan",
tag="exclusive_scan_gpu",
)
else:
assert False, "Unsupported dimension {}".format(ndim)

if return_reduction:
reduction = get_reduction_from_exclusive_scan(data, output)
reduction = None
else:
assert False, "Unsupported dimension {}".format(ndim)

if ndim == 1:
output = squeeze(output, 0)
if return_reduction:
reduction = squeeze(reduction, 0)
return output, reduction
return reduction

if return_reduction:
return output, reduction

return output

0 comments on commit eb142d3

Please sign in to comment.