Skip to content

Commit

Permalink
[Op][Topi] 5 ops can accept unsigned integers as indices (#10098)
Browse files Browse the repository at this point in the history
* tests passed

* reformat

* add uint test for unravel_index
  • Loading branch information
yuanfz98 committed Jan 30, 2022
1 parent 538347e commit 1f9c76b
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 80 deletions.
18 changes: 12 additions & 6 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1105,7 +1105,8 @@ bool ScatterRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
if (updates == nullptr) {
return false;
}
ICHECK(indices->dtype.is_int()) << "indices of scatter must be tensor of integer";
ICHECK(indices->dtype.is_int() || indices->dtype.is_uint())
<< "indices of scatter must be tensor of integer";
const auto param = attrs.as<ScatterAttrs>();
ICHECK(param != nullptr);
reporter->Assign(types[3], TensorType(data->shape, data->dtype));
Expand Down Expand Up @@ -1152,7 +1153,8 @@ bool ScatterAddRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
if (updates == nullptr) {
return false;
}
ICHECK(indices->dtype.is_int()) << "indices of scatter_add must be tensor of integer";
ICHECK(indices->dtype.is_int() || indices->dtype.is_uint())
<< "indices of scatter_add must be tensor of integer";
const auto param = attrs.as<ScatterAddAttrs>();
ICHECK(param != nullptr);
reporter->Assign(types[3], TensorType(data->shape, data->dtype));
Expand Down Expand Up @@ -1204,7 +1206,8 @@ bool ScatterNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
<< "ScatterND: expect updates type to be TensorType but got " << types[2];
return false;
}
ICHECK(indices->dtype.is_int()) << "ScatterND: indices must be a tensor of integers.";
ICHECK(indices->dtype.is_int() || indices->dtype.is_uint())
<< "ScatterND: indices must be a tensor of integers.";

const auto out_shape = data->shape;
const IntImmNode* mdim = indices->shape[0].as<IntImmNode>();
Expand Down Expand Up @@ -3656,15 +3659,17 @@ bool UnRavelIndexRel(const Array<Type>& types, int num_inputs, const Attrs& attr
<< "unravel_index: expect input type to be TensorType but get " << types[0];
return false;
}
ICHECK(indices->dtype.is_int()) << "indices of unravel_index must be tensor of integer";
ICHECK(indices->dtype.is_int() || indices->dtype.is_uint())
<< "indices of unravel_index must be tensor of integer";

const auto* shape = types[1].as<TensorTypeNode>();
if (shape == nullptr) {
ICHECK(types[1].as<IncompleteTypeNode>())
<< "unravel_index: expect input type to be TensorType but get " << types[1];
return false;
}
ICHECK(indices->dtype.is_int()) << "shape of unravel_index must be tensor of integer";
ICHECK(shape->dtype.is_int() || shape->dtype.is_uint())
<< "shape of unravel_index must be tensor of integer";

Array<IndexExpr> indices_shape;
Array<IndexExpr> shape_shape;
Expand Down Expand Up @@ -3894,7 +3899,8 @@ bool AdvIndexRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
if (index_type == nullptr) {
return false;
}
ICHECK(index_type->dtype.is_int()) << "indices must be tensor of integers";
ICHECK(index_type->dtype.is_int() || index_type->dtype.is_uint())
<< "indices must be tensor of integers";

int64_t flatten_len = 1;
bool has_dyn_shape = false;
Expand Down
127 changes: 66 additions & 61 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,17 +999,17 @@ def ref_scatter(data, indices, updates, axis=0):


def test_scatter(target, dev, executor_kind):
def verify_scatter(dshape, ishape, axis=0):
def verify_scatter(dshape, ishape, axis=0, indices_dtype="int64"):
d = relay.var("d", relay.TensorType(dshape, "float32"))
i = relay.var("i", relay.TensorType(ishape, "int64"))
i = relay.var("i", relay.TensorType(ishape, indices_dtype))
u = relay.var("u", relay.TensorType(ishape, "float32"))
z = relay.op.scatter(d, i, u, axis)

func = relay.Function([d, i, u], z)

