Skip to content

Commit

Permalink
[onnx] fix onnx where broadcast (apache#10106)
Browse files Browse the repository at this point in the history
* fix onnx where bcast

* jostle ci

* jostle ci

* jostle ci
  • Loading branch information
lazycal authored and ylc committed Feb 16, 2022
1 parent 3695897 commit c97e3ed
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 22 deletions.
23 changes: 1 addition & 22 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2206,28 +2206,7 @@ class Where(OnnxOpConverter):

@classmethod
def _impl_v9(cls, inputs, attr, params):
condition_rank = len(infer_shape(inputs[0]))
x_rank = len(infer_shape(inputs[1]))
y_rank = len(infer_shape(inputs[2]))
ranks = [condition_rank, x_rank, y_rank]

# If one rank is longer than others, then we can broadcast
# to that shape.
max_rank = max(ranks)
max_rank_idxs = [i for i, x in enumerate(ranks) if x == max_rank]
broadcast_shape = shape_of(inputs[max_rank_idxs[0]])
# If two or more inputs have the same rank, compute the broadcast
# shape by taking the maximum value of each dimensions.
if len(max_rank_idxs) > 1:
for idx in max_rank_idxs:
broadcast_shape = _op.maximum(broadcast_shape, shape_of(inputs[idx]))

broadcast_shape = fold_constant(broadcast_shape)

condition = _op.broadcast_to(inputs[0], broadcast_shape)
x = _op.broadcast_to(inputs[1], broadcast_shape)
y = _op.broadcast_to(inputs[2], broadcast_shape)
return _op.where(condition, x, y)
return _op.where(*inputs)


class Or(Elemwise):
Expand Down
6 changes: 6 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2510,6 +2510,12 @@ def verify_where(condition, x, y, dtype, outdata, dynamic=False):
verify_where(condition, x, y, TensorProto.FLOAT, outdata)
verify_where(condition, x, y, TensorProto.FLOAT, outdata, dynamic=True)

condition = np.random.uniform(size=(3, 1)) < 0.5
x = np.random.uniform(size=2).astype(np.float32)
y = np.random.uniform(size=2).astype(np.float32)
outdata = np.where(condition, x, y)
verify_where(condition, x, y, TensorProto.FLOAT, outdata)


@tvm.testing.parametrize_targets
def test_or(target, dev):
Expand Down

0 comments on commit c97e3ed

Please sign in to comment.