diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index ecd63f60f2b7..0a6654206006 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -1132,14 +1132,12 @@ def unique_shape_func(attrs, inputs, _): @script def _gather_nd_shape(data_shape, indices_shape, batch_dims, gather_dim): ndim = data_shape.shape[0] - mdim = gather_dim # using mdim = indices_shape[0] wouldn't work because a rank cannot # depend on a runtime shape dimension of indices tensor, even if the # dimension is always a known, fixed value. As a workaround, we assume that # the fixed gather dimension (the size of an indexing tuple) is recorded # in `gather_nd` op attribute. - err_msg = "The recorded gather dimension and the actual dimension are different" - assert mdim == indices_shape[0], err_msg + mdim = gather_dim kdim = indices_shape.shape[0] - 1 out_shape = output_tensor((kdim + ndim - (mdim + batch_dims),), "int64") for i in range(1, kdim + 1): @@ -1154,7 +1152,7 @@ def gather_nd_shape_func(attrs, inputs, _): """ Shape func for ghater_nd operator. """ - batch_dims = get_const_int(attrs.batch_dimss) + batch_dims = get_const_int(attrs.batch_dims) gather_dim = get_const_int(attrs.gather_dim) assert gather_dim > 0, "gather_dim needs to be specified for dynamic gather_nd" return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(gather_dim))] diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index fd6d7a9aeb14..07955943e341 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -26,6 +26,7 @@ from tvm.error import TVMError from tvm.relay import create_executor, transform from tvm.relay.testing import check_grad, run_infer_type +from utils import ref_funcs def test_zeros_ones(): @@ -1266,26 +1267,7 @@ def verify_gather_nd(xshape, yshape, y_data, batch_dims=0): else: y_data = np.random.randint(low=0, high=2, size=yshape, dtype="int32") - def gather_nd_batch_dims_1_ref(data, indices): - res = [] - for i, row in enumerate(data): - indices_tuple = tuple(indices[:, i]) # the indices for the i-th batch - res.append(row[indices_tuple]) - # stack on the batch dim - return np.stack(res, 0) - - if batch_dims > 1: - x_data_reshape = np.reshape(x_data, (-1,) + xshape[batch_dims:]) - y_data_reshape = np.reshape(y_data, (yshape[0], -1) + yshape[(batch_dims + 1) :]) - - ref_res = gather_nd_batch_dims_1_ref(x_data_reshape, y_data_reshape) - - out_shape = yshape[1 : (batch_dims + 1)] + ref_res.shape[1:] - ref_res = np.reshape(ref_res, out_shape) - elif batch_dims == 1: - ref_res = gather_nd_batch_dims_1_ref(x_data, y_data) - else: - ref_res = x_data[tuple(y_data)] + ref_res = ref_funcs.gather_nd(x_data, y_data, batch_dims) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: diff --git a/tests/python/relay/utils/ref_funcs.py b/tests/python/relay/utils/ref_funcs.py new file mode 100644 index 000000000000..924805b2295e --- /dev/null +++ b/tests/python/relay/utils/ref_funcs.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np + + +def gather_nd(data_np, indices_np, batch_dims=0): + """gather_nd implemented using numpy""" + data_shape = data_np.shape + indices_shape = indices_np.shape + + def gather_nd_batch_dims_1_ref(data, indices): + res = [] + for i, row in enumerate(data): + indices_tuple = tuple(indices[:, i]) # the indices for the i-th batch + res.append(row[indices_tuple]) + # stack on the batch dim + return np.stack(res, 0) + + if batch_dims > 1: + data_np_reshape = np.reshape(data_np, (-1,) + data_shape[batch_dims:]) + indices_np_reshape = np.reshape( + indices_np, (indices_shape[0], -1) + indices_shape[(batch_dims + 1) :] + ) + + ref_res = gather_nd_batch_dims_1_ref(data_np_reshape, indices_np_reshape) + + out_shape = indices_shape[1 : (batch_dims + 1)] + ref_res.shape[1:] + ref_res = np.reshape(ref_res, out_shape) + elif batch_dims == 1: + ref_res = gather_nd_batch_dims_1_ref(data_np, indices_np) + else: + ref_res = data_np[tuple(indices_np)] + + return ref_res