From b642719c6ead42e713ada54e6ca7aa8146ca6d6f Mon Sep 17 00:00:00 2001 From: AndrewZhaoLuo Date: Wed, 12 Jan 2022 12:56:59 -0800 Subject: [PATCH] [TE] Support negative indices (#9023) * initial change * more explicit api * switch to select * add support for negative indices * reduce things further * lint * to CamelCase * unit test Co-authored-by: Andrew Zhao Luo --- include/tvm/te/tensor.h | 32 ++++++++++++++++++++++++++++++++ src/te/tensor.cc | 33 +++++++++++++++++++++++++++------ tests/cpp/tensor_test.cc | 11 +++++++++++ 3 files changed, 70 insertions(+), 6 deletions(-) diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h index 85677a7265743..30480e1508231 100644 --- a/include/tvm/te/tensor.h +++ b/include/tvm/te/tensor.h @@ -100,6 +100,15 @@ class TensorNode : public DataProducerNode { * or intermediate computation result. */ class Tensor : public DataProducer { + private: + /*! + * \brief Helper for indexing operations into tensors + * \param indices The indices + * \param support_negative_indices Whether to normalize indices in the case of negative indices. + * \return the result expression representing tensor read. + */ + inline PrimExpr IndexTensor(Array indices, bool support_negative_indices) const; + public: TVM_DLL Tensor(Array shape, DataType dtype, Operation op, int value_index); /*! @@ -138,6 +147,29 @@ class Tensor : public DataProducer { * \return the result expression representing tensor read. */ TVM_DLL PrimExpr operator()(Array indices) const; + /*! + * \brief Take elements from the tensor with support for negative indices. + * \param args The indices + * \return the result expression representing tensor read. + */ + template + TVM_DLL PrimExpr IndexWithNegativeIndices(Args&&... args) const { + Array indices{std::forward(args)...}; + return IndexWithNegativeIndices(indices); + } + /*! + * \brief Take elements from the tensor with support for negative indices. + * \param indices the indices. + * \return the result expression representing tensor read. + */ + TVM_DLL PrimExpr IndexWithNegativeIndices(Array indices) const; + /*! + * \brief Take elements from the tensor with support for negative indices. + * \param indices the indices. + * \return the result expression representing tensor read. + */ + TVM_DLL PrimExpr IndexWithNegativeIndices(Array indices) const; + /*! * \brief data structure to represent a slice that fixes first k coordinates. * This is used to enable syntax sugar of Tensor[x][y][z] to get the element. diff --git a/src/te/tensor.cc b/src/te/tensor.cc index b48f39a38627e..1d75761216f1e 100644 --- a/src/te/tensor.cc +++ b/src/te/tensor.cc @@ -39,18 +39,39 @@ IterVar reduce_axis(Range dom, std::string name) { return IterVar(dom, Var(name) Var var(std::string name_hint, DataType t) { return Var(name_hint, t); } // Tensor +inline PrimExpr Tensor::IndexTensor(Array indices, bool support_negative_indices) const { + Array shape = (*this)->shape; + + if (shape.size() != 0) { + ICHECK_EQ(shape.size(), indices.size()) + << "Tensor dimension mismatch in read " + << "ndim = " << ndim() << ", indices.size=" << indices.size(); + } + + if (support_negative_indices) { + for (size_t i = 0; i < shape.size(); i++) { + PrimExpr new_index = + Select(indices[i] < make_const(indices[i]->dtype, 0), indices[i] + shape[i], indices[i]); + indices.Set(i, new_index); + } + } + return ProducerLoad((*this), indices); +} + PrimExpr Tensor::operator()(Array indices) const { Array arr(indices.begin(), indices.end()); return operator()(arr); } -PrimExpr Tensor::operator()(Array indices) const { - if (ndim() != 0) { - ICHECK_EQ(ndim(), indices.size()) << "Tensor dimension mismatch in read " - << "ndim = " << ndim() << ", indices.size=" << indices.size(); - } +PrimExpr Tensor::operator()(Array indices) const { return IndexTensor(indices, false); } - return ProducerLoad((*this), indices); +PrimExpr Tensor::IndexWithNegativeIndices(Array indices) const { + Array arr(indices.begin(), indices.end()); + return IndexWithNegativeIndices(arr); +} + +PrimExpr Tensor::IndexWithNegativeIndices(Array indices) const { + return IndexTensor(indices, true); } String TensorNode::GetNameHint() const { diff --git a/tests/cpp/tensor_test.cc b/tests/cpp/tensor_test.cc index a50af838f735d..e53f6d05a9911 100644 --- a/tests/cpp/tensor_test.cc +++ b/tests/cpp/tensor_test.cc @@ -49,3 +49,14 @@ TEST(Tensor, Reduce) { {m, n}, [&](Var i, Var j) { return sum(max(1 + A[i][rv] + 1, B[j][rv]), {rv}); }, "C"); LOG(INFO) << C->op.as()->body; } + +TEST(Tensor, Indexing) { + using namespace tvm; + using namespace tvm::te; + + Var x("x"), y("y"); + te::Tensor A = te::placeholder({x, y}, DataType::Float(32), "A"); + LOG(INFO) << A(0, 0); + LOG(INFO) << A.IndexWithNegativeIndices(-1, -1); + LOG(INFO) << A.IndexWithNegativeIndices(0, -1); +}