Skip to content

Commit

Permalink
Add tensorflow unique op
Browse files Browse the repository at this point in the history
  • Loading branch information
ymwangg committed Feb 11, 2021
1 parent 1b62d2c commit 73330ec
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 3 deletions.
13 changes: 13 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2295,6 +2295,18 @@ def _impl(inputs, attr, params, mod):
return _impl


def _unique():
def _impl(inputs, attr, params, mod):
assert len(inputs) == 1
x = inputs[0]
[output, indices, counts, num_uniq] = _op.unique(x)
output_sliced = _op.strided_slice(
output, begin=[0], end=num_uniq, slice_mode="size"
)
return [output_sliced, indices]

return _impl

# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -2471,6 +2483,7 @@ def _impl(inputs, attr, params, mod):
"TopKV2": _topk(),
"Transpose": _transpose(),
"TruncateMod": _elemwise("mod"),
"Unique": _unique(),
"Unpack": _unpack(),
"UnravelIndex": _unravel_index(),
"Where": _where(),
Expand Down
34 changes: 34 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4686,5 +4686,39 @@ def lstm_cell():
tvm.testing.assert_allclose(tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)


#######################################################################
# Unique
# ------------
def _test_unique(
n,
dtype,
is_dyn
):
""" One iteration of a Stridedslice """

tf.reset_default_graph()
np_data = np.random.randint(100, size=n).astype(dtype)
with tf.Graph().as_default():
if is_dyn:
in_data = tf.placeholder(dtype, [n], name="in_data")
else:
in_data = tf.constant(np_data, dtype, name="in_data")
tf.unique(in_data)
if is_dyn:
compare_tf_with_tvm(np_data, "in_data:0", [
"Unique:0", "Unique:1"], mode="vm")
else:
compare_tf_with_tvm(None, "", ["Unique:0", "Unique:1"])


def test_forward_unique():
"""test Unique"""

for dtype in ["int32", "int64"]:
for is_dyn in [False, True]:
_test_unique(50, dtype, is_dyn)
_test_unique(100, dtype, is_dyn)


if __name__ == "__main__":
pytest.main([__file__])
6 changes: 3 additions & 3 deletions tests/python/relay/test_op_level6.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,15 @@ def calc_unique(data):
num_uniq = np.array([len(uniq)]).astype("int32")
return uniq, inverse, counts, num_uniq

def verify_unique(len, dtype, is_dyn=False):
def verify_unique(n, dtype, is_dyn=False):
if is_dyn:
x = relay.var("x", relay.TensorType([relay.Any()], dtype))
else:
x = relay.var("x", relay.TensorType([len], dtype))
x = relay.var("x", relay.TensorType([n], dtype))
outs = relay.unique(x)
outs = outs.astuple()
func = relay.Function([x], outs)
x_data = np.random.randint(100, size=len).astype(dtype)
x_data = np.random.randint(100, size=n).astype(dtype)

if is_dyn:
backends = ["vm", "debug"]
Expand Down

0 comments on commit 73330ec

Please sign in to comment.