Skip to content

Commit

Permalink
[Frontend][Tensorflow] Add unique operator (apache#7441)
Browse files Browse the repository at this point in the history
* Initial commit of the unique operator

Add unit tests for unique operator

* Add tensorflow unique op

* Refactor unique to use sort-based algorithm

* Change relay.unique test to run only on cpu

* Change topi.unique test to run only on cpu

* Change range to parallel for parallelizable loops

* Add return_counts option for relay.unique and topi.unique, add pytorch frontend

* Fix pylint

* Patch pytorch frontend

* Initial support of topi.cuda.unique

* Refactor to use ir_builder directly

* Modularize adjacent difference

* Refactor to simplify

* Fix typo

* Combine _unique and _unique_with_counts

* Reuse indices_ptr to remove arange_ptr

Co-authored-by: Yanming Wang <yanmwang@amazon.com>
  • Loading branch information
2 people authored and Lokiiiiii committed Mar 1, 2021
1 parent 7d75e6e commit 7492019
Show file tree
Hide file tree
Showing 17 changed files with 1,199 additions and 1 deletion.
12 changes: 12 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,18 @@ struct CumsumAttrs : public tvm::AttrsNode<CumsumAttrs> {
}
};

/*! \brief Attributes used in unique operator */
struct UniqueAttrs : public tvm::AttrsNode<UniqueAttrs> {
bool sorted;
bool return_counts;
TVM_DECLARE_ATTRS(UniqueAttrs, "relay.attrs.UniqueAttrs") {
TVM_ATTR_FIELD(sorted).describe("Whether the unique elements are sorted").set_default(true);
TVM_ATTR_FIELD(return_counts)
.describe("Whether to return an additional tensor with counts of each unique elements")
.set_default(false);
}
}; // struct UniqueAttrs

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
19 changes: 19 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2157,6 +2157,24 @@ def is_floating_point(self, inputs, input_types):
is_float = input_type in ["float32", "float64", "float16", "bfloat16"]
return _expr.const(is_float)

def unique(self, inputs, input_types):
assert len(inputs) == 4
[data, is_sorted, return_inverse, return_counts] = inputs
if not is_sorted:
logging.warning("TVM always assumes sorted=True for torch.unique")
is_sorted = True
if return_counts:
[unique, indices, num_uniq, counts] = _op.unique(
data, is_sorted=is_sorted, return_counts=True
)
unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size")
counts_sliced = _op.strided_slice(counts, begin=[0], end=num_uniq, slice_mode="size")
return (unique_sliced, indices, counts_sliced)
else:
[unique, indices, num_uniq] = _op.unique(data, is_sorted=is_sorted, return_counts=False)
unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size")
return (unique_sliced, indices)

# Operator mappings
def create_convert_map(self):
self.convert_map = {
Expand Down Expand Up @@ -2363,6 +2381,7 @@ def create_convert_map(self):
"aten::masked_select": self.masked_select,
"aten::argsort": self.argsort,
"aten::sort": self.sort,
"aten::_unique2": self.unique,
}

def update_convert_map(self, custom_map):
Expand Down
26 changes: 26 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2471,6 +2471,30 @@ def _impl(inputs, attr, params, mod):
return _impl


def _unique(return_counts=True):
def _impl(inputs, attr, params, mod):
assert len(inputs) == 1
data = inputs[0]
if return_counts:
[unique, indices, num_uniq, counts] = _op.unique(
data, is_sorted=False, return_counts=True
)
unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size")
counts_sliced = _op.strided_slice(counts, begin=[0], end=num_uniq, slice_mode="size")
return _expr.TupleWrapper(
_expr.Tuple([unique_sliced, indices, counts_sliced]),
3,
)
[unique, indices, num_uniq] = _op.unique(data, is_sorted=False, return_counts=False)
unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size")
return _expr.TupleWrapper(
_expr.Tuple([unique_sliced, indices]),
2,
)

return _impl


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -2650,6 +2674,8 @@ def _impl(inputs, attr, params, mod):
"TopKV2": _topk(),
"Transpose": _transpose(),
"TruncateMod": _elemwise("mod"),
"Unique": _unique(False),
"UniqueWithCounts": _unique(True),
"Unpack": _unpack(),
"UnravelIndex": _unravel_index(),
"Where": _where(),
Expand Down
44 changes: 44 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,15 @@ def compute_cumsum(attrs, inputs, output_type):
_reg.register_strategy("cumsum", strategy.cumsum_strategy)
_reg.register_shape_func("cumsum", False, elemwise_shape_func)


@_reg.register_compute("unique")
def compute_unique(attrs, inputs, output_type):
"""Compute definition of unique"""
return topi.unique(inputs[0], attrs.sorted, attrs.return_counts)


_reg.register_strategy("unique", strategy.unique_strategy)

