Skip to content

Commit

Permalink
add dynamic strided slice to the onnx importer
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart committed Sep 11, 2020
1 parent 2892e6a commit 1fc3721
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 21 deletions.
47 changes: 28 additions & 19 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from .common import AttrCvt, Renamer
from .common import get_relay_op, new_var, infer_shape, infer_channels
from .common import infer_type, get_name, infer_value_simulated
from .common import infer_type, get_name

__all__ = ["from_onnx"]

Expand Down Expand Up @@ -945,7 +945,6 @@ def _impl_v9(cls, inputs, attr, params):
return out



class Shape(OnnxOpConverter):
"""Operator converter for Shape."""

Expand Down Expand Up @@ -1047,24 +1046,35 @@ def _impl_v1(cls, inputs, attr, params):

@classmethod
def _impl_v10(cls, inputs, attr, params):
attrs = {"starts": inputs[1], "ends": inputs[2]}
if len(inputs) >= 4:
attrs["axes"] = inputs[3]
attrs = {k: (v, get_name(v)) for (k, v) in attrs.items()}
attrs = {
k: params[v[1]].asnumpy()
if v[1] in params
else infer_value_simulated(v[0], params).asnumpy()
for (k, v) in attrs.items()
}
starts = inputs[1]
ends = inputs[2]
axes = inputs[3]
steps = inputs[4]

data_rank = len(infer_shape(inputs[0]))

# Update the starts and ends according to axes if required.
if "axes" in attrs:
if max(attrs["axes"] + 1) != len(attrs["axes"]):
new_starts, new_ends, _ = cls._common(attrs["starts"], attrs["ends"], attrs["axes"])
attrs["starts"] = new_starts
attrs["ends"] = new_ends
return _op.strided_slice(inputs[0], begin=list(attrs["starts"]), end=list(attrs["ends"]))
if axes is not None:
data_shape = _op.shape_of(inputs[0], dtype=infer_type(ends).checked_type.dtype)
starts = _op.scatter(
_op.const([0] * data_rank, dtype=infer_type(starts).checked_type.dtype),
axes,
starts,
axis=0,
)
ends = _op.scatter(data_shape, axes, ends, axis=0)
if steps is not None:
steps = _op.scatter(
_op.const([1] * data_rank, dtype=infer_type(steps).checked_type.dtype),
axes,
steps,
axis=0,
)

if steps is None:
steps = _op.const([1] * data_rank, dtype=infer_type(starts).checked_type.dtype)

return _op.strided_slice(inputs[0], starts, ends, steps)


class Gather(OnnxOpConverter):
Expand Down Expand Up @@ -1406,7 +1416,6 @@ def _impl_v6(cls, inputs, attr, params):
return _op.tile(inputs[0], inputs[1])



class Erf(OnnxOpConverter):
"""Operator converter for Erf"""

Expand Down
5 changes: 3 additions & 2 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,12 +599,13 @@ def add_noop_to_input_attr(attr_name, attr):
model = helper.make_model(graph, producer_name="slice_test")

for target, ctx in tvm.testing.enabled_targets():
tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, "float32", opset=10)
tvm_out = get_tvm_output_with_vm(model, indata, target, ctx, opset=10, freeze_params=True)

tvm.testing.assert_allclose(outdata, tvm_out)


@tvm.testing.uses_gpu
# TODO(mbrookhart): enable once VM supports heterogenous execution
# @tvm.testing.uses_gpu
def test_slice():
x = np.random.randn(20, 10, 5).astype(np.float32)
_test_slice_iteration_v1(x, x[0:3, 0:10], starts=(0, 0), ends=(3, 10), axes=(0, 1))
Expand Down

0 comments on commit 1fc3721

Please sign in to comment.