diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index 063d69a737..0c5521d447 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -122,16 +122,8 @@ if(BUILD_PRIMS_BENCH) ) ConfigureBench( - NAME - MATRIX_BENCH - PATH - bench/prims/matrix/argmin.cu - bench/prims/matrix/gather.cu - bench/prims/matrix/select_k.cu - bench/prims/main.cpp - OPTIONAL - LIB - EXPLICIT_INSTANTIATE_ONLY + NAME MATRIX_BENCH PATH bench/prims/matrix/argmin.cu bench/prims/matrix/gather.cu + bench/prims/matrix/select_k.cu bench/prims/main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY ) ConfigureBench( diff --git a/cpp/include/raft/matrix/detail/select_k-ext.cuh b/cpp/include/raft/matrix/detail/select_k-ext.cuh index 95d806dd43..506cbffcb9 100644 --- a/cpp/include/raft/matrix/detail/select_k-ext.cuh +++ b/cpp/include/raft/matrix/detail/select_k-ext.cuh @@ -16,7 +16,6 @@ #pragma once -#include #include #include #include // RAFT_EXPLICIT @@ -45,16 +44,6 @@ void select_k(raft::resources const& handle, bool sorted = false, SelectAlgo algo = SelectAlgo::kAuto, const IdxT* len_i = nullptr) RAFT_EXPLICIT; - -template -void select_k(raft::resources const& handle, - raft::device_csr_matrix_view in_val, - std::optional> in_idx, - raft::device_matrix_view out_val, - raft::device_matrix_view out_idx, - bool select_min, - bool sorted = false, - SelectAlgo algo = SelectAlgo::kAuto) RAFT_EXPLICIT; } // namespace raft::matrix::detail #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY @@ -84,24 +73,3 @@ instantiate_raft_matrix_detail_select_k(double, int64_t); instantiate_raft_matrix_detail_select_k(double, uint32_t); #undef instantiate_raft_matrix_detail_select_k - -#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ - extern template void raft::matrix::detail::select_k( \ - raft::resources const& handle, \ - raft::device_csr_matrix_view in_val, \ - std::optional> in_idx, \ - raft::device_matrix_view out_val, \ - raft::device_matrix_view out_idx, \ - bool select_min, \ - bool sorted, \ - raft::matrix::SelectAlgo algo) - -instantiate_raft_matrix_detail_select_k(__half, uint32_t); -instantiate_raft_matrix_detail_select_k(__half, int64_t); -instantiate_raft_matrix_detail_select_k(float, int64_t); -instantiate_raft_matrix_detail_select_k(float, uint32_t); -instantiate_raft_matrix_detail_select_k(float, int); -instantiate_raft_matrix_detail_select_k(double, int64_t); -instantiate_raft_matrix_detail_select_k(double, uint32_t); - -#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/include/raft/matrix/detail/select_k-inl.cuh b/cpp/include/raft/matrix/detail/select_k-inl.cuh index 7b52199530..93d233152b 100644 --- a/cpp/include/raft/matrix/detail/select_k-inl.cuh +++ b/cpp/include/raft/matrix/detail/select_k-inl.cuh @@ -20,7 +20,6 @@ #include "select_radix.cuh" #include "select_warpsort.cuh" -#include #include #include #include @@ -31,8 +30,6 @@ #include -#include - namespace raft::matrix::detail { /** @@ -320,193 +317,4 @@ void select_k(raft::resources const& handle, default: RAFT_FAIL("K-selection Algorithm not supported."); } } - -/** - * Selects the k smallest or largest keys/values from each row of the input CSR matrix. - * - * This function operates on a CSR matrix `in_val` with a logical dense shape of [batch_size, len], - * selecting the k smallest or largest elements from each row. The selected elements are then stored - * in a row-major output matrix `out_val` with dimensions `batch_size` x k. - * - * @tparam T - * Type of the elements being compared (keys). - * @tparam IdxT - * Type of the indices associated with the keys. - * @tparam NZType - * Type representing non-zero elements of `in_val`. - * - * @param[in] handle - * Container for managing reusable resources. - * @param[in] in_val - * Input matrix in CSR format with a logical dense shape of [batch_size, len], - * containing the elements to be compared and selected. - * @param[in] in_idx - * Optional input indices [in_val.nnz] associated with `in_val.values`. - * If `in_idx` is `std::nullopt`, it defaults to a contiguous array from 0 to len-1. - * @param[out] out_val - * Output matrix [in_val.get_n_row(), k] storing the selected k smallest/largest elements - * from each row of `in_val`. - * @param[out] out_idx - * Output indices [in_val.get_n_row(), k] corresponding to the selected elements in `out_val`. - * @param[in] select_min - * Flag indicating whether to select the k smallest (true) or largest (false) elements. - * @param[in] sorted - * whether to make sure selected pairs are sorted by value - * @param[in] algo - * the selection algorithm to use - */ -template -void select_k(raft::resources const& handle, - raft::device_csr_matrix_view in_val, - std::optional> in_idx, - raft::device_matrix_view out_val, - raft::device_matrix_view out_idx, - bool select_min, - bool sorted = false, - SelectAlgo algo = SelectAlgo::kAuto) -{ - auto csr_view = in_val.structure_view(); - auto nnz = csr_view.get_nnz(); - - if (nnz == 0) return; - - auto batch_size = csr_view.get_n_rows(); - auto len = csr_view.get_n_cols(); - auto k = IdxT(out_val.extent(1)); - - RAFT_EXPECTS(out_val.extent(1) <= int64_t(std::numeric_limits::max()), - "output k must fit the int type."); - - RAFT_EXPECTS(batch_size == out_val.extent(0), "batch sizes must be equal"); - RAFT_EXPECTS(batch_size == out_idx.extent(0), "batch sizes must be equal"); - - if (in_idx.has_value()) { - RAFT_EXPECTS(size_t(nnz) == in_idx->size(), - "nnz of in_val must be equal to the length of in_idx"); - } - RAFT_EXPECTS(IdxT(k) == out_idx.extent(1), "value and index output lengths must be equal"); - - if (algo == SelectAlgo::kAuto) { algo = choose_select_k_algorithm(batch_size, len, k); } - - auto indptr = csr_view.get_indptr().data(); - - switch (algo) { - case SelectAlgo::kRadix8bits: - case SelectAlgo::kRadix11bits: - case SelectAlgo::kRadix11bitsExtraPass: { - if (algo == SelectAlgo::kRadix8bits) { - detail::select::radix::select_k( - handle, - in_val.get_elements().data(), - (in_idx.has_value() ? in_idx->data_handle() : csr_view.get_indices().data()), - batch_size, - len, - k, - out_val.data_handle(), - out_idx.data_handle(), - select_min, - true, - indptr); - } else { - bool fused_last_filter = algo == SelectAlgo::kRadix11bits; - detail::select::radix::select_k( - handle, - in_val.get_elements().data(), - (in_idx.has_value() ? in_idx->data_handle() : csr_view.get_indices().data()), - batch_size, - len, - k, - out_val.data_handle(), - out_idx.data_handle(), - select_min, - fused_last_filter, - indptr); - } - - if (sorted) { - auto offsets = make_device_mdarray( - handle, resource::get_workspace_resource(handle), make_extents(batch_size + 1)); - raft::linalg::map_offset(handle, offsets.view(), mul_const_op(k)); - - auto keys = - raft::make_device_vector_view(out_val.data_handle(), (IdxT)(batch_size * k)); - auto vals = - raft::make_device_vector_view(out_idx.data_handle(), (IdxT)(batch_size * k)); - - segmented_sort_by_key( - handle, raft::make_const_mdspan(offsets.view()), keys, vals, select_min); - } - - return; - } - case SelectAlgo::kWarpDistributed: - return detail::select::warpsort:: - select_k_impl( - handle, - in_val.get_elements().data(), - (in_idx.has_value() ? in_idx->data_handle() : csr_view.get_indices().data()), - batch_size, - len, - k, - out_val.data_handle(), - out_idx.data_handle(), - select_min, - indptr); - case SelectAlgo::kWarpDistributedShm: - return detail::select::warpsort:: - select_k_impl( - handle, - in_val.get_elements().data(), - (in_idx.has_value() ? in_idx->data_handle() : csr_view.get_indices().data()), - batch_size, - len, - k, - out_val.data_handle(), - out_idx.data_handle(), - select_min, - indptr); - case SelectAlgo::kWarpAuto: - return detail::select::warpsort::select_k( - handle, - in_val.get_elements().data(), - (in_idx.has_value() ? in_idx->data_handle() : csr_view.get_indices().data()), - batch_size, - len, - k, - out_val.data_handle(), - out_idx.data_handle(), - select_min, - indptr); - case SelectAlgo::kWarpImmediate: - return detail::select::warpsort:: - select_k_impl( - handle, - in_val.get_elements().data(), - (in_idx.has_value() ? in_idx->data_handle() : csr_view.get_indices().data()), - batch_size, - len, - k, - out_val.data_handle(), - out_idx.data_handle(), - select_min, - indptr); - case SelectAlgo::kWarpFiltered: - return detail::select::warpsort:: - select_k_impl( - handle, - in_val.get_elements().data(), - (in_idx.has_value() ? in_idx->data_handle() : csr_view.get_indices().data()), - batch_size, - len, - k, - out_val.data_handle(), - out_idx.data_handle(), - select_min, - indptr); - default: RAFT_FAIL("K-selection Algorithm not supported."); - } - - return; -} - } // namespace raft::matrix::detail diff --git a/cpp/include/raft/sparse/matrix/detail/select_k-ext.cuh b/cpp/include/raft/sparse/matrix/detail/select_k-ext.cuh new file mode 100644 index 0000000000..08bdfa6f30 --- /dev/null +++ b/cpp/include/raft/sparse/matrix/detail/select_k-ext.cuh @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include // RAFT_EXPLICIT + +#include // rmm:cuda_stream_view +#include // rmm::mr::device_memory_resource + +#include // __half + +#include // uint32_t + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::sparse::matrix::detail { + +template +void select_k(raft::resources const& handle, + raft::device_csr_matrix_view in_val, + std::optional> in_idx, + raft::device_matrix_view out_val, + raft::device_matrix_view out_idx, + bool select_min, + bool sorted = false, + raft::matrix::SelectAlgo algo = SelectAlgo::kAuto) RAFT_EXPLICIT; +} // namespace raft::sparse::matrix::detail + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_sparse_matrix_detail_select_k(T, IdxT) \ + extern template void raft::sparse::matrix::detail::select_k( \ + raft::resources const& handle, \ + raft::device_csr_matrix_view in_val, \ + std::optional> in_idx, \ + raft::device_matrix_view out_val, \ + raft::device_matrix_view out_idx, \ + bool select_min, \ + bool sorted, \ + raft::matrix::SelectAlgo algo) + +instantiate_raft_sparse_matrix_detail_select_k(__half, uint32_t); +instantiate_raft_sparse_matrix_detail_select_k(__half, int64_t); +instantiate_raft_sparse_matrix_detail_select_k(float, int64_t); +instantiate_raft_sparse_matrix_detail_select_k(float, uint32_t); +instantiate_raft_sparse_matrix_detail_select_k(float, int); +instantiate_raft_sparse_matrix_detail_select_k(double, int64_t); +instantiate_raft_sparse_matrix_detail_select_k(double, uint32_t); + +#undef instantiate_raft_sparse_matrix_detail_select_k diff --git a/cpp/include/raft/sparse/matrix/detail/select_k-inl.cuh b/cpp/include/raft/sparse/matrix/detail/select_k-inl.cuh new file mode 100644 index 0000000000..5f39affce6 --- /dev/null +++ b/cpp/include/raft/sparse/matrix/detail/select_k-inl.cuh @@ -0,0 +1,225 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +namespace raft::sparse::matrix::detail { + +using namespace raft::matrix::detail; +using raft::matrix::SelectAlgo; + +/** + * Selects the k smallest or largest keys/values from each row of the input CSR matrix. + * + * This function operates on a CSR matrix `in_val` with a logical dense shape of [batch_size, len], + * selecting the k smallest or largest elements from each row. The selected elements are then stored + * in a row-major output matrix `out_val` with dimensions `batch_size` x k. + * + * @tparam T + * Type of the elements being compared (keys). + * @tparam IdxT + * Type of the indices associated with the keys. + * @tparam NZType + * Type representing non-zero elements of `in_val`. + * + * @param[in] handle + * Container for managing reusable resources. + * @param[in] in_val + * Input matrix in CSR format with a logical dense shape of [batch_size, len], + * containing the elements to be compared and selected. + * @param[in] in_idx + * Optional input indices [in_val.nnz] associated with `in_val.values`. + * If `in_idx` is `std::nullopt`, it defaults to a contiguous array from 0 to len-1. + * @param[out] out_val + * Output matrix [in_val.get_n_row(), k] storing the selected k smallest/largest elements + * from each row of `in_val`. + * @param[out] out_idx + * Output indices [in_val.get_n_row(), k] corresponding to the selected elements in `out_val`. + * @param[in] select_min + * Flag indicating whether to select the k smallest (true) or largest (false) elements. + * @param[in] sorted + * whether to make sure selected pairs are sorted by value + * @param[in] algo + * the selection algorithm to use + */ +template +void select_k(raft::resources const& handle, + raft::device_csr_matrix_view in_val, + std::optional> in_idx, + raft::device_matrix_view out_val, + raft::device_matrix_view out_idx, + bool select_min, + bool sorted = false, + SelectAlgo algo = SelectAlgo::kAuto) +{ + auto csr_view = in_val.structure_view(); + auto nnz = csr_view.get_nnz(); + + if (nnz == 0) return; + + auto batch_size = csr_view.get_n_rows(); + auto len = csr_view.get_n_cols(); + auto k = IdxT(out_val.extent(1)); + + common::nvtx::range fun_scope( + "sparse::matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k); + + RAFT_EXPECTS(out_val.extent(1) <= int64_t(std::numeric_limits::max()), + "output k must fit the int type."); + + RAFT_EXPECTS(batch_size == out_val.extent(0), "batch sizes must be equal"); + RAFT_EXPECTS(batch_size == out_idx.extent(0), "batch sizes must be equal"); + + if (in_idx.has_value()) { + RAFT_EXPECTS(size_t(nnz) == in_idx->size(), + "nnz of in_val must be equal to the length of in_idx"); + } + RAFT_EXPECTS(IdxT(k) == out_idx.extent(1), "value and index output lengths must be equal"); + + if (algo == SelectAlgo::kAuto) { algo = choose_select_k_algorithm(batch_size, len, k); } + + auto indptr = csr_view.get_indptr().data(); + + switch (algo) { + case SelectAlgo::kRadix8bits: + case SelectAlgo::kRadix11bits: + case SelectAlgo::kRadix11bitsExtraPass: { + if (algo == SelectAlgo::kRadix8bits) { + select::radix::select_k( + handle, + in_val.get_elements().data(), + (in_idx.has_value() ? in_idx->data_handle() : csr_view.get_indices().data()), + batch_size, + len, + k, + out_val.data_handle(), + out_idx.data_handle(), + select_min, + true, + indptr); + } else { + bool fused_last_filter = algo == SelectAlgo::kRadix11bits; + select::radix::select_k( + handle, + in_val.get_elements().data(), + (in_idx.has_value() ? in_idx->data_handle() : csr_view.get_indices().data()), + batch_size, + len, + k, + out_val.data_handle(), + out_idx.data_handle(), + select_min, + fused_last_filter, + indptr); + } + + if (sorted) { + auto offsets = make_device_mdarray( + handle, resource::get_workspace_resource(handle), make_extents(batch_size + 1)); + raft::linalg::map_offset(handle, offsets.view(), mul_const_op(k)); + + auto keys = + raft::make_device_vector_view(out_val.data_handle(), (IdxT)(batch_size * k)); + auto vals = + raft::make_device_vector_view(out_idx.data_handle(), (IdxT)(batch_size * k)); + + segmented_sort_by_key( + handle, raft::make_const_mdspan(offsets.view()), keys, vals, select_min); + } + + return; + } + case SelectAlgo::kWarpDistributed: + return select::warpsort::select_k_impl( + handle, + in_val.get_elements().data(), + (in_idx.has_value() ? in_idx->data_handle() : csr_view.get_indices().data()), + batch_size, + len, + k, + out_val.data_handle(), + out_idx.data_handle(), + select_min, + indptr); + case SelectAlgo::kWarpDistributedShm: + return select::warpsort::select_k_impl( + handle, + in_val.get_elements().data(), + (in_idx.has_value() ? in_idx->data_handle() : csr_view.get_indices().data()), + batch_size, + len, + k, + out_val.data_handle(), + out_idx.data_handle(), + select_min, + indptr); + case SelectAlgo::kWarpAuto: + return select::warpsort::select_k( + handle, + in_val.get_elements().data(), + (in_idx.has_value() ? in_idx->data_handle() : csr_view.get_indices().data()), + batch_size, + len, + k, + out_val.data_handle(), + out_idx.data_handle(), + select_min, + indptr); + case SelectAlgo::kWarpImmediate: + return select::warpsort::select_k_impl( + handle, + in_val.get_elements().data(), + (in_idx.has_value() ? in_idx->data_handle() : csr_view.get_indices().data()), + batch_size, + len, + k, + out_val.data_handle(), + out_idx.data_handle(), + select_min, + indptr); + case SelectAlgo::kWarpFiltered: + return select::warpsort::select_k_impl( + handle, + in_val.get_elements().data(), + (in_idx.has_value() ? in_idx->data_handle() : csr_view.get_indices().data()), + batch_size, + len, + k, + out_val.data_handle(), + out_idx.data_handle(), + select_min, + indptr); + default: RAFT_FAIL("K-selection Algorithm not supported."); + } + + return; +} + +} // namespace raft::sparse::matrix::detail diff --git a/cpp/include/raft/sparse/matrix/detail/select_k.cuh b/cpp/include/raft/sparse/matrix/detail/select_k.cuh new file mode 100644 index 0000000000..711169984b --- /dev/null +++ b/cpp/include/raft/sparse/matrix/detail/select_k.cuh @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "select_k-inl.cuh" +#endif + +#ifdef RAFT_COMPILED +#include "select_k-ext.cuh" +#endif diff --git a/cpp/include/raft/sparse/matrix/select_k.cuh b/cpp/include/raft/sparse/matrix/select_k.cuh index 030b5a354f..3f97e60c99 100644 --- a/cpp/include/raft/sparse/matrix/select_k.cuh +++ b/cpp/include/raft/sparse/matrix/select_k.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,8 +21,8 @@ #include #include #include -#include #include +#include #include @@ -79,7 +79,7 @@ void select_k(raft::resources const& handle, bool sorted = false, SelectAlgo algo = SelectAlgo::kAuto) { - return raft::matrix::detail::select_k( + return detail::select_k( handle, in_val, in_idx, out_val, out_idx, select_min, sorted, algo); } /** @} */ // end of group select_k diff --git a/cpp/src/matrix/detail/select_k_double_int64_t.cu b/cpp/src/matrix/detail/select_k_double_int64_t.cu index c17018efe0..bf234aacbf 100644 --- a/cpp/src/matrix/detail/select_k_double_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_double_int64_t.cu @@ -33,18 +33,3 @@ instantiate_raft_matrix_detail_select_k(double, int64_t); #undef instantiate_raft_matrix_detail_select_k - -#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ - template void raft::matrix::detail::select_k( \ - raft::resources const& handle, \ - raft::device_csr_matrix_view in_val, \ - std::optional> in_idx, \ - raft::device_matrix_view out_val, \ - raft::device_matrix_view out_idx, \ - bool select_min, \ - bool sorted, \ - raft::matrix::SelectAlgo algo) - -instantiate_raft_matrix_detail_select_k(double, int64_t); - -#undef instantiate_raft_matrix_detail_select_k \ No newline at end of file diff --git a/cpp/src/matrix/detail/select_k_double_uint32_t.cu b/cpp/src/matrix/detail/select_k_double_uint32_t.cu index fcc3e5d5a7..7f0511a76a 100644 --- a/cpp/src/matrix/detail/select_k_double_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_double_uint32_t.cu @@ -35,18 +35,3 @@ instantiate_raft_matrix_detail_select_k(double, uint32_t); #undef instantiate_raft_matrix_detail_select_k - -#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ - template void raft::matrix::detail::select_k( \ - raft::resources const& handle, \ - raft::device_csr_matrix_view in_val, \ - std::optional> in_idx, \ - raft::device_matrix_view out_val, \ - raft::device_matrix_view out_idx, \ - bool select_min, \ - bool sorted, \ - raft::matrix::SelectAlgo algo) - -instantiate_raft_matrix_detail_select_k(double, uint32_t); - -#undef instantiate_raft_matrix_detail_select_k \ No newline at end of file diff --git a/cpp/src/matrix/detail/select_k_float_int32.cu b/cpp/src/matrix/detail/select_k_float_int32.cu index 82041a9b2d..e68b1e32df 100644 --- a/cpp/src/matrix/detail/select_k_float_int32.cu +++ b/cpp/src/matrix/detail/select_k_float_int32.cu @@ -33,18 +33,3 @@ instantiate_raft_matrix_detail_select_k(float, int); #undef instantiate_raft_matrix_detail_select_k - -#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ - template void raft::matrix::detail::select_k( \ - raft::resources const& handle, \ - raft::device_csr_matrix_view in_val, \ - std::optional> in_idx, \ - raft::device_matrix_view out_val, \ - raft::device_matrix_view out_idx, \ - bool select_min, \ - bool sorted, \ - raft::matrix::SelectAlgo algo) - -instantiate_raft_matrix_detail_select_k(float, int); - -#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/src/matrix/detail/select_k_float_int64_t.cu b/cpp/src/matrix/detail/select_k_float_int64_t.cu index 4d381b417f..5aa40d8c9d 100644 --- a/cpp/src/matrix/detail/select_k_float_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_float_int64_t.cu @@ -33,18 +33,3 @@ instantiate_raft_matrix_detail_select_k(float, int64_t); #undef instantiate_raft_matrix_detail_select_k - -#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ - template void raft::matrix::detail::select_k( \ - raft::resources const& handle, \ - raft::device_csr_matrix_view in_val, \ - std::optional> in_idx, \ - raft::device_matrix_view out_val, \ - raft::device_matrix_view out_idx, \ - bool select_min, \ - bool sorted, \ - raft::matrix::SelectAlgo algo) - -instantiate_raft_matrix_detail_select_k(float, uint64_t); - -#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/src/matrix/detail/select_k_float_uint32_t.cu b/cpp/src/matrix/detail/select_k_float_uint32_t.cu index 775807cfac..9aba147edf 100644 --- a/cpp/src/matrix/detail/select_k_float_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_float_uint32_t.cu @@ -33,18 +33,3 @@ instantiate_raft_matrix_detail_select_k(float, uint32_t); #undef instantiate_raft_matrix_detail_select_k - -#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ - template void raft::matrix::detail::select_k( \ - raft::resources const& handle, \ - raft::device_csr_matrix_view in_val, \ - std::optional> in_idx, \ - raft::device_matrix_view out_val, \ - raft::device_matrix_view out_idx, \ - bool select_min, \ - bool sorted, \ - raft::matrix::SelectAlgo algo) - -instantiate_raft_matrix_detail_select_k(float, uint32_t); - -#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/src/matrix/detail/select_k_half_int64_t.cu b/cpp/src/matrix/detail/select_k_half_int64_t.cu index cfd260326b..bc513e4aeb 100644 --- a/cpp/src/matrix/detail/select_k_half_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_half_int64_t.cu @@ -33,18 +33,3 @@ instantiate_raft_matrix_detail_select_k(__half, int64_t); #undef instantiate_raft_matrix_detail_select_k - -#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ - template void raft::matrix::detail::select_k( \ - raft::resources const& handle, \ - raft::device_csr_matrix_view in_val, \ - std::optional> in_idx, \ - raft::device_matrix_view out_val, \ - raft::device_matrix_view out_idx, \ - bool select_min, \ - bool sorted, \ - raft::matrix::SelectAlgo algo) - -instantiate_raft_matrix_detail_select_k(__half, int64_t); - -#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/src/matrix/detail/select_k_half_uint32_t.cu b/cpp/src/matrix/detail/select_k_half_uint32_t.cu index c252337f97..e46c7d46bb 100644 --- a/cpp/src/matrix/detail/select_k_half_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_half_uint32_t.cu @@ -33,18 +33,3 @@ instantiate_raft_matrix_detail_select_k(__half, uint32_t); #undef instantiate_raft_matrix_detail_select_k - -#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ - template void raft::matrix::detail::select_k( \ - raft::resources const& handle, \ - raft::device_csr_matrix_view in_val, \ - std::optional> in_idx, \ - raft::device_matrix_view out_val, \ - raft::device_matrix_view out_idx, \ - bool select_min, \ - bool sorted, \ - raft::matrix::SelectAlgo algo) - -instantiate_raft_matrix_detail_select_k(__half, uint32_t); - -#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/src/sparse/matrix/detail/select_k_double_int64_t.cu b/cpp/src/sparse/matrix/detail/select_k_double_int64_t.cu new file mode 100644 index 0000000000..c784b50dad --- /dev/null +++ b/cpp/src/sparse/matrix/detail/select_k_double_int64_t.cu @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_sparse_matrix_detail_select_k(T, IdxT) \ + template void raft::sparse::matrix::detail::select_k( \ + raft::resources const& handle, \ + raft::device_csr_matrix_view in_val, \ + std::optional> in_idx, \ + raft::device_matrix_view out_val, \ + raft::device_matrix_view out_idx, \ + bool select_min, \ + bool sorted, \ + raft::matrix::SelectAlgo algo) + +instantiate_raft_sparse_matrix_detail_select_k(double, int64_t); + +#undef instantiate_raft_sparse_matrix_detail_select_k \ No newline at end of file diff --git a/cpp/src/sparse/matrix/detail/select_k_double_uint32_t.cu b/cpp/src/sparse/matrix/detail/select_k_double_uint32_t.cu new file mode 100644 index 0000000000..98bab9a504 --- /dev/null +++ b/cpp/src/sparse/matrix/detail/select_k_double_uint32_t.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include // uint32_t + +#define instantiate_raft_sparse_matrix_detail_select_k(T, IdxT) \ + template void raft::sparse::matrix::detail::select_k( \ + raft::resources const& handle, \ + raft::device_csr_matrix_view in_val, \ + std::optional> in_idx, \ + raft::device_matrix_view out_val, \ + raft::device_matrix_view out_idx, \ + bool select_min, \ + bool sorted, \ + raft::matrix::SelectAlgo algo) + +instantiate_raft_sparse_matrix_detail_select_k(double, uint32_t); + +#undef instantiate_raft_sparse_matrix_detail_select_k \ No newline at end of file diff --git a/cpp/src/sparse/matrix/detail/select_k_float_int32.cu b/cpp/src/sparse/matrix/detail/select_k_float_int32.cu new file mode 100644 index 0000000000..bff213ae69 --- /dev/null +++ b/cpp/src/sparse/matrix/detail/select_k_float_int32.cu @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_sparse_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k( \ + raft::resources const& handle, \ + raft::device_csr_matrix_view in_val, \ + std::optional> in_idx, \ + raft::device_matrix_view out_val, \ + raft::device_matrix_view out_idx, \ + bool select_min, \ + bool sorted, \ + raft::matrix::SelectAlgo algo) + +instantiate_raft_sparse_matrix_detail_select_k(float, int); + +#undef instantiate_raft_sparse_matrix_detail_select_k diff --git a/cpp/src/sparse/matrix/detail/select_k_float_int64_t.cu b/cpp/src/sparse/matrix/detail/select_k_float_int64_t.cu new file mode 100644 index 0000000000..412b06e587 --- /dev/null +++ b/cpp/src/sparse/matrix/detail/select_k_float_int64_t.cu @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_sparse_matrix_detail_select_k(T, IdxT) \ + template void raft::sparse::matrix::detail::select_k( \ + raft::resources const& handle, \ + raft::device_csr_matrix_view in_val, \ + std::optional> in_idx, \ + raft::device_matrix_view out_val, \ + raft::device_matrix_view out_idx, \ + bool select_min, \ + bool sorted, \ + raft::matrix::SelectAlgo algo) + +instantiate_raft_sparse_matrix_detail_select_k(float, int64_t); + +#undef instantiate_raft_sparse_matrix_detail_select_k diff --git a/cpp/src/sparse/matrix/detail/select_k_float_uint32_t.cu b/cpp/src/sparse/matrix/detail/select_k_float_uint32_t.cu new file mode 100644 index 0000000000..8ba3f0e22b --- /dev/null +++ b/cpp/src/sparse/matrix/detail/select_k_float_uint32_t.cu @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_sparse_matrix_detail_select_k(T, IdxT) \ + template void raft::sparse::matrix::detail::select_k( \ + raft::resources const& handle, \ + raft::device_csr_matrix_view in_val, \ + std::optional> in_idx, \ + raft::device_matrix_view out_val, \ + raft::device_matrix_view out_idx, \ + bool select_min, \ + bool sorted, \ + raft::matrix::SelectAlgo algo) + +instantiate_raft_sparse_matrix_detail_select_k(float, uint32_t); + +#undef instantiate_raft_sparse_matrix_detail_select_k diff --git a/cpp/src/sparse/matrix/detail/select_k_half_int64_t.cu b/cpp/src/sparse/matrix/detail/select_k_half_int64_t.cu new file mode 100644 index 0000000000..24c844f8c8 --- /dev/null +++ b/cpp/src/sparse/matrix/detail/select_k_half_int64_t.cu @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_sparse_matrix_detail_select_k(T, IdxT) \ + template void raft::sparse::matrix::detail::select_k( \ + raft::resources const& handle, \ + raft::device_csr_matrix_view in_val, \ + std::optional> in_idx, \ + raft::device_matrix_view out_val, \ + raft::device_matrix_view out_idx, \ + bool select_min, \ + bool sorted, \ + raft::matrix::SelectAlgo algo) + +instantiate_raft_sparse_matrix_detail_select_k(__half, int64_t); + +#undef instantiate_raft_sparse_matrix_detail_select_k diff --git a/cpp/src/sparse/matrix/detail/select_k_half_uint32_t.cu b/cpp/src/sparse/matrix/detail/select_k_half_uint32_t.cu new file mode 100644 index 0000000000..d63dc64933 --- /dev/null +++ b/cpp/src/sparse/matrix/detail/select_k_half_uint32_t.cu @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_sparse_matrix_detail_select_k(T, IdxT) \ + template void raft::sparse::matrix::detail::select_k( \ + raft::resources const& handle, \ + raft::device_csr_matrix_view in_val, \ + std::optional> in_idx, \ + raft::device_matrix_view out_val, \ + raft::device_matrix_view out_idx, \ + bool select_min, \ + bool sorted, \ + raft::matrix::SelectAlgo algo) + +instantiate_raft_sparse_matrix_detail_select_k(__half, uint32_t); + +#undef instantiate_raft_sparse_matrix_detail_select_k diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 17990700e6..4d17aacffd 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -275,12 +275,7 @@ if(BUILD_TESTS) EXPLICIT_INSTANTIATE_ONLY ) - ConfigureTest( - NAME - MATRIX_SELECT_TEST - PATH test/matrix/select_k.cu - LIB - EXPLICIT_INSTANTIATE_ONLY) + ConfigureTest(NAME MATRIX_SELECT_TEST PATH test/matrix/select_k.cu LIB EXPLICIT_INSTANTIATE_ONLY) ConfigureTest( NAME MATRIX_SELECT_LARGE_TEST PATH test/matrix/select_large_k.cu LIB EXPLICIT_INSTANTIATE_ONLY