Skip to content

Commit

Permalink
gather_dim -> num_indices_per_tuple
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 27, 2021
1 parent a489375 commit d4a4db8
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 14 deletions.
4 changes: 2 additions & 2 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,11 @@ struct GatherAttrs : public tvm::AttrsNode<GatherAttrs> {

struct GatherNDAttrs : public tvm::AttrsNode<GatherNDAttrs> {
Integer batch_dims;
Integer gather_dim;
Integer num_indices_per_tuple;

TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherNDAttrs") {
TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of batch dimensions.");
TVM_ATTR_FIELD(gather_dim)
TVM_ATTR_FIELD(num_indices_per_tuple)
.set_default(Integer(-1))
.describe(
"The size of an indexing tuple, which is a fixed value. Only needed when the number of "
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1436,8 +1436,8 @@ def _impl_common(cls, data, indices, batch_dims=0):
indices_dims = len(infer_shape(indices))
indices_shape = infer_shape(indices)
indices = _op.transpose(indices, axes=[-1] + list(range(indices_dims - 1)))
gather_dim = indices_shape[-1]
return _op.gather_nd(data, indices, batch_dims, gather_dim)
num_indices_per_tuple = indices_shape[-1]
return _op.gather_nd(data, indices, batch_dims, num_indices_per_tuple)

@classmethod
def _impl_v1(cls, inputs, attr, params):
Expand Down
10 changes: 5 additions & 5 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,14 +1130,14 @@ def unique_shape_func(attrs, inputs, _):


@script
def _gather_nd_shape(data_shape, indices_shape, batch_dims, gather_dim):
def _gather_nd_shape(data_shape, indices_shape, batch_dims, num_indices_per_tuple):
ndim = data_shape.shape[0]
# using mdim = indices_shape[0] wouldn't work because a rank cannot
# depend on a runtime shape dimension of indices tensor, even if the
# dimension is always a known, fixed value. As a workaround, we assume that
# the fixed gather dimension (the size of an indexing tuple) is recorded
# in `gather_nd` op attribute.
mdim = gather_dim
mdim = num_indices_per_tuple
kdim = indices_shape.shape[0] - 1
out_shape = output_tensor((kdim + ndim - (mdim + batch_dims),), "int64")
for i in range(1, kdim + 1):
Expand All @@ -1153,6 +1153,6 @@ def gather_nd_shape_func(attrs, inputs, _):
Shape func for ghater_nd operator.
"""
batch_dims = get_const_int(attrs.batch_dims)
gather_dim = get_const_int(attrs.gather_dim)
assert gather_dim > 0, "gather_dim needs to be specified for dynamic gather_nd"
return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(gather_dim))]
num_indices_per_tuple = get_const_int(attrs.num_indices_per_tuple)
assert num_indices_per_tuple > 0, "num_indices_per_tuple needs to be specified for dynamic gather_nd"
return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(num_indices_per_tuple))]
6 changes: 3 additions & 3 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,7 +1075,7 @@ def gather(data, axis, indices):
return _make.gather(data, axis, indices)


def gather_nd(data, indices, batch_dims=0, gather_dim=-1):
def gather_nd(data, indices, batch_dims=0, num_indices_per_tuple=-1):
"""Gather elements or slices from data and store to a tensor whose shape is
defined by indices.
Expand All @@ -1090,7 +1090,7 @@ def gather_nd(data, indices, batch_dims=0, gather_dim=-1):
batch_dims : int
The number of batch dimensions.
gather_dim : int
num_indices_per_tuple : int
The size of an indexing tuple, which is a fixed value and the same as indices.shape[0]
Only needed when other dimensions of indices are dynamic.
Expand All @@ -1115,7 +1115,7 @@ def gather_nd(data, indices, batch_dims=0, gather_dim=-1):
indices = [[1, 0]]
relay.gather_nd(data, indices, batch_dims=1) = [[2,3],[4,5]]
"""
return _make.gather_nd(data, indices, batch_dims, gather_dim)
return _make.gather_nd(data, indices, batch_dims, num_indices_per_tuple)


def sequence_mask(data, valid_length, mask_value=0, axis=0):
Expand Down
4 changes: 2 additions & 2 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3600,11 +3600,11 @@ Array<te::Tensor> GatherNDCompute(const Attrs& attrs, const Array<te::Tensor>& i
return {topi::gather_nd(inputs[0], inputs[1], param->batch_dims)};
}

Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0, int gather_dim = -1) {
Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0, int num_indices_per_tuple = -1) {
static const Op& op = Op::Get("gather_nd");
auto attrs = make_object<GatherNDAttrs>();
attrs->batch_dims = batch_dims;
attrs->gather_dim = gather_dim;
attrs->num_indices_per_tuple = num_indices_per_tuple;
return Call(op, {data, indices}, Attrs(attrs));
}

Expand Down

0 comments on commit d4a4db8

Please sign in to comment.