diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index d11c835cfe99..0fe29f315b43 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks """Scatter operator""" -from ..tir import decl_buffer, ir_builder, AssertStmt, StringImm, Evaluate +from ..tir import decl_buffer, ir_builder, AssertStmt, StringImm, Evaluate, expr from ..te import extern, hybrid @@ -200,20 +200,22 @@ def scatter(data, indices, updates, axis=0): def _verify_scatter_nd_inputs(data, indices, updates): - # TODO(masahi): revisit - return mdim = int(indices.shape[0]) assert mdim <= len(data.shape), ( f"The first dimension of the indices ({mdim}) must be less than or equal to " f"the length of the shape of the output ({len(shape)})." ) for i in range(len(indices.shape) - 1): + if isinstance(indices.shape[i + 1], expr.Var) or isinstance(updates.shape[i], expr.Var): + continue assert indices.shape[i + 1] == updates.shape[i], ( f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of " f"updates[{i}] ({updates.shape[i]})." ) for i in range(mdim, len(data.shape)): data_ind = i - mdim + len(indices.shape) - 1 + if isinstance(updates.shape[data_ind], expr.Var) or isinstance(data.shape[i], expr.Var): + continue assert updates.shape[data_ind] == data.shape[i], ( f"Dimension of updates[{data_ind}] ({updates.shape[data_ind]}) must equal dimension " f"of out_shape[{i}] ({data.shape[i]})." diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 2add0739b901..8016e435618a 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -1728,5 +1728,28 @@ def verify_gather_nd(data_shape, indices_shape, data_shape_np, indices_shape_np, ) +@tvm.testing.uses_gpu +def test_scatter_nd(): + def verify_scatter_nd(data_np, indices_np, updates_np, ref_res): + indices_shape = (2, relay.Any()) + updates_shape = (relay.Any(),) + data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype)) + indices = relay.var("indices", relay.TensorType(indices_shape, str(indices_np.dtype))) + updates = relay.var("updates", relay.TensorType(updates_shape, str(updates_np.dtype))) + + out = relay.op.scatter_nd(data, indices, updates, "add") + + mod = tvm.IRModule() + mod["main"] = relay.Function([data, indices, updates], out) + + check_result([data_np, indices_np, updates_np], mod, [ref_res]) + + data = np.zeros((2, 2)).astype("int64") + indices = np.array([[1, 1, 0], [0, 1, 0]]) + updates = np.array([2, 3, 0]) + out = np.array([[0, 0], [2, 3]]) + verify_scatter_nd(data, indices, updates, out) + + if __name__ == "__main__": pytest.main([__file__])