data_np = np.random.uniform(size=dshape).astype("float32")
updates_np = np.random.uniform(size=ishape).astype("float32")
indices_np = np.random.randint(-dshape[axis], dshape[axis] - 1, ishape).astype("int64")
indices_np = np.random.randint(0, dshape[axis] - 1, ishape).astype(indices_dtype)

ref_res = ref_scatter(data_np, indices_np, updates_np, axis)

Expand All @@ -1031,6 +1031,7 @@ def verify_scatter(dshape, ishape, axis=0):
verify_scatter((6, 3, 4, 5), (2, 3, 4, 5), 1)
verify_scatter((2, 3, 8, 5), (2, 3, 1, 1), 2)
verify_scatter((16, 16, 4, 5), (16, 16, 4, 5), 3)
verify_scatter((16, 16, 4, 5), (16, 16, 4, 5), 3, indices_dtype="uint32")


class TestDynamicScatter:
Expand Down Expand Up @@ -1073,27 +1074,28 @@ def test_dynamic_scatter(self, target, dev, executor_kind, dshape, ishape, axis)


class TestScatterAdd:
dshape, ishape, axis, dtype = tvm.testing.parameters(
((10,), (10,), 0, "int32"),
((1000,), (1000,), 0, "int32"),
((10, 5), (10, 5), -2, "float32"),
((10, 5), (10, 5), -1, "float32"),
((10, 5), (3, 5), 0, "float32"),
((12, 4), (7, 2), 1, "float32"),
((2, 3, 4), (1, 3, 4), 0, "float32"),
((2, 3, 4), (2, 1, 4), 1, "float32"),
((2, 3, 4), (2, 3, 1), 2, "float32"),
((2, 3, 4, 5), (1, 3, 4, 5), 0, "float32"),
((6, 3, 4, 5), (2, 3, 4, 5), 1, "float32"),
((2, 3, 8, 5), (2, 3, 1, 1), 2, "float32"),
((16, 16, 4, 5), (16, 16, 4, 5), 3, "float32"),
dshape, ishape, axis, dtype, indice_dtype = tvm.testing.parameters(
((10,), (10,), 0, "int32", "int64"),
((1000,), (1000,), 0, "int32", "int64"),
((10, 5), (10, 5), -2, "float32", "int64"),
((10, 5), (10, 5), -1, "float32", "int64"),
((10, 5), (3, 5), 0, "float32", "int64"),
((12, 4), (7, 2), 1, "float32", "int64"),
((2, 3, 4), (1, 3, 4), 0, "float32", "int64"),
((2, 3, 4), (2, 1, 4), 1, "float32", "int64"),
((2, 3, 4), (2, 3, 1), 2, "float32", "int64"),
((2, 3, 4, 5), (1, 3, 4, 5), 0, "float32", "int64"),
((6, 3, 4, 5), (2, 3, 4, 5), 1, "float32", "int64"),
((2, 3, 8, 5), (2, 3, 1, 1), 2, "float32", "int64"),
((16, 16, 4, 5), (16, 16, 4, 5), 3, "float32", "int64"),
((16, 16, 4, 5), (16, 16, 4, 5), 3, "float32", "uint32"),
)

@tvm.testing.fixture(cache_return_value=True)
def ref_data(self, dshape, ishape, axis, dtype):
def ref_data(self, dshape, ishape, axis, dtype, indice_dtype):
data_np = np.random.uniform(size=dshape).astype(dtype)
updates_np = np.random.uniform(size=ishape).astype(dtype)
indices_np = np.random.randint(-dshape[axis], dshape[axis] - 1, ishape).astype("int64")
indices_np = np.random.randint(0, dshape[axis] - 1, ishape).astype(indice_dtype)

out_np = np.copy(data_np)
for index in np.ndindex(*indices_np.shape):
Expand All @@ -1105,9 +1107,11 @@ def ref_data(self, dshape, ishape, axis, dtype):
# Optimization can produce tir.atomic_add, not currently supported
# on vulkan runtime.
@tvm.testing.known_failing_targets("vulkan")
def test_scatter_add(self, target, dev, ref_data, dshape, ishape, axis, dtype):
def test_scatter_add(self, target, dev, ref_data, dshape, ishape, axis, dtype, indice_dtype):
d = relay.var("d", relay.TensorType(shape=[relay.Any() for _ in dshape], dtype=dtype))
i = relay.var("i", relay.TensorType(shape=[relay.Any() for _ in ishape], dtype="int64"))
i = relay.var(
"i", relay.TensorType(shape=[relay.Any() for _ in ishape], dtype=indice_dtype)
)
u = relay.var("u", relay.TensorType(shape=[relay.Any() for _ in ishape], dtype=dtype))
z = relay.op.scatter_add(d, i, u, axis)

Expand Down Expand Up @@ -1955,47 +1959,48 @@ def verify_scatter_nd_with_stack(
)
tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=rtol, atol=atol)

