From 8bf88913b9bc02730120a0695138ed3fb8ed49ae Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 24 May 2021 17:39:50 +0900 Subject: [PATCH] fix format --- include/tvm/topi/transform.h | 293 +++++++++++++++---------------- src/relay/op/tensor/transform.cc | 5 +- 2 files changed, 149 insertions(+), 149 deletions(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 9fbc4f01c79f..1acf99eb1c44 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1803,169 +1803,168 @@ inline Tensor sparse_to_dense(const Tensor& sparse_indices, const Array& indices) { - PrimExpr ret = default_value; - if (0 == rank_sparse_indices) { - ret = if_then_else(indices[0] == sparse_indices[0], sparse_values[0], ret); - } else if (1 == rank_sparse_indices) { - for (int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) { - ret = if_then_else(indices[0] == sparse_indices[j], sparse_values[j], ret); - } - } else { - for (int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) { - PrimExpr aggregate_condition; - for (int k = 0; k < GetConstInt(sparse_indices->shape[1]); k++) { - PrimExpr comparision = indices[k] == sparse_indices[j][k]; - aggregate_condition = 0 == k ? comparision : aggregate_condition && comparision; - } - ret = if_then_else(aggregate_condition, sparse_values[j], ret); + [&](const Array& indices) { + PrimExpr ret = default_value; + if (0 == rank_sparse_indices) { + ret = if_then_else(indices[0] == sparse_indices[0], sparse_values[0], ret); + } else if (1 == rank_sparse_indices) { + for (int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) { + ret = if_then_else(indices[0] == sparse_indices[j], sparse_values[j], ret); + } + } else { + for (int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) { + PrimExpr aggregate_condition; + for (int k = 0; k < GetConstInt(sparse_indices->shape[1]); k++) { + PrimExpr comparision = indices[k] == sparse_indices[j][k]; + aggregate_condition = 0 == k ? comparision : aggregate_condition && comparision; } + ret = if_then_else(aggregate_condition, sparse_values[j], ret); } - return ret; - }, - name, tag); - } + } + return ret; + }, + name, tag); +} - /*! - * \brief Returns a tensor with the diagonal of input tensor replaced with the provided diagonals. - * \param input input tensor. - * \param diagonal values to be filled in the diagonals. - * \param k1 lower limit (included) of the range of diagonals. - * \param k2 upper limit (included) of the range of diagonals. - * \param super_diag_right_align bool, true iff super-diagonal is right aligned (left-padded). - * \param sub_diag_right_align bool, true iff sub-diagonal is right aligned (left-padded). - * \param name output tensor name. - * \param tag output tensor tag. - * \return new tensor with given diagonal values. - */ - inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k1, int k2, - bool super_diag_right_align, bool sub_diag_right_align, - const std::string name = "T_matrix_set_diag", - const std::string tag = kInjective) { - size_t ndim = input->shape.size() - 1; - - bool only_one_diagonal = k1 == k2; +/*! + * \brief Returns a tensor with the diagonal of input tensor replaced with the provided diagonals. + * \param input input tensor. + * \param diagonal values to be filled in the diagonals. + * \param k1 lower limit (included) of the range of diagonals. + * \param k2 upper limit (included) of the range of diagonals. + * \param super_diag_right_align bool, true iff super-diagonal is right aligned (left-padded). + * \param sub_diag_right_align bool, true iff sub-diagonal is right aligned (left-padded). + * \param name output tensor name. + * \param tag output tensor tag. + * \return new tensor with given diagonal values. + */ +inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k1, int k2, + bool super_diag_right_align, bool sub_diag_right_align, + const std::string name = "T_matrix_set_diag", + const std::string tag = kInjective) { + size_t ndim = input->shape.size() - 1; - return compute( - input->shape, - [&](const Array& iter_vars) { - auto get_diag = [&]() { - Array diagonal_indices; - PrimExpr k, offset = 0; - for (size_t i = 0; i < ndim - 1; i++) { - diagonal_indices.push_back(iter_vars[i]); - } - if (only_one_diagonal) { - k = k1; - } else { - // Determining which diagonal/sub-diagonal/super-diagonal it is - k = iter_vars[ndim] - iter_vars[ndim - 1]; - diagonal_indices.push_back(k2 - k); - - // Calculating the offset in diagonal tensor for this diagonal - auto get_offset = [&](PrimExpr M, PrimExpr N) { - // offset = max_diagonal_length - diagonal_length - return diagonal->shape[diagonal->shape.size() - 1] - if_then_else(M < N, M, N); - }; - offset = if_then_else(k >= 0, - super_diag_right_align - ? get_offset(input->shape[ndim] - k, input->shape[ndim - 1]) - : 0, - sub_diag_right_align - ? get_offset(input->shape[ndim], input->shape[ndim - 1] + k) - : 0); - } - diagonal_indices.push_back(if_then_else(k >= 0, iter_vars[ndim - 1], iter_vars[ndim]) + - offset); - return diagonal(diagonal_indices); - }; - return if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] >= k1, - if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] <= k2, - get_diag(), input(iter_vars)), - input(iter_vars)); - }, - name, tag); - } + bool only_one_diagonal = k1 == k2; - /*! - * \brief Numpy style advanced indexing with tensor. - * \param data is input data. - * \param indices is list of indexing tensors. - * \param name output tensor name. - * \param tag output tensor tag. - * \return Output tensor. - */ - inline Tensor adv_index(const Tensor& data, const Array& indices, - const std::string name = "advanced_index", - const std::string tag = kInjective) { - Array oshape; - Array broadcast_shape; - Array bindices; - std::vector flatten_shape_lens; - int64_t num_picked_elems = 1; - bool has_dyn_shape = false; - - if (indices.size() == 1) { - broadcast_shape = indices[0]->shape; - bindices = indices; - } else { - for (const auto& index : indices) { - int64_t flatten_len = 1; - for (const auto& dim : index->shape) { - const IntImmNode* axis_len = dim.as(); - if (!axis_len) { - broadcast_shape = index->shape; - has_dyn_shape = true; - break; + return compute( + input->shape, + [&](const Array& iter_vars) { + auto get_diag = [&]() { + Array diagonal_indices; + PrimExpr k, offset = 0; + for (size_t i = 0; i < ndim - 1; i++) { + diagonal_indices.push_back(iter_vars[i]); } - flatten_len *= axis_len->value; - } - if (has_dyn_shape) break; - flatten_shape_lens.push_back(flatten_len); - if (flatten_len > num_picked_elems) { - num_picked_elems = flatten_len; + if (only_one_diagonal) { + k = k1; + } else { + // Determining which diagonal/sub-diagonal/super-diagonal it is + k = iter_vars[ndim] - iter_vars[ndim - 1]; + diagonal_indices.push_back(k2 - k); + + // Calculating the offset in diagonal tensor for this diagonal + auto get_offset = [&](PrimExpr M, PrimExpr N) { + // offset = max_diagonal_length - diagonal_length + return diagonal->shape[diagonal->shape.size() - 1] - if_then_else(M < N, M, N); + }; + offset = if_then_else( + k >= 0, + super_diag_right_align ? get_offset(input->shape[ndim] - k, input->shape[ndim - 1]) + : 0, + sub_diag_right_align ? get_offset(input->shape[ndim], input->shape[ndim - 1] + k) + : 0); + } + diagonal_indices.push_back(if_then_else(k >= 0, iter_vars[ndim - 1], iter_vars[ndim]) + + offset); + return diagonal(diagonal_indices); + }; + return if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] >= k1, + if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] <= k2, + get_diag(), input(iter_vars)), + input(iter_vars)); + }, + name, tag); +} + +/*! + * \brief Numpy style advanced indexing with tensor. + * \param data is input data. + * \param indices is list of indexing tensors. + * \param name output tensor name. + * \param tag output tensor tag. + * \return Output tensor. + */ +inline Tensor adv_index(const Tensor& data, const Array& indices, + const std::string name = "advanced_index", + const std::string tag = kInjective) { + Array oshape; + Array broadcast_shape; + Array bindices; + std::vector flatten_shape_lens; + int64_t num_picked_elems = 1; + bool has_dyn_shape = false; + + if (indices.size() == 1) { + broadcast_shape = indices[0]->shape; + bindices = indices; + } else { + for (const auto& index : indices) { + int64_t flatten_len = 1; + for (const auto& dim : index->shape) { + const IntImmNode* axis_len = dim.as(); + if (!axis_len) { broadcast_shape = index->shape; + has_dyn_shape = true; + break; } + flatten_len *= axis_len->value; } - - // Do broadcast for indices - for (size_t i = 0; i < indices.size(); ++i) { - if (!has_dyn_shape && flatten_shape_lens[i] < num_picked_elems) { - bindices.push_back(broadcast_to(indices[i], broadcast_shape)); - } else { - bindices.push_back(indices[i]); - } + if (has_dyn_shape) break; + flatten_shape_lens.push_back(flatten_len); + if (flatten_len > num_picked_elems) { + num_picked_elems = flatten_len; + broadcast_shape = index->shape; } } - for (const auto& dim : broadcast_shape) { - oshape.push_back(dim); - } - for (size_t i = indices.size(); i < data->shape.size(); ++i) { - oshape.push_back(data->shape[i]); + // Do broadcast for indices + for (size_t i = 0; i < indices.size(); ++i) { + if (!has_dyn_shape && flatten_shape_lens[i] < num_picked_elems) { + bindices.push_back(broadcast_to(indices[i], broadcast_shape)); + } else { + bindices.push_back(indices[i]); + } } + } - return compute( - oshape, - [&](const Array& iter_var) { - Array tensor_indices; - for (size_t i = 0; i < broadcast_shape.size(); ++i) { - tensor_indices.push_back(iter_var[i]); - } + for (const auto& dim : broadcast_shape) { + oshape.push_back(dim); + } + for (size_t i = indices.size(); i < data->shape.size(); ++i) { + oshape.push_back(data->shape[i]); + } - Array real_indices; - for (size_t i = 0; i < bindices.size(); ++i) { - real_indices.push_back(bindices[i](tensor_indices)); - } - for (size_t i = broadcast_shape.size(); i < iter_var.size(); ++i) { - real_indices.push_back(iter_var[i]); - } + return compute( + oshape, + [&](const Array& iter_var) { + Array tensor_indices; + for (size_t i = 0; i < broadcast_shape.size(); ++i) { + tensor_indices.push_back(iter_var[i]); + } - return data(real_indices); - }, - name, tag); - } + Array real_indices; + for (size_t i = 0; i < bindices.size(); ++i) { + real_indices.push_back(bindices[i](tensor_indices)); + } + for (size_t i = broadcast_shape.size(); i < iter_var.size(); ++i) { + real_indices.push_back(iter_var[i]); + } + + return data(real_indices); + }, + name, tag); +} } // namespace topi -} // namespace topi +} // namespace tvm #endif // TVM_TOPI_TRANSFORM_H_ diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 244f9bd73702..7265520bade1 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2457,8 +2457,9 @@ bool StridedSliceRel(const Array& types, int num_inputs, const Attrs& attr Array axes; if (param->axes) { axes = param->axes.value(); - ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size()) - << "axes, begin, end, and strides must have the same length"; + ICHECK(axes.size() == begin.size() && axes.size() == end.size() && + axes.size() == strides.size()) + << "axes, begin, end, and strides must have the same length"; } else { for (size_t i = 0; i < src_tensor_dim; ++i) axes.push_back(i);