diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 0c01cc9fbbdf..32691da90ecc 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -22,6 +22,8 @@ from tvm.tir import if_then_else from .sort import argsort, argsort_thrust, is_thrust_available +from .scan import exclusive_scan +from ..utils import ceil_div def cuda_atomic_add_rule(op): @@ -51,10 +53,6 @@ def atomic_add(x, y): return tvm.tir.call_intrin(y.dtype, "tir.atomic_add", x, y) -def ceil_div(a, b): - return tvm.tir.indexdiv(a + b - 1, b) - - def get_valid_boxes_ir(data, valid_boxes, score_threshold, id_index, score_index): """Low level IR to identify bounding boxes given a score threshold. @@ -123,136 +121,6 @@ def get_valid_boxes_ir(data, valid_boxes, score_threshold, id_index, score_index return ib.get() -def get_valid_indices_ir(valid_boxes, valid_count, valid_indices): - """Low level IR to get the ouput indices of valid boxes - and the count of valid boxes - - Parameters - ---------- - valid_boxes: Buffer - 2D Buffer indicating valid boxes with shape [batch_size, num_anchors]. - - Returns - ------- - valid_count: Buffer - 1D Buffer of number of valid boxes per batch [batch_size]. - - valid_indices: Buffer - 2D Buffer indicating output sorted indcies of valid boxes [batch_size, num_anchors]. - """ - batch_size = valid_boxes.shape[0] - num_anchors = valid_boxes.shape[1] - - ib = tvm.tir.ir_builder.create() - - valid_boxes = ib.buffer_ptr(valid_boxes) - - valid_count = ib.buffer_ptr(valid_count) - valid_indices = ib.buffer_ptr(valid_indices) - - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - with ib.if_scope(num_anchors > 0): - # Copy boxes to valid_indices - with ib.new_scope(): - nthread_tx = max_threads - nthread_bx = ceil_div(num_anchors, max_threads) - nthread_by = batch_size - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - by = te.thread_axis("blockIdx.y") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - ib.scope_attr(by, "thread_extent", nthread_by) - tid = bx * nthread_tx + tx - with ib.if_scope(tid < num_anchors): - valid_indices[by, tid] = valid_boxes[by, tid] - - nthread_tx = max_threads - nthread_bx = ceil_div(num_anchors, max_threads) - nthread_by = batch_size - - ## The following algorithm performs parallel exclusive scan to get - ## a tensor that can later be used to select valid indices - # Up Sweep of exclusive scan - lim = tvm.tir.generic.cast( - tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(num_anchors, "float64"))), "int64" - ) - with ib.for_range(0, lim, dtype="int64") as l2_width: - width = 2 << l2_width - - with ib.new_scope(): - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr( - bx, - "thread_extent", - tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"), - ) - tid = bx * nthread_tx + tx - - by = te.thread_axis("blockIdx.y") - ib.scope_attr(by, "thread_extent", nthread_by) - start = ib.allocate("int64", (1,), name="start", scope="local") - middle = ib.allocate("int64", (1,), name="middle", scope="local") - end = ib.allocate("int64", (1,), name="end", scope="local") - start[0] = width * tid - with ib.if_scope(start[0] < num_anchors): - middle[0] = start[0] + tvm.tir.indexdiv(width, 2) - end[0] = tvm.te.min(start[0] + width, num_anchors) - with ib.if_scope(middle[0] < num_anchors): - valid_indices[by * num_anchors + end[0] - 1] += valid_indices[ - by * num_anchors + middle[0] - 1 - ] - - # Down Sweep of exclusive scan - with ib.new_scope(): - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(bx, "thread_extent", batch_size) - with ib.if_scope(bx < batch_size): - valid_count[bx] = valid_indices[(bx + 1) * num_anchors - 1] - valid_indices[(bx + 1) * num_anchors - 1] = 0 - - with ib.for_range(0, lim, dtype="int64") as l2_width: - width = 2 << (lim - l2_width - 1) - - with ib.new_scope(): - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr( - bx, - "thread_extent", - tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"), - ) - tid = bx * nthread_tx + tx - - by = te.thread_axis("blockIdx.y") - ib.scope_attr(by, "thread_extent", nthread_by) - start = ib.allocate("int64", (1,), name="start", scope="local") - middle = ib.allocate("int64", (1,), name="middle", scope="local") - end = ib.allocate("int64", (1,), name="end", scope="local") - tmp = ib.allocate("int32", (1,), name="end", scope="local") - start[0] = width * tid - with ib.if_scope(tvm.tir.all(start[0] < num_anchors)): - middle[0] = start[0] + tvm.tir.indexdiv(width, 2) - end[0] = tvm.tir.min(start[0] + width, num_anchors) - with ib.if_scope(middle[0] < num_anchors): - tmp[0] = valid_indices[by * num_anchors + middle[0] - 1] - valid_indices[by * num_anchors + middle[0] - 1] = valid_indices[ - by * num_anchors + end[0] - 1 - ] - valid_indices[by * num_anchors + end[0] - 1] += tmp[0] - with ib.else_scope(): - with ib.new_scope(): - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(bx, "thread_extent", batch_size) - with ib.if_scope(bx < batch_size): - valid_count[bx] = 0 - - return ib.get() - - def get_valid_counts_ir(data, valid_indices, valid_boxes, out, out_indices): """Low level IR to get valid count of bounding boxes given a score threshold. Also prepares to move valid boxes to the @@ -374,19 +242,8 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): valid_indices_buf = tvm.tir.decl_buffer( (batch_size, num_anchors), "int32", "valid_indices_buf", data_alignment=8 ) - valid_count_buf = tvm.tir.decl_buffer( - (batch_size,), "int32", "valid_count_buf", data_alignment=8 - ) - valid_count, valid_indices = te.extern( - [(batch_size,), (batch_size, num_anchors)], - [valid_boxes], - lambda ins, outs: get_valid_indices_ir(ins[0], outs[0], outs[1]), - dtype=["int32"], - in_buffers=[valid_boxes_buf], - out_buffers=[valid_count_buf, valid_indices_buf], - name="get_valid_indices", - tag="get_valid_indices_gpu", - ) + + valid_indices, valid_count = exclusive_scan(valid_boxes, axis=1, return_reduction=True) out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf", data_alignment=8) out_indices_buf = tvm.tir.decl_buffer( diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py new file mode 100644 index 000000000000..f19e4a14239a --- /dev/null +++ b/python/tvm/topi/cuda/scan.py @@ -0,0 +1,406 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-locals, too-many-statements +"Scan related operators" +import tvm +from tvm import te +from tvm._ffi import get_global_func +from ..transform import expand_dims, squeeze +from ..utils import ceil_div +from ..math import cast +from .. import tag +from .injective import schedule_injective_from_existing + + +def exclusive_sum_scan2d_ir(data, output, reduction=None): + """Low level IR to do exclusive sum scan along rows of 2D input. + + Parameters + ---------- + data : Buffer + Input data. 2-D Buffer with shape [batch_size, scan_axis_size]. + + output: Buffer + A buffer to store the output scan, of the same size as data + + reduction: Buffer, optional + 1D Buffer of size [batch_size], to store the sum of each row. + """ + + batch_size = data.shape[0] + scan_axis_size = data.shape[1] + + ib = tvm.tir.ir_builder.create() + + data = ib.buffer_ptr(data) + output = ib.buffer_ptr(output) + + out_dtype = output.dtype + + if reduction is not None: + reduction = ib.buffer_ptr(reduction) + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + + with ib.if_scope(scan_axis_size == 0): + with ib.new_scope(): + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(bx, "thread_extent", batch_size) + with ib.if_scope(bx < batch_size): + if reduction is not None: + reduction[bx] = 0 + with ib.else_scope(): + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(scan_axis_size, max_threads) + nthread_by = batch_size + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + ib.scope_attr(by, "thread_extent", nthread_by) + tid = bx * nthread_tx + tx + with ib.if_scope(tid < scan_axis_size): + output[by, tid] = data[by, tid] + + nthread_tx = max_threads + nthread_bx = ceil_div(scan_axis_size, max_threads) + nthread_by = batch_size + + # The following algorithm performs parallel exclusive scan + # Up Sweep of exclusive scan + lim = tvm.tir.generic.cast( + tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(scan_axis_size, "float64"))), "int64" + ) + with ib.for_range(0, lim, dtype="int64") as l2_width: + width = 2 << l2_width + + with ib.new_scope(): + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr( + bx, + "thread_extent", + tvm.tir.generic.cast(ceil_div(scan_axis_size, max_threads * width), "int32"), + ) + tid = bx * nthread_tx + tx + + by = te.thread_axis("blockIdx.y") + ib.scope_attr(by, "thread_extent", nthread_by) + start = ib.allocate("int64", (1,), name="start", scope="local") + middle = ib.allocate("int64", (1,), name="middle", scope="local") + end = ib.allocate("int64", (1,), name="end", scope="local") + start[0] = width * tid + with ib.if_scope(start[0] < scan_axis_size): + middle[0] = start[0] + tvm.tir.indexdiv(width, 2) + end[0] = tvm.te.min(start[0] + width, scan_axis_size) + with ib.if_scope(middle[0] < scan_axis_size): + output[by * scan_axis_size + end[0] - 1] += output[ + by * scan_axis_size + middle[0] - 1 + ] + + # Down Sweep of exclusive scan + with ib.new_scope(): + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(bx, "thread_extent", batch_size) + with ib.if_scope(bx < batch_size): + if reduction is not None: + reduction[bx] = output[(bx + 1) * scan_axis_size - 1] + output[(bx + 1) * scan_axis_size - 1] = cast(0, out_dtype) + + with ib.for_range(0, lim, dtype="int64") as l2_width: + width = 2 << (lim - l2_width - 1) + + with ib.new_scope(): + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr( + bx, + "thread_extent", + tvm.tir.generic.cast(ceil_div(scan_axis_size, max_threads * width), "int32"), + ) + tid = bx * nthread_tx + tx + + by = te.thread_axis("blockIdx.y") + ib.scope_attr(by, "thread_extent", nthread_by) + start = ib.allocate("int64", (1,), name="start", scope="local") + middle = ib.allocate("int64", (1,), name="middle", scope="local") + end = ib.allocate("int64", (1,), name="end", scope="local") + tmp = ib.allocate(out_dtype, (1,), name="end", scope="local") + start[0] = width * tid + with ib.if_scope(tvm.tir.all(start[0] < scan_axis_size)): + middle[0] = start[0] + tvm.tir.indexdiv(width, 2) + end[0] = tvm.tir.min(start[0] + width, scan_axis_size) + with ib.if_scope(middle[0] < scan_axis_size): + tmp[0] = output[by * scan_axis_size + middle[0] - 1] + output[by * scan_axis_size + middle[0] - 1] = output[ + by * scan_axis_size + end[0] - 1 + ] + output[by * scan_axis_size + end[0] - 1] += tmp[0] + return ib.get() + + +def get_reduction_from_exclusive_scan(data, ex_scan_output): + """Return the sum of the last element of data and the exclusive scan output. + The is the reduction of data along each row (for 2-D case). + + Parameters + ---------- + data : tvm.te.Tensor + Input data. 1-D tensor with shape [scan_axis_size], or + 2-D tensor with shape [batch_size, scan_axis_size]. + + ex_scan_output : tvm.te.Tensor + 1-D tensor that is the exclusive scan of the input, or + 2-D tensor storing the exclusive scan of each row. + + Returns + ------- + reduction : tvm.te.Tensor + 1-D tensor storing the reduction of each row. + """ + ndim = len(data.shape) + if ndim == 1: + data = expand_dims(data, axis=0) + ex_scan_output = expand_dims(ex_scan_output, axis=0) + + def ir(data, data_ex_scan, reduction): + batch_size = data.shape[0] + num_anchors = data.shape[1] + + ib = tvm.tir.ir_builder.create() + + data = ib.buffer_ptr(data) + data_ex_scan = ib.buffer_ptr(data_ex_scan) + reduction = ib.buffer_ptr(reduction) + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size): + with ib.if_scope(num_anchors > 0): + reduction[tid] = data_ex_scan[tid, num_anchors - 1] + data[tid, num_anchors - 1] + with ib.else_scope(): + reduction[tid] = 0 + + return ib.get() + + assert len(data.shape) == 2, "Only 2D input supported for now" + data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "valid_indices_buf", data_alignment=8) + ex_scan_output_buf = tvm.tir.decl_buffer( + ex_scan_output.shape, ex_scan_output.dtype, "ex_scan_output_buf", data_alignment=8 + ) + + reduction = te.extern( + [(data.shape[0],)], + [data, ex_scan_output], + lambda ins, outs: ir(ins[0], ins[1], outs[0]), + dtype=[ex_scan_output.dtype], + in_buffers=[data_buf, ex_scan_output_buf], + name="ex_scan_reduction", + tag="ex_scan_reduction_gpu", + ) + + if ndim == 1: + return squeeze(reduction, 0) + + return reduction + + +def is_thrust_available(): + """Test if thrust based scan ops are available.""" + return get_global_func("tvm.contrib.thrust.sum_scan", allow_missing=True) is not None + + +def scan_thrust(data, output_dtype, exclusive=True, return_reduction=False): + """Do exclusive scan on 1D input or along rows of 2D input, using thrust. + + Parameters + ---------- + data : tvm.te.Tensor + Input data. 1-D tensor with shape [scan_axis_size], or + 2-D tensor with shape [batch_size, scan_axis_size]. + + output_dtype: string + The dtype of the output scan tensor. + + exclusive: bool, optional + Whether or not do exclusive or inclusive scan. + + return_reduction: bool, optional + Whether or not return a 1-D tensor storing the reduction of each row. + Reductions are computed as part of the upsweep pass, so there is no extra cost. + If False, reductions are ignored. + + Returns + ------- + output : tvm.te.Tensor + 1-D tensor that is the exclusive scan of the input, or + 2-D tensor storing the exclusive scan of each row. + + reduction : tvm.te.Tensor, optional + 1-D tensor storing the reduction of each row. + Returned if return_reduction is True. + """ + data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + output_buf = tvm.tir.decl_buffer(data.shape, output_dtype, "output_buf", data_alignment=8) + output = te.extern( + [data.shape], + [data], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.thrust.sum_scan", ins[0], outs[0], exclusive + ), + dtype=[output_dtype], + in_buffers=[data_buf], + out_buffers=[output_buf], + name="exclusive_sum_scan2d", + tag="exclusive_sum_scan2d_gpu", + ) + + if return_reduction: + assert exclusive, "return_reduction should be False for inclusive scan" + reduction = get_reduction_from_exclusive_scan(data, output) + return output, reduction + + return output + + +def exclusive_scan(data, axis=-1, return_reduction=False, output_dtype=None): + """Do exclusive scan on 1D input or along rows of 2D input. + + Parameters + ---------- + data : tvm.te.Tensor + Input data. 1-D tensor with shape [scan_axis_size], or + 2-D tensor with shape [batch_size, scan_axis_size]. + + axis: int, optional + The axis to do scan on. For now, only the inner most axis is supported. + + return_reduction: bool, optional + Whether or not return a 1-D tensor storing the reduction of each row. + Reductions are computed as part of the upsweep pass, so there is no extra cost. + If False, reductions are ignored. + + output_dtype: string, optional + The dtype of the output scan tensor. If not provided, the dtype of the input is used. + + Returns + ------- + output : tvm.te.Tensor + 1-D tensor that is the exclusive scan of the input, or + 2-D tensor storing the exclusive scan of each row. + + reduction : tvm.te.Tensor, optional + 1-D tensor storing the reduction of each row. + Returned if return_reduction is True. + """ + # TODO(masahi): Support other binary operators + ndim = len(data.shape) + if axis < 0: + axis += ndim + assert axis == ndim - 1, "Only support scan on the inner most axis." + + if output_dtype is None: + output_dtype = data.dtype + + target = tvm.target.Target.current() + if target and target.kind.name == "cuda" and is_thrust_available(): + return scan_thrust(data, output_dtype, exclusive=True, return_reduction=return_reduction) + + if ndim == 1: + # TIR exclusive scan accepts only 2D inputs. + data = expand_dims(data, axis=0) + + data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + output_buf = tvm.tir.decl_buffer(data.shape, output_dtype, "output_buf", data_alignment=8) + + if len(data.shape) == 2: + if return_reduction: + output, reduction = te.extern( + [data.shape, (data.shape[0],)], + [data], + lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0], outs[1]), + dtype=[data.dtype, output_dtype], + in_buffers=[data_buf], + name="exclusive_scan", + tag="exclusive_scan_gpu", + ) + else: + output = te.extern( + [data.shape], + [data], + lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0]), + dtype=[output_dtype], + in_buffers=[data_buf], + out_buffers=[output_buf], + name="exclusive_scan", + tag="exclusive_scan_gpu", + ) + reduction = None + else: + assert False, "Unsupported dimension {}".format(ndim) + + if ndim == 1: + output = squeeze(output, 0) + if return_reduction: + reduction = squeeze(reduction, 0) + + if return_reduction: + return output, reduction + + return output + + +def schedule_scan(outs): + """Schedule for scan operator. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of scan + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + scheduled_ops = [] + + def traverse(op): + if tag.is_injective(op.tag): + schedule_injective_from_existing(s, op.output(0)) + for tensor in op.input_tensors: + if tensor.op.input_tensors and tensor.op not in scheduled_ops: + traverse(tensor.op) + scheduled_ops.append(op) + + for out in outs: + traverse(out.op) + return s diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index b34bd1df14e4..444fb25cc34b 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -22,11 +22,7 @@ from ..generic import schedule_extern from .nms import atomic_add from .sort import stable_sort_by_key_thrust, is_thrust_available -from ..utils import prod - - -def ceil_div(a, b): - return (a + b - 1) // b +from ..utils import prod, ceil_div def _memcpy_ir(ib, out_ptr, data_ptr, shape): diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index 9b6a18a8b06b..18340385205e 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -23,6 +23,7 @@ from .injective import schedule_injective_from_existing from ..transform import strided_slice, transpose from .. import tag +from ..utils import ceil_div def swap(arr, axis): @@ -61,10 +62,6 @@ def traverse(op): return s -def ceil_div(a, b): - return tvm.tir.indexdiv(a + b - 1, b) - - def _sort_init(ib, shape, axis, keys_in, keys_out, values_out=None, value_init_func=None): """Initialize the output buffers by copying from inputs""" axis_mul_before = 1 diff --git a/python/tvm/topi/cuda/sparse.py b/python/tvm/topi/cuda/sparse.py index cb61d9686919..0b46cf0f9f97 100644 --- a/python/tvm/topi/cuda/sparse.py +++ b/python/tvm/topi/cuda/sparse.py @@ -23,7 +23,7 @@ from tvm import relay, te from .. import nn -from ..utils import traverse_inline, get_const_tuple, prod, get_const_int +from ..utils import traverse_inline, get_const_tuple, prod, get_const_int, ceil_div def sparse_dense(data, weight_data, weight_indices, weight_indptr, sparse_lhs=False): @@ -162,9 +162,6 @@ def sparse_dense_tir(data, w_data, w_indices, w_indptr): default_function_kernel1 for the multiply. """ - def ceil_div(a, b): - return (a + (b - 1)) // b - def gen_ir(data, w_data, w_indices, w_indptr, out): # pylint: disable=invalid-name # TODO(tkonolige): use tensorcores for block multiply diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py index c3e14eff3919..dfc226f0c331 100644 --- a/python/tvm/topi/utils.py +++ b/python/tvm/topi/utils.py @@ -487,3 +487,8 @@ def is_empty_shape(shape): Whether input shape is empty or has dimesion with size 0. """ return cpp.utils.is_empty_shape(shape) + + +def ceil_div(a, b): + """Return ceil division of a by b""" + return tvm.tir.indexdiv(a + (b - 1), b) diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu index 6a48f1ad876a..4e3e3a81af1a 100644 --- a/src/runtime/contrib/thrust/thrust.cu +++ b/src/runtime/contrib/thrust/thrust.cu @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -264,5 +265,80 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key") } }); +template +void thrust_scan(DLTensor* data, + DLTensor* output, + bool exclusive) { + thrust::device_ptr data_ptr(static_cast(data->data)); + thrust::device_ptr output_ptr(static_cast(output->data)); + const auto scan_size = data->shape[data->ndim - 1]; + + if (scan_size == 0) return; + + if (data->ndim == 1 || (data->ndim == 2 && data->shape[0] == 1)) { + if (exclusive) { + thrust::exclusive_scan(data_ptr, data_ptr + scan_size, output_ptr); + } else { + thrust::inclusive_scan(data_ptr, data_ptr + scan_size, output_ptr); + } + } else { + // Use thrust segmented scan to compute scan on the inner most axis + // data->shape[0] * data->shape[1] * ... * data->shape[ndim - 2] scans are + // computed in parallel + + // This is for constructing a sequence 0, 0, 0,...,1, 1, 1,...,2, 2, 2,..., + // without materializing the sequence vector + auto counting_iter = thrust::counting_iterator(0); + // Without __host__ annotation, cub crashes + auto linear_index_to_scan_key = [scan_size] __host__ __device__(int64_t i) { + return i / scan_size; + }; // NOLINT(*) + auto key_iter = thrust::make_transform_iterator(counting_iter, linear_index_to_scan_key); + int64_t size = 1; + for (int i = 0; i < data->ndim; ++i) size *= data->shape[i]; + + if (exclusive) { + thrust::exclusive_scan_by_key(key_iter, key_iter + size, data_ptr, output_ptr); + } else { + thrust::inclusive_scan_by_key(key_iter, key_iter + size, data_ptr, output_ptr); + } + } +} + +TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan") +.set_body([](TVMArgs args, TVMRetValue* ret) { + ICHECK_EQ(args.num_args, 3); + DLTensor* data = args[0]; + DLTensor* output = args[1]; + bool exclusive = args[2]; + + auto in_dtype = DLDataType2String(data->dtype); + auto out_dtype = DLDataType2String(output->dtype); + + if (in_dtype == "int32") { + if (out_dtype == "int32") { + thrust_scan(data, output, exclusive); + } else if (out_dtype == "int64") { + thrust_scan(data, output, exclusive); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (in_dtype == "int64") { + if (out_dtype == "int64") { + thrust_scan(data, output, exclusive); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (in_dtype == "float32") { + if (out_dtype == "float32") { + thrust_scan(data, output, exclusive); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else { + LOG(FATAL) << "Unsupported input dtype: " << in_dtype; + } +}); + } // namespace contrib } // namespace tvm diff --git a/tests/python/contrib/test_sort.py b/tests/python/contrib/test_sort.py index f338276ca118..a049602ac265 100644 --- a/tests/python/contrib/test_sort.py +++ b/tests/python/contrib/test_sort.py @@ -17,7 +17,7 @@ import tvm import tvm.testing from tvm import te -from tvm.topi.cuda import stable_sort_by_key_thrust, is_thrust_available, sort_by_key +from tvm.topi.cuda import sort_by_key import numpy as np @@ -91,38 +91,6 @@ def test_sort_np(): tvm.testing.assert_allclose(c.asnumpy(), np_out, rtol=1e-5) -def test_thrust_stable_sort_by_key(): - if not is_thrust_available(): - print("skip because thrust is not enabled...") - return - - size = 6 - keys = te.placeholder((size,), name="keys", dtype="int32") - values = te.placeholder((size,), name="values", dtype="int32") - - keys_out, values_out = stable_sort_by_key_thrust(keys, values) - - ctx = tvm.gpu(0) - target = "cuda" - s = te.create_schedule([keys_out.op, values_out.op]) - f = tvm.build(s, [keys, values, keys_out, values_out], target) - - keys_np = np.array([1, 4, 2, 8, 2, 7], np.int32) - values_np = np.random.randint(0, 10, size=(size,)).astype(np.int32) - keys_np_out = np.zeros(keys_np.shape, np.int32) - values_np_out = np.zeros(values_np.shape, np.int32) - keys_in = tvm.nd.array(keys_np, ctx) - values_in = tvm.nd.array(values_np, ctx) - keys_out = tvm.nd.array(keys_np_out, ctx) - values_out = tvm.nd.array(values_np_out, ctx) - f(keys_in, values_in, keys_out, values_out) - - ref_keys_out = np.sort(keys_np) - ref_values_out = np.array([values_np[i] for i in np.argsort(keys_np)]) - tvm.testing.assert_allclose(keys_out.asnumpy(), ref_keys_out, rtol=1e-5) - tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) - - def test_sort_by_key_gpu(): size = 6 keys = te.placeholder((size,), name="keys", dtype="int32") @@ -158,5 +126,4 @@ def test_sort_by_key_gpu(): if __name__ == "__main__": test_sort() test_sort_np() - test_thrust_stable_sort_by_key() test_sort_by_key_gpu() diff --git a/tests/python/contrib/test_thrust.py b/tests/python/contrib/test_thrust.py new file mode 100644 index 000000000000..5f66d465bf17 --- /dev/null +++ b/tests/python/contrib/test_thrust.py @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +from tvm import te +from tvm.topi.cuda import stable_sort_by_key_thrust, is_thrust_available +from tvm.topi.cuda.scan import exclusive_scan, scan_thrust, schedule_scan +import numpy as np + + +def test_stable_sort_by_key(): + if not is_thrust_available(): + print("skip because thrust is not enabled...") + return + + size = 6 + keys = te.placeholder((size,), name="keys", dtype="int32") + values = te.placeholder((size,), name="values", dtype="int32") + + keys_out, values_out = stable_sort_by_key_thrust(keys, values) + + ctx = tvm.gpu(0) + target = "cuda" + s = te.create_schedule([keys_out.op, values_out.op]) + f = tvm.build(s, [keys, values, keys_out, values_out], target) + + keys_np = np.array([1, 4, 2, 8, 2, 7], np.int32) + values_np = np.random.randint(0, 10, size=(size,)).astype(np.int32) + keys_np_out = np.zeros(keys_np.shape, np.int32) + values_np_out = np.zeros(values_np.shape, np.int32) + keys_in = tvm.nd.array(keys_np, ctx) + values_in = tvm.nd.array(values_np, ctx) + keys_out = tvm.nd.array(keys_np_out, ctx) + values_out = tvm.nd.array(values_np_out, ctx) + f(keys_in, values_in, keys_out, values_out) + + ref_keys_out = np.sort(keys_np) + ref_values_out = np.array([values_np[i] for i in np.argsort(keys_np)]) + tvm.testing.assert_allclose(keys_out.asnumpy(), ref_keys_out, rtol=1e-5) + tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) + + +def test_exclusive_scan(): + if not is_thrust_available(): + print("skip because thrust is not enabled...") + return + + for ishape in [(1,), (10, 10)]: + values = te.placeholder(ishape, name="values", dtype="int32") + + with tvm.target.Target("cuda"): + scan, reduction = exclusive_scan(values, return_reduction=True) + s = schedule_scan([scan, reduction]) + + ctx = tvm.gpu(0) + f = tvm.build(s, [values, scan, reduction], "cuda") + + values_np = np.random.randint(0, 10, size=ishape).astype(np.int32) + values_np_out = np.zeros(values_np.shape, np.int32) + + if len(ishape) == 1: + reduction_shape = () + else: + reduction_shape = (ishape[0],) + + reduction_np_out = np.zeros(reduction_shape, np.int32) + + values_in = tvm.nd.array(values_np, ctx) + values_out = tvm.nd.array(values_np_out, ctx) + reduction_out = tvm.nd.array(reduction_np_out, ctx) + f(values_in, values_out, reduction_out) + + ref_values_out = np.cumsum(values_np, axis=-1, dtype="int32") - values_np + tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) + ref_reduction_out = np.sum(values_np, axis=-1) + tvm.testing.assert_allclose(reduction_out.asnumpy(), ref_reduction_out, rtol=1e-5) + + +def test_inclusive_scan(): + if not is_thrust_available(): + print("skip because thrust is not enabled...") + return + + out_dtype = "int64" + + for ishape in [(10,), (10, 10)]: + values = te.placeholder(ishape, name="values", dtype="int32") + + with tvm.target.Target("cuda"): + scan = scan_thrust(values, out_dtype, exclusive=False) + s = tvm.te.create_schedule([scan.op]) + + ctx = tvm.gpu(0) + f = tvm.build(s, [values, scan], "cuda") + + values_np = np.random.randint(0, 10, size=ishape).astype(np.int32) + values_np_out = np.zeros(values_np.shape, out_dtype) + values_in = tvm.nd.array(values_np, ctx) + values_out = tvm.nd.array(values_np_out, ctx) + f(values_in, values_out) + + ref_values_out = np.cumsum(values_np, axis=-1, dtype=out_dtype) + tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) + + +if __name__ == "__main__": + test_stable_sort_by_key() + test_exclusive_scan() + test_inclusive_scan() diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index d30e7873dae7..a537782355d2 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -879,6 +879,54 @@ def test_any_topk(): verify_any_topk(any_dims(1), 0, (0,), "float32", ret_type="both") +def verify_any_get_valid_counts(num_anchor_real, dtype, targets=None): + mod = tvm.IRModule() + batch_size = 1 + num_anchor = relay.Any() + data = relay.var("data", shape=(batch_size, num_anchor, 5), dtype=dtype) + np_data = np.random.uniform(size=(batch_size, num_anchor_real, 5)).astype(dtype) + + np_out1 = np.zeros(shape=(batch_size,)) + np_out2 = np.zeros(shape=np_data.shape).astype(dtype) + np_out3 = np.zeros(shape=(batch_size, num_anchor_real)) + score_threshold = 0.95 + + for i in range(batch_size): + np_out1[i] = 0 + inter_idx = 0 + for j in range(num_anchor_real): + score = np_data[i, j, 0] + if score > score_threshold: + for k in range(5): + np_out2[i, inter_idx, k] = np_data[i, j, k] + np_out1[i] += 1 + np_out3[i, inter_idx] = j + inter_idx += 1 + if j >= np_out1[i]: + for k in range(5): + np_out2[i, j, k] = -1.0 + np_out3[i, j] = -1 + + z = relay.vision.get_valid_counts(data, score_threshold, 0, score_index=0) + + mod["main"] = relay.Function([data], z.astuple()) + + check_result([np_data], mod, [np_out1, np_out2, np_out3], targets=targets) + + +@tvm.testing.uses_gpu +def test_any_get_valid_counts(): + verify_any_get_valid_counts(10, "float32") + # opencl seems to have issues with empty size buffer + # Check failed: err_code == CL_SUCCESS == false: OpenCL Error, + # code=-61: CL_INVALID_BUFFER_SIZE + targets = [] + for tgt, ctx in tvm.testing.enabled_targets(): + if "opencl" not in tgt: + targets.append((tgt, ctx)) + verify_any_get_valid_counts(0, "float32", targets=targets) + + @tvm.testing.uses_gpu def test_fused_ops(): x = relay.var("x", shape=(relay.Any(), relay.Any()), dtype="float32")