Skip to content

Commit

Permalink
fix pylint
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 27, 2021
1 parent 7063a09 commit 9bcb2ad
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,7 +1141,7 @@ def _gather_nd_shape(data_shape, indices_shape, batch_dims, num_indices_per_tupl
kdim = indices_shape.shape[0] - 1
out_shape = output_tensor((kdim + ndim - (mdim + batch_dims),), "int64")
for i in range(1, kdim + 1):
out_shape[i-1] = indices_shape[i]
out_shape[i - 1] = indices_shape[i]
for i in range(mdim + batch_dims, ndim):
out_shape[kdim + i - (mdim + batch_dims)] = data_shape[i]
return out_shape
Expand All @@ -1154,5 +1154,11 @@ def gather_nd_shape_func(attrs, inputs, _):
"""
batch_dims = get_const_int(attrs.batch_dims)
num_indices_per_tuple = get_const_int(attrs.num_indices_per_tuple)
assert num_indices_per_tuple > 0, "num_indices_per_tuple needs to be specified for dynamic gather_nd"
return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(num_indices_per_tuple))]

assert (
num_indices_per_tuple > 0
), "num_indices_per_tuple needs to be specified for dynamic gather_nd"

return [
_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(num_indices_per_tuple))
]

0 comments on commit 9bcb2ad

Please sign in to comment.