Skip to content

Commit

Permalink
fix format
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 27, 2021
1 parent e89d599 commit 8bf8891
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 149 deletions.
293 changes: 146 additions & 147 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -1803,169 +1803,168 @@ inline Tensor sparse_to_dense(const Tensor& sparse_indices, const Array<PrimExpr
}
return compute(
oshape,
[&](const Array<Var>& indices) {
PrimExpr ret = default_value;
if (0 == rank_sparse_indices) {
ret = if_then_else(indices[0] == sparse_indices[0], sparse_values[0], ret);
} else if (1 == rank_sparse_indices) {
for (int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) {
ret = if_then_else(indices[0] == sparse_indices[j], sparse_values[j], ret);
}
} else {
for (int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) {
PrimExpr aggregate_condition;
for (int k = 0; k < GetConstInt(sparse_indices->shape[1]); k++) {
PrimExpr comparision = indices[k] == sparse_indices[j][k];
aggregate_condition = 0 == k ? comparision : aggregate_condition && comparision;
}
ret = if_then_else(aggregate_condition, sparse_values[j], ret);
[&](const Array<Var>& indices) {
PrimExpr ret = default_value;
if (0 == rank_sparse_indices) {
ret = if_then_else(indices[0] == sparse_indices[0], sparse_values[0], ret);
} else if (1 == rank_sparse_indices) {
for (int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) {
ret = if_then_else(indices[0] == sparse_indices[j], sparse_values[j], ret);
}
} else {
for (int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) {
PrimExpr aggregate_condition;
for (int k = 0; k < GetConstInt(sparse_indices->shape[1]); k++) {
PrimExpr comparision = indices[k] == sparse_indices[j][k];
aggregate_condition = 0 == k ? comparision : aggregate_condition && comparision;
}
ret = if_then_else(aggregate_condition, sparse_values[j], ret);
}
return ret;
},
name, tag);
}
}
return ret;
},
name, tag);
}

/*!
* \brief Returns a tensor with the diagonal of input tensor replaced with the provided diagonals.
* \param input input tensor.
* \param diagonal values to be filled in the diagonals.
* \param k1 lower limit (included) of the range of diagonals.
* \param k2 upper limit (included) of the range of diagonals.
* \param super_diag_right_align bool, true iff super-diagonal is right aligned (left-padded).
* \param sub_diag_right_align bool, true iff sub-diagonal is right aligned (left-padded).
* \param name output tensor name.
* \param tag output tensor tag.
* \return new tensor with given diagonal values.
*/
inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k1, int k2,
bool super_diag_right_align, bool sub_diag_right_align,
const std::string name = "T_matrix_set_diag",
const std::string tag = kInjective) {
size_t ndim = input->shape.size() - 1;

bool only_one_diagonal = k1 == k2;
/*!
* \brief Returns a tensor with the diagonal of input tensor replaced with the provided diagonals.
* \param input input tensor.
* \param diagonal values to be filled in the diagonals.
* \param k1 lower limit (included) of the range of diagonals.
* \param k2 upper limit (included) of the range of diagonals.
* \param super_diag_right_align bool, true iff super-diagonal is right aligned (left-padded).
* \param sub_diag_right_align bool, true iff sub-diagonal is right aligned (left-padded).
* \param name output tensor name.
* \param tag output tensor tag.
* \return new tensor with given diagonal values.
*/
inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k1, int k2,
bool super_diag_right_align, bool sub_diag_right_align,
const std::string name = "T_matrix_set_diag",
const std::string tag = kInjective) {
size_t ndim = input->shape.size() - 1;

return compute(
input->shape,
[&](const Array<Var>& iter_vars) {
auto get_diag = [&]() {
Array<PrimExpr> diagonal_indices;
PrimExpr k, offset = 0;
for (size_t i = 0; i < ndim - 1; i++) {
diagonal_indices.push_back(iter_vars[i]);
}
if (only_one_diagonal) {
k = k1;
} else {
// Determining which diagonal/sub-diagonal/super-diagonal it is
k = iter_vars[ndim] - iter_vars[ndim - 1];
diagonal_indices.push_back(k2 - k);

// Calculating the offset in diagonal tensor for this diagonal
auto get_offset = [&](PrimExpr M, PrimExpr N) {
// offset = max_diagonal_length - diagonal_length
return diagonal->shape[diagonal->shape.size() - 1] - if_then_else(M < N, M, N);
};
offset = if_then_else(k >= 0,
super_diag_right_align
? get_offset(input->shape[ndim] - k, input->shape[ndim - 1])
: 0,
sub_diag_right_align
? get_offset(input->shape[ndim], input->shape[ndim - 1] + k)
: 0);
}
diagonal_indices.push_back(if_then_else(k >= 0, iter_vars[ndim - 1], iter_vars[ndim]) +
offset);
return diagonal(diagonal_indices);
};
return if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] >= k1,
if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] <= k2,
get_diag(), input(iter_vars)),
input(iter_vars));
},
name, tag);
}
bool only_one_diagonal = k1 == k2;

