Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay, TOPI] Add numpy style cumsum op #7334

Merged
merged 27 commits into from
Jan 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,16 @@ struct MatrixSetDiagAttrs : public tvm::AttrsNode<MatrixSetDiagAttrs> {
}
}; // struct MatrixSetDiagAttrs

/*! \brief Attributes used in cumsum operator */
struct CumsumAttrs : public tvm::AttrsNode<CumsumAttrs> {
Integer axis;
DataType dtype;
TVM_DECLARE_ATTRS(CumsumAttrs, "relay.attrs.CumsumAttrs") {
TVM_ATTR_FIELD(axis).describe("The axis to sum over").set_default(NullValue<Integer>());
TVM_ATTR_FIELD(dtype).describe("Output data type").set_default(NullValue<DataType>());
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
12 changes: 11 additions & 1 deletion python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def compute_scatter_add(attrs, inputs, output_type):

_reg.register_strategy("scatter_add", strategy.scatter_add_strategy)

# scatter
# scatter_nd
@_reg.register_compute("scatter_nd")
def compute_scatter_nd(attrs, inputs, output_type):
"""Compute definition of scatter_nd"""
Expand All @@ -112,6 +112,16 @@ def compute_scatter_nd(attrs, inputs, output_type):

_reg.register_strategy("scatter_nd", strategy.scatter_nd_strategy)

# cumsum
@_reg.register_compute("cumsum")
def compute_cumsum(attrs, inputs, output_type):
"""Compute definition of cumsum"""
return [topi.cumsum(inputs[0], attrs.axis, attrs.dtype)]


_reg.register_strategy("cumsum", strategy.cumsum_strategy)
_reg.register_shape_func("cumsum", False, elemwise_shape_func)

#####################
# Shape functions #
#####################
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,3 +996,15 @@ def argwhere_strategy_cuda(attrs, inputs, out_type, target):
name="argwhere.cuda",
)
return strategy


@cumsum_strategy.register(["cuda", "gpu"])
def cumsum_strategy_cuda(attrs, inputs, out_type, target):
"""cumsum cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_cumsum(topi.cuda.cumsum),
wrap_topi_schedule(topi.cuda.schedule_scan),
name="cumsum.cuda",
)
return strategy
21 changes: 21 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,3 +1361,24 @@ def threefry_split_strategy(attrs, inputs, out_type, target):
name="threefry_split.generic",
)
return strategy


def wrap_compute_cumsum(topi_compute):
"""Wrap cumsum topi compute"""

def _compute_cumsum(attrs, inputs, _):
return [topi_compute(inputs[0], attrs.axis, attrs.dtype)]

return _compute_cumsum


@override_native_generic_func("cumsum_strategy")
def cumsum_strategy(attrs, inputs, out_type, target):
"""cumsum generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_cumsum(topi.cumsum),
wrap_topi_schedule(topi.generic.schedule_extern),
name="cumsum.generic",
)
return strategy
49 changes: 49 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,3 +1320,52 @@ def adv_index(inputs):
Output tensor.
"""
return _make.adv_index(Tuple(inputs))


def cumsum(data, axis=None, dtype=None):
"""Numpy style cumsum op. Return the cumulative inclusive sum of the elements along
a given axis.

Parameters
----------
data : relay.Expr
The input data to the operator.

axis : int, optional
Axis along which the cumulative sum is computed. The default (None) is to compute
the cumsum over the flattened array.

dtype : string, optional
Type of the returned array and of the accumulator in which the elements are summed.
If dtype is not specified, it defaults to the dtype of data.
masahi marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
result : relay.Expr
The result has the same size as data, and the same shape as data if axis is not None.
If axis is None, the result is a 1-d array.

Examples
--------
.. code-block:: python
a = [[1,2,3], [4,5,6]]

cumsum(a) # if axis is not provided, cumsum is done over the flattened input.
-> [ 1, 3, 6, 10, 15, 21]

cumsum(a, dtype="float32")
-> [ 1., 3., 6., 10., 15., 21.]

cumsum(a, axis=0) # sum over rows for each of the 3 columns
-> [[1, 2, 3],
[5, 7, 9]]

cumsum(a, axis=1)
-> [[ 1, 3, 6],
[ 4, 9, 15]]

a = [1, 0, 1, 0, 1, 1, 0] # a is a boolean array
cumsum(a, dtype=int32) # dtype should be provided to get the expected results
-> [1, 1, 2, 2, 3, 4, 4]
"""
return _make.cumsum(data, axis, dtype)
1 change: 1 addition & 0 deletions python/tvm/topi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .scatter import *
from .scatter_add import *
from .argwhere import *
from .cumsum import *
from . import generic
from . import nn
from . import x86
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,4 @@
from .correlation import *
from .sparse import *
from .argwhere import *
from .scan import *
3 changes: 2 additions & 1 deletion python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,8 @@ def _get_sorted_indices(data, data_buf, score_index, score_shape):
tag="fetch_score",
)

if is_thrust_available():
target = tvm.target.Target.current()
if target and target.kind.name == "cuda" and is_thrust_available():
sort_tensor = argsort_thrust(score_tensor, axis=1, is_ascend=False, dtype="int32")
else:
sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, dtype="int32")
Expand Down
Loading