Skip to content

Commit

Permalink
refactoring slice with axes
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 27, 2021
1 parent 32698b7 commit 36aa777
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 154 deletions.
2 changes: 1 addition & 1 deletion include/tvm/topi/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data,
out = reshape(out, r_p_shape);

// Crop the start and end of dimensions of out
Array<PrimExpr> begin_idx, end_idx, strides;
Array<Integer> begin_idx, end_idx, strides;
for (size_t i = 0; i < r_p_shape.size(); ++i) {
strides.push_back(Integer(1));
if (i > 0 && i <= num_block_dims) {
Expand Down
234 changes: 88 additions & 146 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -593,54 +593,54 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b
name, tag);
}

/*!
* \brief strided_slice of a tensor
*
* \param x The input tensor
* \param begin The indices to begin with in the slicing
* \param end Indicies indicating end of the slice
* \param strides Specifies the stride values, it can be negative
* in that case, the input tensor will be reversed in that particular axis
* \param slice_mode Specifies the slice mode
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the split operation
*/
inline Tensor strided_slice(const Tensor& x, const Array<PrimExpr>& begin,
const Array<PrimExpr>& end, const Array<PrimExpr>& strides,
std::string slice_mode = "end", std::string name = "T_strided_slice",
std::string tag = kInjective) {
size_t src_tensor_dim = static_cast<size_t>(x->shape.size());
// Quick path for dynamic shape strided slice.
// This is for ease of use to dynamice strided slice in topi.
bool is_static = IsConstIntArray(x->shape);
is_static &= IsConstIntArray(begin);
is_static &= IsConstIntArray(end);
is_static &= IsConstIntArray(strides);

inline Tensor strided_slice_dynamic_input(const Tensor& input, const Array<Integer>& begin,
const Array<Integer>& end, const Array<Integer>& strides,
std::string slice_mode = "end",
std::string name = "T_strided_slice_dynamic_input",
std::string tag = kInjective) {
size_t src_tensor_dim = input->shape.size();
ICHECK(begin.size() == src_tensor_dim)
<< "for dynamic inputs, len(begin) must equal the input dimension";
Array<PrimExpr> out_shape;
if (!is_static) {
ICHECK_EQ(strides.size(), src_tensor_dim);
for (size_t i = 0; i < src_tensor_dim; ++i) {
out_shape.push_back(indexdiv(end[i] - begin[i], strides[i]));
for (size_t i = 0; i < src_tensor_dim; ++i) {
out_shape.push_back(tvm::tir::Var("dim"));
}
Array<PrimExpr> begin_expr, end_expr, strides_expr;
for (size_t i = 0; i < src_tensor_dim; ++i) {
int64_t begin_i = begin[i]->value;
if (begin_i < 0) {
begin_i += topi::detail::GetConstInt(input->shape[i]);
}
return te::compute(
out_shape,
[&](const Array<tvm::tir::Var>& indices) {
Array<PrimExpr> real_indices;
for (size_t i = 0; i < src_tensor_dim; ++i) {
real_indices.push_back(indices[i] * strides[i] + begin[i]);
}
return x(real_indices);
},
name, tag);
begin_expr.push_back(tir::make_const(begin[0].dtype(), begin_i));
strides_expr.push_back(
tir::make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()),
(i < strides.size() ? strides[i]->value : 1)));
}
return te::compute(
out_shape,
[&](const Array<tir::Var>& indices) {
Array<PrimExpr> real_indices;
for (size_t i = 0; i < src_tensor_dim; ++i) {
real_indices.push_back(indices[i] * strides_expr[i] + begin_expr[i]);
}
return input(real_indices);
},
std::string{"T_strided_slice_dynamic_input"}, std::string{topi::kInjective});
}

inline Tensor strided_slice_with_axes(const Tensor& x, const Array<Integer>& begin,
const Array<Integer>& end, const Array<Integer>& strides,
const Array<Integer>& axes, std::string slice_mode = "end",
std::string name = "T_strided_slice_dynamic_input",
std::string tag = kInjective) {
size_t src_tensor_dim = x->shape.size();

ICHECK(axes.size() <= src_tensor_dim);
ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size());

// Setup the ranges.
// NOTE: this code duplicates the shape inference logic relay.op
// Consider to refactor in the future.
std::vector<int64_t> stride_vec(src_tensor_dim, 1);
std::vector<int64_t> stride_vec(strides.size(), 1);
for (size_t i = 0; i < strides.size(); ++i) {
ICHECK(strides[i].defined());
stride_vec[i] = GetConstInt(strides[i]);
Expand All @@ -657,9 +657,6 @@ inline Tensor strided_slice(const Tensor& x, const Array<PrimExpr>& begin,
begin_vec.push_back(GetConstInt(begin[i]));
}
}
for (size_t i = begin_vec.size(); i < src_tensor_dim; ++i) {
begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range);
}

