Skip to content

Commit

Permalink
[Frontend, Tensorflow] Support for broadcasting in batch_matmul when …
Browse files Browse the repository at this point in the history
…shapes differ (apache#8251)

* Support for broadcasting in batch_matmul when shapes differ

* refactor

* refactor logic for reshape in conditional

* refactor
  • Loading branch information
rohanmukh authored and trevor-m committed Jun 17, 2021
1 parent d50f5f4 commit 7e09f1f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 7 deletions.
16 changes: 9 additions & 7 deletions python/tvm/relay/frontend/tensorflow_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,22 +1132,23 @@ def _impl(inputs, attr, params, mod):
orig_shape_x = _infer_shape(input_x, mod)
orig_shape_y = _infer_shape(input_y, mod)
ndim = len(orig_shape_x)
ndim_y = len(orig_shape_y)

is_static = not check_symbolic_shape(orig_shape_x)

if ndim > 3 and not is_static:
shape_of_x = list_shape_of(inputs[0], ndim)
shape_of_y = list_shape_of(inputs[1], ndim)

# reshape n-dimensional batch matmul into 3d
if ndim > 3:
outer_dims = [orig_shape_x[i] for i in range(0, len(orig_shape_x) - 2)]
if is_static:
num_outer_elts = np.prod(outer_dims)
new_shape_x = (num_outer_elts, orig_shape_x[-2], orig_shape_x[-1])
new_shape_y = (num_outer_elts, orig_shape_y[-2], orig_shape_y[-1])
if ndim_y > 2:
new_shape_y = (num_outer_elts, orig_shape_y[-2], orig_shape_y[-1])
elif ndim_y == 2:
new_shape_y = (1, orig_shape_y[-2], orig_shape_y[-1])
else: # handle dynamic shape (dyn.reshape op)
# new shape = [prod(shape[:-2]), -2, -1]
shape_of_x = list_shape_of(inputs[0], ndim)
shape_of_y = list_shape_of(inputs[1], ndim)
new_shape_x = [_op.const(1), shape_of_x[-2], shape_of_x[-1]]
new_shape_y = [_op.const(1), shape_of_y[-2], shape_of_y[-1]]
for i in range(ndim - 2):
Expand All @@ -1158,7 +1159,8 @@ def _impl(inputs, attr, params, mod):

input_x = _op.reshape(input_x, newshape=new_shape_x)
input_y = _op.reshape(input_y, newshape=new_shape_y)

elif ndim_y == 2:
input_y = _op.reshape(input_y, (1, orig_shape_y[-2], orig_shape_y[-1]))
adj_x = attr["adj_x"]
adj_y = attr["adj_y"]
input_x = _op.transpose(input_x, axes=[0, 2, 1]) if adj_x else input_x
Expand Down
17 changes: 17 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1858,6 +1858,9 @@ def test_forward_batch_matmul():
_test_batch_matmul((1, 2, 3, 4, 5, 6), (1, 2, 3, 4, 6, 5), "float32", True, True)
_test_batch_matmul((3, 4, 5, 6), (3, 4, 5, 6), "int32", True, False)
_test_batch_matmul((2, 3, 4, 2, 3, 4, 5, 6), (2, 3, 4, 2, 3, 4, 5, 6), "float32", False, True)
_test_batch_matmul((1, 8, 64, 2), (2, 1), "float32", False, False)
_test_batch_matmul((1, 8, 8, 64), (64, 1), "float32", False, False)
_test_batch_matmul((1, 8, 64), (64, 1), "float32", False, False)


@tvm.testing.requires_cuda
Expand Down Expand Up @@ -1885,6 +1888,20 @@ def test_forward_batch_matmul_dynamic():
(2, 3, 4, 6, 5),
"float32",
)
_test_batch_matmul_dynamic(
(None, None, None, 5, 6),
(6, None),
(2, 3, 4, 5, 6),
(6, 1),
"float32",
)
_test_batch_matmul_dynamic(
(None, 5, 6),
(6, None),
(24, 5, 6),
(6, 1),
"float32",
)


#######################################################################
Expand Down

0 comments on commit 7e09f1f

Please sign in to comment.