data = np.zeros((2, 2)).astype("int64")
indices = np.array([[1, 1, 0], [0, 1, 0]])
updates = np.array([2, 3, 0])
out = np.array([[0, 0], [2, 3]])
verify_scatter_nd(data, indices, updates, out)
verify_scatter_nd_with_stack(data, indices, updates, out)

data = np.zeros((2, 2, 2, 2)).astype("int64")
indices = np.array([[0, 1], [1, 1]])
updates = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
out = np.array([[[[0, 0], [0, 0]], [[1, 2], [3, 4]]], [[[0, 0], [0, 0]], [[5, 6], [7, 8]]]])
verify_scatter_nd(data, indices, updates, out)
verify_scatter_nd_with_stack(data, indices, updates, out)

indices = np.array([[1, 0, 0]])
updates = np.reshape(np.arange(1560 * 3), (3, 1560)).astype("float32")
shape = (2, 1560)
data = np.zeros(shape).astype("float32")
out = data.copy()
out[1, :] += updates[0, :]
out[0, :] += updates[1, :]
out[0, :] += updates[2, :]
verify_scatter_nd(data, indices, updates, out, mode="add")
verify_scatter_nd_with_stack(data, indices, updates, out)

for mode in ["add", "update"]:
indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype(
"int64"
)
updates = np.ones((5, 3)).astype("float64")
shape = (2, 7, 3)
data = np.random.random(shape).astype("float64")
for indice_dtype in ["uint8", "uint16", "uint32"]:
data = np.zeros((2, 2)).astype("int64")
indices = np.array([[1, 1, 0], [0, 1, 0]]).astype(indice_dtype)
updates = np.array([2, 3, 0])
out = np.array([[0, 0], [2, 3]])
verify_scatter_nd(data, indices, updates, out)
verify_scatter_nd_with_stack(data, indices, updates, out)

data = np.zeros((2, 2, 2, 2)).astype("int64")
indices = np.array([[0, 1], [1, 1]]).astype(indice_dtype)
updates = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
out = np.array([[[[0, 0], [0, 0]], [[1, 2], [3, 4]]], [[[0, 0], [0, 0]], [[5, 6], [7, 8]]]])
verify_scatter_nd(data, indices, updates, out)
verify_scatter_nd_with_stack(data, indices, updates, out)

indices = np.array([[1, 0, 0]]).astype(indice_dtype)
updates = np.reshape(np.arange(1560 * 3), (3, 1560)).astype("float32")
shape = (2, 1560)
data = np.zeros(shape).astype("float32")
out = data.copy()
for i in range(indices.shape[1]):
for j in range(updates.shape[1]):
if mode == "add":
out[indices[0, i], indices[1, i], j] += updates[i, j]
elif mode == "update":
out[indices[0, i], indices[1, i], j] = updates[i, j]
verify_scatter_nd(data, indices, updates, out, mode)
verify_scatter_nd_with_stack(data, indices, updates, out, mode)
out[1, :] += updates[0, :]
out[0, :] += updates[1, :]
out[0, :] += updates[2, :]
verify_scatter_nd(data, indices, updates, out, mode="add")
verify_scatter_nd_with_stack(data, indices, updates, out)

for mode in ["add", "update"]:
indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype(
indice_dtype
)
updates = np.ones((5, 3)).astype("float64")
shape = (2, 7, 3)
data = np.random.random(shape).astype("float64")
out = data.copy()
for i in range(indices.shape[1]):
for j in range(updates.shape[1]):
if mode == "add":
out[indices[0, i], indices[1, i], j] += updates[i, j]
elif mode == "update":
out[indices[0, i], indices[1, i], j] = updates[i, j]
verify_scatter_nd(data, indices, updates, out, mode)
verify_scatter_nd_with_stack(data, indices, updates, out, mode)


