Skip to content

Commit

Permalink
Onnx eyelike (apache#8191)
Browse files Browse the repository at this point in the history
* add ONNX EyeLike converter

* need to implement k

* test pass

* eyelike tests all pass

* Revert "test pass"

This reverts commit 0aa7347.

* removed comments, black'd, lint

* changed == to is in onnx.py

Co-authored-by: Masahiro Masuda <masahi129@gmail.com>
Co-authored-by: Jocelyn <jocelyn@pop-os.localdomain>
  • Loading branch information
3 people authored and trevor-m committed Jun 17, 2021
1 parent fd4a1b0 commit e57780c
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 3 deletions.
22 changes: 22 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,6 +1475,27 @@ def _impl_v11(cls, inputs, attr, params):
)


class EyeLike(OnnxOpConverter):
"""Operator converter for EyeLike."""

@classmethod
def _impl_v9(cls, inputs, attr, params):
in_checked_type = infer_type(inputs[0]).checked_type
in_dtype = in_checked_type.dtype
in_shape = list(get_const_tuple(in_checked_type.shape))
dtype = attr.get("dtype", None)
if dtype is None:
dtype = in_dtype
else:
dtype = get_type(dtype)
zeros = _op.zeros(in_shape, dtype)
dim = in_shape[0]
indices = _op.arange(_op.const(0), _op.const(dim), dtype="int32")
ones = _op.full(_op.const(1), (dim,), dtype=dtype)
k = _op.const(attr.get("k", 0), dtype="int32")
return _op.scatter_nd(zeros, _op.stack([indices, indices + k], axis=0), ones, "update")


class Greater(OnnxOpConverter):
"""Operator logical greater."""

Expand Down Expand Up @@ -3158,6 +3179,7 @@ def _get_convert_map(opset):
"Scatter": Scatter.get_converter(opset),
"ScatterElements": Scatter.get_converter(opset),
"ScatterND": ScatterND.get_converter(opset),
"EyeLike": EyeLike.get_converter(opset),
"Squeeze": AttrCvt("squeeze", {"axes": "axis"}),
"Unsqueeze": Unsqueeze.get_converter(opset),
"Pad": Pad.get_converter(opset),
Expand Down
29 changes: 26 additions & 3 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4129,6 +4129,7 @@ def verify_softplus(indata):
verify_softplus(input_data)


@tvm.testing.uses_gpu
def test_cumsum():
def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"):
cumsum_node = onnx.helper.make_node(
Expand Down Expand Up @@ -4205,6 +4206,30 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"):
verify_cumsum(data, 1, 1, 1, type="int32")


@tvm.testing.uses_gpu
def test_eyelike():
def verify_eyelike(indata):
node = helper.make_node(
"EyeLike",
inputs=["X"],
outputs=["Y"],
)

graph = helper.make_graph(
[node],
"eyelike_test",
inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, list(indata.shape))],
outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, list(indata.shape))],
)

model = helper.make_model(graph, producer_name="eyelike_test")

verify_with_ort_with_inputs(model, [indata], dtype="float32", opset=9)

input_data = np.zeros((5, 5), dtype=np.float32)
verify_eyelike(input_data)


"""
The following parameterized tests loads the tests that ONNX ships as
serialized ONNX files, inputs, and outputs. The goal of this test
Expand Down Expand Up @@ -4241,9 +4266,6 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"):
"test_cumsum_2d_negative_axis/",
"test_det_2d/",
"test_det_nd/",
"test_eyelike_populate_off_main_diagonal/",
"test_eyelike_with_dtype/",
"test_eyelike_without_dtype/",
"test_matmulinteger/",
"test_maxpool_2d_same_lower/",
"test_maxpool_2d_same_upper/",
Expand Down Expand Up @@ -4680,4 +4702,5 @@ def repeat(N, D):
test_wrong_input()
test_aten()
test_reverse_sequence()
test_eyelike()
test_qlinearconv()

0 comments on commit e57780c

Please sign in to comment.