/*!
* \brief Numpy style advanced indexing with tensor.
* \param data is input data.
* \param indices is list of indexing tensors.
* \param name output tensor name.
* \param tag output tensor tag.
* \return Output tensor.
*/
inline Tensor adv_index(const Tensor& data, const Array<Tensor>& indices,
const std::string name = "advanced_index",
const std::string tag = kInjective) {
Array<PrimExpr> oshape;
Array<PrimExpr> broadcast_shape;
Array<Tensor> bindices;
std::vector<int64_t> flatten_shape_lens;
int64_t num_picked_elems = 1;
bool has_dyn_shape = false;

if (indices.size() == 1) {
broadcast_shape = indices[0]->shape;
bindices = indices;
} else {
for (const auto& index : indices) {
int64_t flatten_len = 1;
for (const auto& dim : index->shape) {
const IntImmNode* axis_len = dim.as<IntImmNode>();
if (!axis_len) {
broadcast_shape = index->shape;
has_dyn_shape = true;
break;
return compute(
input->shape,
[&](const Array<Var>& iter_vars) {
auto get_diag = [&]() {
Array<PrimExpr> diagonal_indices;
PrimExpr k, offset = 0;
for (size_t i = 0; i < ndim - 1; i++) {
diagonal_indices.push_back(iter_vars[i]);
}
flatten_len *= axis_len->value;
}
if (has_dyn_shape) break;
flatten_shape_lens.push_back(flatten_len);
if (flatten_len > num_picked_elems) {
num_picked_elems = flatten_len;
if (only_one_diagonal) {
k = k1;
} else {
// Determining which diagonal/sub-diagonal/super-diagonal it is
k = iter_vars[ndim] - iter_vars[ndim - 1];
diagonal_indices.push_back(k2 - k);

// Calculating the offset in diagonal tensor for this diagonal
auto get_offset = [&](PrimExpr M, PrimExpr N) {
// offset = max_diagonal_length - diagonal_length
return diagonal->shape[diagonal->shape.size() - 1] - if_then_else(M < N, M, N);
};
offset = if_then_else(
k >= 0,
super_diag_right_align ? get_offset(input->shape[ndim] - k, input->shape[ndim - 1])
: 0,
sub_diag_right_align ? get_offset(input->shape[ndim], input->shape[ndim - 1] + k)
: 0);
}
diagonal_indices.push_back(if_then_else(k >= 0, iter_vars[ndim - 1], iter_vars[ndim]) +
offset);
return diagonal(diagonal_indices);
};
return if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] >= k1,
if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] <= k2,
get_diag(), input(iter_vars)),
input(iter_vars));
},
name, tag);
}

