From 510bce6a181604e5eb3f2bd1951ae035a4090700 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 24 May 2021 12:52:30 +0900 Subject: [PATCH] refactoring dynamic slice --- include/tvm/topi/transform.h | 92 +++++++++++++++++++++--------------- python/tvm/topi/cuda/sort.py | 1 - 2 files changed, 53 insertions(+), 40 deletions(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index c2a973b8bbf0..8234d4c69486 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -570,6 +570,50 @@ inline te::Tensor strided_slice_compute_common(const te::Tensor& x, name, tag); } +inline Tensor dynamic_strided_slice(const Tensor& x, const Array& begin, + const Array& end, const Array& strides, + std::string name = "T_dynamic_strided_slice", + std::string tag = kInjective) { + const size_t src_tensor_dim = static_cast(x->shape.size()); + ICHECK_LE(begin.size(), src_tensor_dim); + ICHECK_LE(end.size(), src_tensor_dim); + ICHECK_LE(strides.size(), src_tensor_dim); + ICHECK_EQ(begin.size(), end.size()); + ICHECK_EQ(begin.size(), strides.size()); + + const size_t num_slice_axes = begin.size(); + Array out_shape; + + for (size_t i = 0; i < num_slice_axes; ++i) { + auto d = indexdiv(end[i] - begin[i], strides[i]); + if (d->IsInstance()) { + // Preserve static dimension if possible + out_shape.push_back(d); + } else { + out_shape.push_back(tvm::tir::Var("dim")); + } + } + + for (size_t i = num_slice_axes; i < src_tensor_dim; ++i) { + out_shape.push_back(x->shape[i]); + } + + return te::compute( + out_shape, + [&](const Array& indices) { + Array real_indices; + for (int32_t i = 0; i < num_slice_axes; ++i) { + real_indices.push_back(indices[i] * strides[i] + tvm::min(begin[i], x->shape[i] - 1)); + } + // keep input dim + for (int32_t i = num_slice_axes; i < src_tensor_dim; ++i) { + real_indices.push_back(indices[i]); + } + return x(real_indices); + }, + name, tag); +} + /*! * \brief strided_slice of a tensor with dynamic begin/end/stride * @@ -587,48 +631,18 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b const te::Tensor& end, const te::Tensor& strides, std::string name = "T_strided_slice_dynamic", std::string tag = topi::kInjective) { - int64_t src_tensor_dim = x->shape.size(); - Array out_shape; const int64_t num_dynamic_axes = begin->shape[0].as()->value; - for (int64_t i = 0; i < num_dynamic_axes; ++i) { - out_shape.push_back(tvm::tir::Var("dim")); - } - for (int64_t i = num_dynamic_axes; i < src_tensor_dim; ++i) { - out_shape.push_back(x->shape[i]); - } - return te::compute( - out_shape, - [&](const Array& indices) { - Array real_indices; - // dynamic slicing - for (int32_t i = 0; i < num_dynamic_axes; ++i) { - real_indices.push_back(indices[i] * strides(i) + tvm::min(begin(i), x->shape[i] - 1)); - } - // keep input dim - for (int32_t i = num_dynamic_axes; i < src_tensor_dim; ++i) { - real_indices.push_back(indices[i]); - } - return x(real_indices); - }, - name, tag); -} - -inline Tensor dynamic_strided_slice(const Tensor& x, const Array& begin, - const Array& end, const Array& strides, - std::string name = "T_dynamic_strided_slice", - std::string tag = kInjective) { - size_t src_tensor_dim = static_cast(x->shape.size()); - ICHECK_EQ(begin.size(), src_tensor_dim); - ICHECK_EQ(end.size(), src_tensor_dim); - ICHECK_EQ(strides.size(), src_tensor_dim); + ICHECK_EQ(end->shape[0].as()->value, num_dynamic_axes); + ICHECK_EQ(strides->shape[0].as()->value, num_dynamic_axes); - Array out_shape; - Array axes; - for (size_t i = 0; i < src_tensor_dim; ++i) { - out_shape.push_back(indexdiv(end[i] - begin[i], strides[i])); - axes.push_back(i); + Array begin_expr, end_expr, strides_expr; + for (int64_t i = 0; i < num_dynamic_axes; ++i) { + auto i64_ind = IntImm(DataType::Int(64), i); + begin_expr.push_back(begin(i64_ind)); + end_expr.push_back(end(i64_ind)); + strides_expr.push_back(strides(i64_ind)); } - return strided_slice_compute_common(x, out_shape, begin, strides, axes, name, tag); + return dynamic_strided_slice(x, begin_expr, end_expr, strides_expr, name, tag); } inline Tensor strided_slice_dynamic_input(const Tensor& x, const Array& begin, diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index a9ad55c72c81..25cc7a4e2cfb 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -962,7 +962,6 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"): end.append(dshape[i]) if ret_type == "both": values_out, indices_out = output - print("end:", end, k) values_out = strided_slice(values_out, beg, end, strides) indices_out = strided_slice(indices_out, beg, end, strides) output = [values_out, indices_out]