Skip to content

Commit

Permalink
add thrust scan python stub
Browse files Browse the repository at this point in the history
  • Loading branch information
masa authored and masahi committed Dec 24, 2020
1 parent 9876c90 commit 65634e8
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 5 deletions.
4 changes: 2 additions & 2 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from tvm.tir import if_then_else
from .sort import argsort, argsort_thrust
from .prefix_scan import exclusive_sum_scan2d
from .prefix_scan import exclusive_scan


def cuda_atomic_add_rule(op):
Expand Down Expand Up @@ -288,7 +288,7 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1):
(batch_size,), "int32", "valid_count_buf", data_alignment=8
)

valid_indices = exclusive_sum_scan2d(valid_boxes)
valid_indices = exclusive_scan(valid_boxes, axis=1)

valid_count = te.extern(
[(batch_size,)],
Expand Down
52 changes: 49 additions & 3 deletions python/tvm/topi/cuda/prefix_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
"Scan related operators"
import tvm
from tvm import te
from tvm._ffi import get_global_func
from ..transform import expand_dims, squeeze


def exclusive_sum_scan2d_ir(data, output):
Expand Down Expand Up @@ -116,17 +118,61 @@ def ceil_div(a, b):
return ib.get()


def exclusive_sum_scan2d(data):
def is_thrust_available():
"""
Test if thrust based scan op is available.
"""
return get_global_func("tvm.contrib.thrust.scan", allow_missing=True) is not None


def scan_thrust(data, axis, exclusive=True):
data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
output_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "output_buf", data_alignment=8)

return te.extern(
[data.shape],
[data],
lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0]),
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.thrust.scan", ins[0], outs[0], axis, exclusive
),
dtype=[data.dtype],
in_buffers=[data_buf],
out_buffers=[output_buf],
name="exclusive_sum_scan2d",
tag="exclusive_sum_scan2d_gpu",
)


def exclusive_scan(data, axis=-1):
# TODO(masahi): support other binary associative operators
# TODO(masahi): support inclusive scan
ndim = len(data.shape)
if axis < 0:
axis += ndim
assert axis == ndim - 1, "Only support scan on the inner most axis."

target = tvm.target.Target.current()
if target and target.kind.name == "cuda" and is_thrust_available():
return scan_thrust(data, axis)

if ndim == 1:
data = expand_dims(data, axis=0)

data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
output_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "output_buf", data_alignment=8)

if ndim == 2:
output = te.extern(
[data.shape],
[data],
lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0]),
dtype=[data.dtype],
in_buffers=[data_buf],
out_buffers=[output_buf],
name="exclusive_scan",
tag="exclusive_scan_gpu",
)
if ndim == 1:
return squeeze(output, 0)
return output
else:
assert False, "Unsupported dimension {}".format(ndim)

0 comments on commit 65634e8

Please sign in to comment.