Skip to content

Commit

Permalink
refactoring dynamic slice
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 27, 2021
1 parent 1b3969a commit 510bce6
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 40 deletions.
92 changes: 53 additions & 39 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,50 @@ inline te::Tensor strided_slice_compute_common(const te::Tensor& x,
name, tag);
}

inline Tensor dynamic_strided_slice(const Tensor& x, const Array<PrimExpr>& begin,
const Array<PrimExpr>& end, const Array<PrimExpr>& strides,
std::string name = "T_dynamic_strided_slice",
std::string tag = kInjective) {
const size_t src_tensor_dim = static_cast<size_t>(x->shape.size());
ICHECK_LE(begin.size(), src_tensor_dim);
ICHECK_LE(end.size(), src_tensor_dim);
ICHECK_LE(strides.size(), src_tensor_dim);
ICHECK_EQ(begin.size(), end.size());
ICHECK_EQ(begin.size(), strides.size());

const size_t num_slice_axes = begin.size();
Array<PrimExpr> out_shape;

for (size_t i = 0; i < num_slice_axes; ++i) {
auto d = indexdiv(end[i] - begin[i], strides[i]);
if (d->IsInstance<tvm::IntImmNode>()) {
// Preserve static dimension if possible
out_shape.push_back(d);
} else {
out_shape.push_back(tvm::tir::Var("dim"));
}
}

for (size_t i = num_slice_axes; i < src_tensor_dim; ++i) {
out_shape.push_back(x->shape[i]);
}

return te::compute(
out_shape,
[&](const Array<tvm::tir::Var>& indices) {
Array<PrimExpr> real_indices;
for (int32_t i = 0; i < num_slice_axes; ++i) {
real_indices.push_back(indices[i] * strides[i] + tvm::min(begin[i], x->shape[i] - 1));
}
// keep input dim
for (int32_t i = num_slice_axes; i < src_tensor_dim; ++i) {
real_indices.push_back(indices[i]);
}
return x(real_indices);
},
name, tag);
}

/*!
* \brief strided_slice of a tensor with dynamic begin/end/stride
*
Expand All @@ -587,48 +631,18 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b
const te::Tensor& end, const te::Tensor& strides,
std::string name = "T_strided_slice_dynamic",
std::string tag = topi::kInjective) {
int64_t src_tensor_dim = x->shape.size();
Array<PrimExpr> out_shape;
const int64_t num_dynamic_axes = begin->shape[0].as<IntImmNode>()->value;
for (int64_t i = 0; i < num_dynamic_axes; ++i) {
out_shape.push_back(tvm::tir::Var("dim"));
}
for (int64_t i = num_dynamic_axes; i < src_tensor_dim; ++i) {
out_shape.push_back(x->shape[i]);
}
return te::compute(
out_shape,
[&](const Array<tvm::tir::Var>& indices) {
Array<PrimExpr> real_indices;
// dynamic slicing
for (int32_t i = 0; i < num_dynamic_axes; ++i) {
real_indices.push_back(indices[i] * strides(i) + tvm::min(begin(i), x->shape[i] - 1));
}
// keep input dim
for (int32_t i = num_dynamic_axes; i < src_tensor_dim; ++i) {
real_indices.push_back(indices[i]);
}
return x(real_indices);
},
name, tag);
}

inline Tensor dynamic_strided_slice(const Tensor& x, const Array<PrimExpr>& begin,
const Array<PrimExpr>& end, const Array<PrimExpr>& strides,
std::string name = "T_dynamic_strided_slice",
std::string tag = kInjective) {
size_t src_tensor_dim = static_cast<size_t>(x->shape.size());
ICHECK_EQ(begin.size(), src_tensor_dim);
ICHECK_EQ(end.size(), src_tensor_dim);
ICHECK_EQ(strides.size(), src_tensor_dim);
ICHECK_EQ(end->shape[0].as<IntImmNode>()->value, num_dynamic_axes);
ICHECK_EQ(strides->shape[0].as<IntImmNode>()->value, num_dynamic_axes);

Array<PrimExpr> out_shape;
Array<Integer> axes;
for (size_t i = 0; i < src_tensor_dim; ++i) {
out_shape.push_back(indexdiv(end[i] - begin[i], strides[i]));
axes.push_back(i);
Array<PrimExpr> begin_expr, end_expr, strides_expr;
for (int64_t i = 0; i < num_dynamic_axes; ++i) {
auto i64_ind = IntImm(DataType::Int(64), i);
begin_expr.push_back(begin(i64_ind));
end_expr.push_back(end(i64_ind));
strides_expr.push_back(strides(i64_ind));
}
return strided_slice_compute_common(x, out_shape, begin, strides, axes, name, tag);
return dynamic_strided_slice(x, begin_expr, end_expr, strides_expr, name, tag);
}

inline Tensor strided_slice_dynamic_input(const Tensor& x, const Array<Integer>& begin,
Expand Down
1 change: 0 additions & 1 deletion python/tvm/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,6 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
end.append(dshape[i])
if ret_type == "both":
values_out, indices_out = output
print("end:", end, k)
values_out = strided_slice(values_out, beg, end, strides)
indices_out = strided_slice(indices_out, beg, end, strides)
output = [values_out, indices_out]
Expand Down

0 comments on commit 510bce6

Please sign in to comment.