#####################
# Shape functions #
#####################
Expand Down Expand Up @@ -957,3 +966,38 @@ def where_shape_func(attrs, inputs, _):
out_shape = _broadcast_shape_tensors(bcast_shape, cond_shape)

return [out_shape]


@script
def _unique_shape(data_shape):
unique_shape = output_tensor((1,), "int64")
indices_shape = output_tensor((1,), "int64")
num_unique_shape = output_tensor((1,), "int64")
unique_shape[0] = data_shape[0]
indices_shape[0] = data_shape[0]
num_unique_shape[0] = int64(1)
return (unique_shape, indices_shape, num_unique_shape)


@script
def _unique_with_counts_shape(data_shape):
unique_shape = output_tensor((1,), "int64")
indices_shape = output_tensor((1,), "int64")
num_unique_shape = output_tensor((1,), "int64")
counts_shape = output_tensor((1,), "int64")
unique_shape[0] = data_shape[0]
indices_shape[0] = data_shape[0]
num_unique_shape[0] = int64(1)
counts_shape[0] = data_shape[0]
return (unique_shape, indices_shape, num_unique_shape, counts_shape)


@_reg.register_shape_func("unique", False)
def unique_shape_func(attrs, inputs, _):
"""
Shape func for unique operator.
"""
if attrs.return_counts:
return _unique_with_counts_shape(inputs[0])
else:
return _unique_shape(inputs[0])
12 changes: 12 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,3 +1009,15 @@ def cumsum_strategy_cuda(attrs, inputs, out_type, target):
name="cumsum.cuda",
)
return strategy


@unique_strategy.register(["cuda", "gpu"])
def unique_strategy_cuda(attrs, inputs, out_type, target):
"""unique cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_unique(topi.cuda.unique),
wrap_topi_schedule(topi.cuda.schedule_scan),
name="unique.cuda",
)
return strategy
21 changes: 21 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1440,3 +1440,24 @@ def cumsum_strategy(attrs, inputs, out_type, target):
name="cumsum.generic",
)
return strategy


def wrap_compute_unique(topi_compute):
"""Wrap unique topi compute"""

def _compute_unique(attrs, inputs, _):
return topi_compute(inputs[0], attrs.sorted, attrs.return_counts)

return _compute_unique


@override_native_generic_func("unique_strategy")
def unique_strategy(attrs, inputs, out_type, target):
"""unique generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_unique(topi.unique),
wrap_topi_schedule(topi.generic.schedule_unique),
name="unique.generic",
)
return strategy
54 changes: 54 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1502,3 +1502,57 @@ def cumsum(data, axis=None, dtype=None, exclusive=None):
-> [1, 1, 2, 2, 3, 4, 4]
"""
return _make.cumsum(data, axis, dtype, exclusive)


def unique(data, is_sorted=True, return_counts=False):
"""
Find the unique elements of a 1-D tensor. Please note `output` and `counts` are all padded to
have the same length of `data` and element with index >= num_unique[0] has undefined value.
Parameters
----------
data : relay.Expr
A 1-D tensor of integers.
sorted : bool
Whether to sort the unique elements in ascending order before returning as output.
return_counts : bool
Whether to return the count of each unique element.
Returns
-------
output : relay.Expr
A 1-D tensor containing the unique elements of the input data tensor.
indices : relay.Expr
A 1-D tensor containing the index of each data element in the output tensor.
num_unique : relay.Expr
A 1-D tensor with size=1 containing the number of unique elements in the input data tensor.
counts (optional) : relay.Expr
A 1-D tensor containing the count of each unique element in the output.
Examples
--------
.. code-block:: python
[output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, False)
output = [4, 5, 1, 2, 3, ?, ?, ?]
indices = [0, 1, 2, 3, 4, 4, 0, 1]
num_unique = [5]
[output, indices, num_unique, counts] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, True)
output = [4, 5, 1, 2, 3, ?, ?, ?]
indices = [0, 1, 2, 3, 4, 4, 0, 1]
num_unique = [5]
counts = [2, 2, 1, 1, 2, ?, ?, ?]
[output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], True)
output = [1, 2, 3, 4, 5, ?, ?, ?]
indices = [3, 4, 0, 1, 2, 2, 3, 4]
num_unique = [5]
"""
if return_counts:
return TupleWrapper(_make.unique(data, is_sorted, return_counts), 4)
return TupleWrapper(_make.unique(data, is_sorted, return_counts), 3)
1 change: 1 addition & 0 deletions python/tvm/topi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from .interpolate import *
from .cumsum import *
from .einsum import *
from .unique import *
from . import generic
from . import nn
from . import x86
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,4 @@
from . import tensorcore_alter_op
from .argwhere import *
from .scan import *
from .unique import *
Loading

0 comments on commit 7492019

Please sign in to comment.