std::vector<int64_t> end_vec;
for (size_t i = 0; i < end.size(); ++i) {
Expand All @@ -678,16 +675,17 @@ inline Tensor strided_slice(const Tensor& x, const Array<PrimExpr>& begin,
end_vec.push_back(GetConstInt(end[i]));
}
}
for (size_t i = end_vec.size(); i < src_tensor_dim; ++i) {
end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
}

// Compute
Array<PrimExpr> begin_expr;
Array<PrimExpr> strides_expr;

for (size_t i = 0; i < src_tensor_dim; ++i) {
Array<PrimExpr> out_shape;
for (size_t i = 0; i < axes.size(); ++i) {
int64_t begin_range = stride_vec[i] < 0 ? -1 : 0;
int64_t dim_i = GetConstInt(x->shape[i]);
ICHECK(x->shape[axes[i]]->IsInstance<tvm::IntImmNode>())
<< "Input shape at axis " << axes[i] << " is not static";
int64_t dim_i = GetConstInt(x->shape[axes[i]]);
int64_t end_range = stride_vec[i] < 0 ? dim_i - 1 : dim_i;
// transform negative indices to positive value, clips on the correct range
auto index_canonicalization = [dim_i, begin_range, end_range](int64_t index) {
Expand All @@ -713,116 +711,60 @@ inline Tensor strided_slice(const Tensor& x, const Array<PrimExpr>& begin,
out_shape.push_back(slice_size);
}

return compute(
return te::compute(
out_shape,
[&](const Array<Var>& indices) {
[&](const Array<tir::Var>& indices) {
Array<PrimExpr> real_indices;
for (size_t i = 0; i < src_tensor_dim; ++i) {
real_indices.push_back(indices[i] * strides_expr[i] + begin_expr[i]);
for (size_t i = 0; i < src_tensor_dim; ++i) real_indices.push_back(indices[i]);
for (size_t i = 0; i < axes.size(); ++i) {
PrimExpr ind = indices[axes[i]] * strides[i] + begin_expr[i];
real_indices.Set(axes[i], ind);
}
return x(real_indices);
},
name, tag);
std::string{"T_strided_slice_with_axes"}, std::string{topi::kInjective});
}

/*!
* \brief strided_slice of a tensor
*
* \param x The input tensor
* \param begin The indices to begin with in the slicing
* \param end Indicies indicating end of the slice
* \param strides Specifies the stride values, it can be negative
* in that case, the input tensor will be reversed in that particular axis
* \param slice_mode Specifies the slice mode
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the split operation
*/
inline Tensor strided_slice(const Tensor& x, const Array<Integer>& begin, const Array<Integer>& end,
const Array<Integer>& strides, std::string slice_mode = "end",
std::string name = "T_strided_slice", std::string tag = kInjective) {
Array<PrimExpr> begin_expr, end_expr, strides_expr;
for (size_t i = 0; i < begin.size(); ++i) {
begin_expr.push_back(begin[i]);
}
for (size_t i = 0; i < end.size(); ++i) {
end_expr.push_back(end[i]);
}
for (size_t i = 0; i < strides.size(); ++i) {
strides_expr.push_back(strides[i]);
}
return strided_slice(x, begin_expr, end_expr, strides_expr, slice_mode, name, tag);
}

inline Tensor strided_slice_dynamic_input(const Tensor& input, const Array<Integer>& begin,
const Array<Integer>& end, const Array<Integer>& strides,
std::string slice_mode = "end",
std::string name = "T_strided_slice_dynamic_input",
std::string tag = kInjective) {
size_t src_tensor_dim = input->shape.size();
ICHECK(begin.size() == src_tensor_dim)
<< "for dynamic inputs, len(begin) must equal the input dimension";
Array<PrimExpr> out_shape;
for (size_t i = 0; i < src_tensor_dim; ++i) {
out_shape.push_back(tvm::tir::Var("dim"));
}
Array<PrimExpr> begin_expr, end_expr, strides_expr;
for (size_t i = 0; i < src_tensor_dim; ++i) {
int64_t begin_i = begin[i]->value;
if (begin_i < 0) {
begin_i += topi::detail::GetConstInt(input->shape[i]);
}
begin_expr.push_back(tir::make_const(begin[0].dtype(), begin_i));
strides_expr.push_back(
tir::make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()),
(i < strides.size() ? strides[i]->value : 1)));
}
return te::compute(
out_shape,
[&](const Array<tir::Var>& indices) {
Array<PrimExpr> real_indices;
for (size_t i = 0; i < src_tensor_dim; ++i) {
real_indices.push_back(indices[i] * strides_expr[i] + begin_expr[i]);
}
return input(real_indices);
},
std::string{"T_strided_slice_dynamic_input"}, std::string{topi::kInjective});
}

inline Tensor strided_slice_with_axes(const Tensor& input, const Array<Integer>& begin,
const Array<Integer>& end, const Array<Integer>& strides,
const Array<Integer>& axes, std::string slice_mode = "end",
std::string name = "T_strided_slice_dynamic_input",
std::string tag = kInjective) {
size_t src_tensor_dim = input->shape.size();
size_t src_tensor_dim = static_cast<size_t>(x->shape.size());
Array<Integer> axes;
for (size_t i = 0; i < src_tensor_dim; ++i) axes.push_back(i);
Array<Integer> begin_full(begin);
Array<Integer> end_full(end);
Array<Integer> strides_full(strides);

ICHECK(axes.size() <= src_tensor_dim);
ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size());
const IntImm one = IntImm(DataType::Int(64), 1);
const IntImm zero = IntImm(DataType::Int(64), 1);
const IntImm max_range = IntImm(DataType::Int(64), std::numeric_limits<int64_t>::max());

Array<PrimExpr> out_shape;
for (size_t i = 0; i < src_tensor_dim; ++i) {
out_shape.push_back(input->shape[i]);
for (size_t i = strides_full.size(); i < src_tensor_dim; ++i) {
strides_full.push_back(one);
}
Array<PrimExpr> begin_expr;
for (size_t i = 0; i < axes.size(); ++i) {
auto idim = input->shape[axes[i]];
auto b = tvm::if_then_else(begin[i] < 0, begin[i] + idim, begin[i]);
auto e = tvm::if_then_else(end[i] < 0, end[i] + idim, end[i]);
auto s = strides[i]->value;
PrimExpr range;
if (s < 0) {
b = tvm::min(b, idim - 1);
e = tvm::if_then_else(e < -1, -1, e);
range = b - e;
s = -s;
} else {
b = tvm::if_then_else(b < 0, 0, b);
e = tvm::min(e, idim);
range = e - b;
}
PrimExpr odim = indexdiv(range + tvm::PrimExpr(static_cast<int32_t>(s - 1)), s);
out_shape.Set(axes[i], cast(out_shape[i].dtype(), odim));
begin_expr.push_back(b);
for (size_t i = begin.size(); i < src_tensor_dim; ++i) {
begin_full.push_back(GetConstInt(strides_full[i]) > 0 ? zero : max_range);
}
return te::compute(
out_shape,
[&](const Array<tir::Var>& indices) {
Array<PrimExpr> real_indices;
for (size_t i = 0; i < src_tensor_dim; ++i) real_indices.push_back(indices[i]);
for (size_t i = 0; i < axes.size(); ++i) {
PrimExpr ind = indices[axes[i]] * strides[i] + begin_expr[i];
real_indices.Set(axes[i], ind);
}
return input(real_indices);
},
std::string{"T_strided_slice_with_axes"}, std::string{topi::kInjective});
for (size_t i = end.size(); i < src_tensor_dim; ++i) {
end_full.push_back(GetConstInt(strides_full[i]) < 0 ? zero : max_range);
}

return strided_slice_with_axes(x, begin_full, end_full, strides_full, axes, slice_mode, name,
tag);
}

/*!
Expand Down
15 changes: 11 additions & 4 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3072,16 +3072,21 @@ Array<te::Tensor> SliceLikeCompute(const Attrs& attrs, const Array<te::Tensor>&
ICHECK(param != nullptr);
Array<IndexExpr> src_shape = inputs[0]->shape;
Array<IndexExpr> target_shape = inputs[1]->shape;
Array<IndexExpr> begin_idx, end_idx, strides;
Array<Integer> begin_idx, end_idx, strides;
for (size_t i = 0; i < src_shape.size(); ++i) {
begin_idx.push_back(0);
strides.push_back(1);
}
end_idx = Array<IndexExpr>(src_shape);
for (auto s : src_shape) {
ICHECK(s->IsInstance<tvm::IntImmNode>()) << "slice_like does not support dynamic input shape";
end_idx.push_back(topi::GetConstInt(s));
}
if (!param->axes.defined()) {
for (size_t i = 0; i < src_shape.size(); ++i) {
if (i < target_shape.size()) {
end_idx.Set(i, target_shape[i]);
ICHECK(target_shape[i]->IsInstance<tvm::IntImmNode>())
<< "slice_like does not support dynamic output shape";
end_idx.Set(i, topi::GetConstInt(target_shape[i]));
ICHECK_LE(topi::GetConstInt(end_idx[i]), topi::GetConstInt(src_shape[i]))
<< "End index of axis " << i
<< " exceeds input shape: " << topi::GetConstInt(end_idx[i]) << " vs "
Expand All @@ -3093,7 +3098,9 @@ Array<te::Tensor> SliceLikeCompute(const Attrs& attrs, const Array<te::Tensor>&
if (axis < 0) {
axis = static_cast<int>(src_shape.size()) + axis;
}
end_idx.Set(axis, target_shape[axis]);
ICHECK(target_shape[axis]->IsInstance<tvm::IntImmNode>())
<< "slice_like does not support dynamic output shape";
end_idx.Set(axis, topi::GetConstInt(target_shape[axis]));
ICHECK_LE(topi::GetConstInt(end_idx[axis]), topi::GetConstInt(src_shape[axis]))
<< "End index of axis " << axis
<< " exceeds input shape: " << topi::GetConstInt(end_idx[axis]) << " vs "
Expand Down
1 change: 1 addition & 0 deletions src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan)
.set_default_keys({"vulkan", "gpu"})
.set_attrs_preprocessor(UpdateVulkanAttrs);


TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU)
.add_attr_option<Bool>("system-lib")
.add_attr_option<Integer>("max_num_threads", Integer(256))
Expand Down
6 changes: 3 additions & 3 deletions src/topi/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ TVM_REGISTER_GLOBAL("topi.einsum").set_body([](TVMArgs args, TVMRetValue* rv) {
});

TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue* rv) {
Array<PrimExpr> begin = args[1];
Array<PrimExpr> end = args[2];
Array<PrimExpr> strides = args[3];
Array<Integer> begin = args[1];
Array<Integer> end = args[2];
Array<Integer> strides = args[3];
*rv = strided_slice(args[0], begin, end, strides, args[4]);
});

Expand Down

0 comments on commit 36aa777

Please sign in to comment.