/*!
* \brief Numpy style advanced indexing with tensor.
* \param data is input data.
* \param indices is list of indexing tensors.
* \param name output tensor name.
* \param tag output tensor tag.
* \return Output tensor.
*/
inline Tensor adv_index(const Tensor& data, const Array<Tensor>& indices,
const std::string name = "advanced_index",
const std::string tag = kInjective) {
Array<PrimExpr> oshape;
Array<PrimExpr> broadcast_shape;
Array<Tensor> bindices;
std::vector<int64_t> flatten_shape_lens;
int64_t num_picked_elems = 1;
bool has_dyn_shape = false;

if (indices.size() == 1) {
broadcast_shape = indices[0]->shape;
bindices = indices;
} else {
for (const auto& index : indices) {
int64_t flatten_len = 1;
for (const auto& dim : index->shape) {
const IntImmNode* axis_len = dim.as<IntImmNode>();
if (!axis_len) {
broadcast_shape = index->shape;
has_dyn_shape = true;
break;
}
flatten_len *= axis_len->value;
}

// Do broadcast for indices
for (size_t i = 0; i < indices.size(); ++i) {
if (!has_dyn_shape && flatten_shape_lens[i] < num_picked_elems) {
bindices.push_back(broadcast_to(indices[i], broadcast_shape));
} else {
bindices.push_back(indices[i]);
}
if (has_dyn_shape) break;
flatten_shape_lens.push_back(flatten_len);
if (flatten_len > num_picked_elems) {
num_picked_elems = flatten_len;
broadcast_shape = index->shape;
}
}

for (const auto& dim : broadcast_shape) {
oshape.push_back(dim);
}
for (size_t i = indices.size(); i < data->shape.size(); ++i) {
oshape.push_back(data->shape[i]);
// Do broadcast for indices
for (size_t i = 0; i < indices.size(); ++i) {
if (!has_dyn_shape && flatten_shape_lens[i] < num_picked_elems) {
bindices.push_back(broadcast_to(indices[i], broadcast_shape));
} else {
bindices.push_back(indices[i]);
}
}
}

return compute(
oshape,
[&](const Array<Var>& iter_var) {
Array<PrimExpr> tensor_indices;
for (size_t i = 0; i < broadcast_shape.size(); ++i) {
tensor_indices.push_back(iter_var[i]);
}
for (const auto& dim : broadcast_shape) {
oshape.push_back(dim);
}
for (size_t i = indices.size(); i < data->shape.size(); ++i) {
oshape.push_back(data->shape[i]);
}

Array<PrimExpr> real_indices;
for (size_t i = 0; i < bindices.size(); ++i) {
real_indices.push_back(bindices[i](tensor_indices));
}
for (size_t i = broadcast_shape.size(); i < iter_var.size(); ++i) {
real_indices.push_back(iter_var[i]);
}
return compute(
oshape,
[&](const Array<Var>& iter_var) {
Array<PrimExpr> tensor_indices;
for (size_t i = 0; i < broadcast_shape.size(); ++i) {
tensor_indices.push_back(iter_var[i]);
}

return data(real_indices);
},
name, tag);
}
Array<PrimExpr> real_indices;
for (size_t i = 0; i < bindices.size(); ++i) {
real_indices.push_back(bindices[i](tensor_indices));
}
for (size_t i = broadcast_shape.size(); i < iter_var.size(); ++i) {
real_indices.push_back(iter_var[i]);
}

return data(real_indices);
},
name, tag);
}

} // namespace topi
} // namespace topi
} // namespace tvm
#endif // TVM_TOPI_TRANSFORM_H_
5 changes: 3 additions & 2 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2457,8 +2457,9 @@ bool StridedSliceRel(const Array<Type>& types, int num_inputs, const Attrs& attr
Array<Integer> axes;
if (param->axes) {
axes = param->axes.value();
ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size())
<< "axes, begin, end, and strides must have the same length";
ICHECK(axes.size() == begin.size() && axes.size() == end.size() &&
axes.size() == strides.size())
<< "axes, begin, end, and strides must have the same length";
} else {
for (size_t i = 0; i < src_tensor_dim; ++i) axes.push_back(i);

Expand Down

0 comments on commit 8bf8891

Please sign in to comment.