Skip to content

Commit

Permalink
Fix python format
Browse files Browse the repository at this point in the history
  • Loading branch information
ymwangg committed Feb 11, 2021
1 parent 73330ec commit a1d5c43
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 23 deletions.
5 changes: 2 additions & 3 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2300,13 +2300,12 @@ def _impl(inputs, attr, params, mod):
assert len(inputs) == 1
x = inputs[0]
[output, indices, counts, num_uniq] = _op.unique(x)
output_sliced = _op.strided_slice(
output, begin=[0], end=num_uniq, slice_mode="size"
)
output_sliced = _op.strided_slice(output, begin=[0], end=num_uniq, slice_mode="size")
return [output_sliced, indices]

return _impl


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down
7 changes: 6 additions & 1 deletion python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,4 +911,9 @@ def unique_shape_func(attrs, inputs, _):
"""
Shape func for unique operator.
"""
return [_unique_shape_1(inputs[0]), _unique_shape_1(inputs[0]), _unique_shape_1(inputs[0]), _unique_shape_2(inputs[0])]
return [
_unique_shape_1(inputs[0]),
_unique_shape_1(inputs[0]),
_unique_shape_1(inputs[0]),
_unique_shape_2(inputs[0]),
]
4 changes: 1 addition & 3 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,9 +962,7 @@ def unique(data):
return te.extern(
[data.shape, data.shape, data.shape, (1,)],
[data],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.algorithm.unique", ins[0], *outs
),
lambda ins, outs: tvm.tir.call_packed("tvm.contrib.algorithm.unique", ins[0], *outs),
dtype=[data.dtype, "int32", "int32", "int32", "int32"],
name="unique_cpu",
tag="unique_cpu",
Expand Down
9 changes: 2 additions & 7 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4689,11 +4689,7 @@ def lstm_cell():
#######################################################################
# Unique
# ------------
def _test_unique(
n,
dtype,
is_dyn
):
def _test_unique(n, dtype, is_dyn):
""" One iteration of a Stridedslice """

tf.reset_default_graph()
Expand All @@ -4705,8 +4701,7 @@ def _test_unique(
in_data = tf.constant(np_data, dtype, name="in_data")
tf.unique(in_data)
if is_dyn:
compare_tf_with_tvm(np_data, "in_data:0", [
"Unique:0", "Unique:1"], mode="vm")
compare_tf_with_tvm(np_data, "in_data:0", ["Unique:0", "Unique:1"], mode="vm")
else:
compare_tf_with_tvm(None, "", ["Unique:0", "Unique:1"])

Expand Down
16 changes: 7 additions & 9 deletions tests/python/relay/test_op_level6.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype):
def test_unique():
def calc_unique(data):
uniq, index, inverse, counts = np.unique(
data, return_index=True, return_inverse=True, return_counts=True)
data, return_index=True, return_inverse=True, return_counts=True
)
order = np.argsort(index)
reverse_order = dict(zip(order, np.arange(len(order))))
uniq = uniq[order].astype(data.dtype)
Expand All @@ -168,21 +169,18 @@ def verify_unique(n, dtype, is_dyn=False):
for target, ctx in tvm.testing.enabled_targets():
for kind in backends:
mod = tvm.ir.IRModule.from_expr(func)
intrp = relay.create_executor(
kind, mod=mod, ctx=ctx, target=target)
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x_data)
ref_res = calc_unique(x_data)
num_uniq = ref_res[3][0]
assert num_uniq == op_res[3].asnumpy()[0]
# output
tvm.testing.assert_allclose(
op_res[0].asnumpy()[:num_uniq], ref_res[0], rtol=1e-5)
tvm.testing.assert_allclose(op_res[0].asnumpy()[:num_uniq], ref_res[0], rtol=1e-5)
# inverse_indices
tvm.testing.assert_allclose(
op_res[1].asnumpy(), ref_res[1], rtol=1e-5)
tvm.testing.assert_allclose(op_res[1].asnumpy(), ref_res[1], rtol=1e-5)
# count
tvm.testing.assert_allclose(
op_res[2].asnumpy()[:num_uniq], ref_res[2], rtol=1e-5)
tvm.testing.assert_allclose(op_res[2].asnumpy()[:num_uniq], ref_res[2], rtol=1e-5)

for dtype in ["int32", "int64"]:
for is_dyn in [False, True]:
verify_unique((50), dtype, is_dyn=is_dyn)
Expand Down

0 comments on commit a1d5c43

Please sign in to comment.