Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TE] Support negative indices #9023

Merged
merged 8 commits into from
Jan 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions include/tvm/te/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@ class TensorNode : public DataProducerNode {
* or intermediate computation result.
*/
class Tensor : public DataProducer {
private:
/*!
* \brief Helper for indexing operations into tensors
* \param indices The indices
* \param support_negative_indices Whether to normalize indices in the case of negative indices.
* \return the result expression representing tensor read.
*/
inline PrimExpr IndexTensor(Array<PrimExpr> indices, bool support_negative_indices) const;

public:
TVM_DLL Tensor(Array<PrimExpr> shape, DataType dtype, Operation op, int value_index);
/*!
Expand Down Expand Up @@ -138,6 +147,29 @@ class Tensor : public DataProducer {
* \return the result expression representing tensor read.
*/
TVM_DLL PrimExpr operator()(Array<Var> indices) const;
/*!
* \brief Take elements from the tensor with support for negative indices.
* \param args The indices
* \return the result expression representing tensor read.
*/
template <typename... Args>
TVM_DLL PrimExpr IndexWithNegativeIndices(Args&&... args) const {
Array<PrimExpr> indices{std::forward<Args>(args)...};
return IndexWithNegativeIndices(indices);
}
/*!
* \brief Take elements from the tensor with support for negative indices.
* \param indices the indices.
* \return the result expression representing tensor read.
*/
TVM_DLL PrimExpr IndexWithNegativeIndices(Array<PrimExpr> indices) const;
/*!
* \brief Take elements from the tensor with support for negative indices.
* \param indices the indices.
* \return the result expression representing tensor read.
*/
TVM_DLL PrimExpr IndexWithNegativeIndices(Array<Var> indices) const;

/*!
* \brief data structure to represent a slice that fixes first k coordinates.
* This is used to enable syntax sugar of Tensor[x][y][z] to get the element.
Expand Down
33 changes: 27 additions & 6 deletions src/te/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,39 @@ IterVar reduce_axis(Range dom, std::string name) { return IterVar(dom, Var(name)
Var var(std::string name_hint, DataType t) { return Var(name_hint, t); }

// Tensor
inline PrimExpr Tensor::IndexTensor(Array<PrimExpr> indices, bool support_negative_indices) const {
Array<PrimExpr> shape = (*this)->shape;

if (shape.size() != 0) {
ICHECK_EQ(shape.size(), indices.size())
<< "Tensor dimension mismatch in read "
<< "ndim = " << ndim() << ", indices.size=" << indices.size();
}

if (support_negative_indices) {
for (size_t i = 0; i < shape.size(); i++) {
PrimExpr new_index =
Select(indices[i] < make_const(indices[i]->dtype, 0), indices[i] + shape[i], indices[i]);
indices.Set(i, new_index);
}
}
return ProducerLoad((*this), indices);
}

PrimExpr Tensor::operator()(Array<Var> indices) const {
Array<PrimExpr> arr(indices.begin(), indices.end());
return operator()(arr);
}

PrimExpr Tensor::operator()(Array<PrimExpr> indices) const {
if (ndim() != 0) {
ICHECK_EQ(ndim(), indices.size()) << "Tensor dimension mismatch in read "
<< "ndim = " << ndim() << ", indices.size=" << indices.size();
}
PrimExpr Tensor::operator()(Array<PrimExpr> indices) const { return IndexTensor(indices, false); }

return ProducerLoad((*this), indices);
PrimExpr Tensor::IndexWithNegativeIndices(Array<Var> indices) const {
Array<PrimExpr> arr(indices.begin(), indices.end());
return IndexWithNegativeIndices(arr);
}

PrimExpr Tensor::IndexWithNegativeIndices(Array<PrimExpr> indices) const {
return IndexTensor(indices, true);
}

String TensorNode::GetNameHint() const {
Expand Down
11 changes: 11 additions & 0 deletions tests/cpp/tensor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,14 @@ TEST(Tensor, Reduce) {
{m, n}, [&](Var i, Var j) { return sum(max(1 + A[i][rv] + 1, B[j][rv]), {rv}); }, "C");
LOG(INFO) << C->op.as<te::ComputeOpNode>()->body;
}

TEST(Tensor, Indexing) {
using namespace tvm;
using namespace tvm::te;

Var x("x"), y("y");
te::Tensor A = te::placeholder({x, y}, DataType::Float(32), "A");
LOG(INFO) << A(0, 0);
LOG(INFO) << A.IndexWithNegativeIndices(-1, -1);
LOG(INFO) << A.IndexWithNegativeIndices(0, -1);
}