def test_unique(target, dev):
Expand Down
30 changes: 17 additions & 13 deletions tests/python/topi/python/test_topi_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,15 +674,15 @@ def check_device(target, dev):
check_device(target, dev)


def verify_unravel_index(indices, shape, dtype):
x_data = np.array(indices).astype(dtype)
def verify_unravel_index(indices, shape, dtype, indice_dtype="int64"):
x_data = np.array(indices).astype(indice_dtype)
y_data = np.array(shape).astype(dtype)
if len(x_data.shape) == 1:
dst_shape = [y_data.shape[0], x_data.shape[0]]
else:
dst_shape = [y_data.shape[0]]

X = te.placeholder(shape=x_data.shape, dtype=dtype, name="X")
X = te.placeholder(shape=x_data.shape, dtype=indice_dtype, name="X")
Y = te.placeholder(shape=y_data.shape, dtype=dtype, name="Y")
Z = topi.unravel_index(X, Y)

Expand Down Expand Up @@ -771,16 +771,18 @@ def check_device(target, dev):
check_device(target, dev)


def verify_adv_index(data_shape, index_shapes):
def verify_adv_index(data_shape, index_shapes, indice_dtype="int64"):
dtype = "float32"
data = te.placeholder(shape=data_shape, name="data", dtype=dtype)
indices = []
np_data = np.random.uniform(size=data_shape).astype(dtype)
np_indices = []
for i, index_shape in enumerate(index_shapes):
limit = data_shape[i]
np_indices.append(np.random.uniform(0, limit - 1, size=index_shape).astype("int64"))
indices.append(te.placeholder(shape=index_shape, name="index_{}".format(i), dtype="int64"))
np_indices.append(np.random.uniform(0, limit - 1, size=index_shape).astype(indice_dtype))
indices.append(
te.placeholder(shape=index_shape, name="index_{}".format(i), dtype=indice_dtype)
)
np_out = np_data[tuple(np_indices)]
out = topi.adv_index(data, indices)

Expand Down Expand Up @@ -1207,10 +1209,11 @@ def test_one_hot():
@tvm.testing.uses_gpu
def test_unravel_index():
for dtype in ["int32", "int64"]:
verify_unravel_index([0, 1, 2, 3], [2, 2], dtype)
verify_unravel_index([144], [5, 5, 5, 2], dtype)
verify_unravel_index(144, [5, 5, 5, 2], dtype)
verify_unravel_index([100, 13, 5], [5, 5, 5, 2], dtype)
for indice_dtype in ["int64", "uint8", "uint16", "uint32"]:
verify_unravel_index([0, 1, 2, 3], [2, 2], dtype, indice_dtype)
verify_unravel_index([144], [5, 5, 5, 2], dtype, indice_dtype)
verify_unravel_index(144, [5, 5, 5, 2], dtype, indice_dtype)
verify_unravel_index([100, 13, 5], [5, 5, 5, 2], dtype, indice_dtype)


@tvm.testing.uses_gpu
Expand Down Expand Up @@ -1254,9 +1257,10 @@ def test_matrix_set_diag():

@tvm.testing.uses_gpu
def test_adv_index():
verify_adv_index((3, 4, 5), [(2,), (2,), (1,)])
verify_adv_index((10, 15, 5), [(1, 1), (2, 7)])
verify_adv_index((10, 5, 15), [(1, 2, 1), (1, 2, 7)])
for indice_dtype in ["int32", "int64", "uint8", "uint16", "uint32"]:
verify_adv_index((3, 4, 5), [(2,), (2,), (1,)], indice_dtype=indice_dtype)
verify_adv_index((10, 15, 5), [(1, 1), (2, 7)], indice_dtype=indice_dtype)
verify_adv_index((10, 5, 15), [(1, 2, 1), (1, 2, 7)], indice_dtype=indice_dtype)


if __name__ == "__main__":
Expand Down

0 comments on commit 1f9c76b

Please sign in to comment.