Skip to content

Commit

Permalink
Move all of the impl over to sparse name scope
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong committed Apr 5, 2024
1 parent 48cc460 commit 04a8bb4
Show file tree
Hide file tree
Showing 22 changed files with 548 additions and 348 deletions.
12 changes: 2 additions & 10 deletions cpp/bench/prims/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,8 @@ if(BUILD_PRIMS_BENCH)
)

ConfigureBench(
NAME
MATRIX_BENCH
PATH
bench/prims/matrix/argmin.cu
bench/prims/matrix/gather.cu
bench/prims/matrix/select_k.cu
bench/prims/main.cpp
OPTIONAL
LIB
EXPLICIT_INSTANTIATE_ONLY
NAME MATRIX_BENCH PATH bench/prims/matrix/argmin.cu bench/prims/matrix/gather.cu
bench/prims/matrix/select_k.cu bench/prims/main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY
)

ConfigureBench(
Expand Down
32 changes: 0 additions & 32 deletions cpp/include/raft/matrix/detail/select_k-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#pragma once

#include <raft/core/device_csr_matrix.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/matrix/select_k_types.hpp>
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT
Expand Down Expand Up @@ -45,16 +44,6 @@ void select_k(raft::resources const& handle,
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto,
const IdxT* len_i = nullptr) RAFT_EXPLICIT;

template <typename T, typename IdxT>
void select_k(raft::resources const& handle,
raft::device_csr_matrix_view<const T, IdxT, IdxT, IdxT> in_val,
std::optional<raft::device_vector_view<const IdxT, IdxT>> in_idx,
raft::device_matrix_view<T, IdxT, raft::row_major> out_val,
raft::device_matrix_view<IdxT, IdxT, raft::row_major> out_idx,
bool select_min,
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto) RAFT_EXPLICIT;
} // namespace raft::matrix::detail

#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY
Expand Down Expand Up @@ -84,24 +73,3 @@ instantiate_raft_matrix_detail_select_k(double, int64_t);
instantiate_raft_matrix_detail_select_k(double, uint32_t);

#undef instantiate_raft_matrix_detail_select_k

#define instantiate_raft_matrix_detail_select_k(T, IdxT) \
extern template void raft::matrix::detail::select_k( \
raft::resources const& handle, \
raft::device_csr_matrix_view<const T, IdxT, IdxT, IdxT> in_val, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> in_idx, \
raft::device_matrix_view<T, IdxT, raft::row_major> out_val, \
raft::device_matrix_view<IdxT, IdxT, raft::row_major> out_idx, \
bool select_min, \
bool sorted, \
raft::matrix::SelectAlgo algo)

instantiate_raft_matrix_detail_select_k(__half, uint32_t);
instantiate_raft_matrix_detail_select_k(__half, int64_t);
instantiate_raft_matrix_detail_select_k(float, int64_t);
instantiate_raft_matrix_detail_select_k(float, uint32_t);
instantiate_raft_matrix_detail_select_k(float, int);
instantiate_raft_matrix_detail_select_k(double, int64_t);
instantiate_raft_matrix_detail_select_k(double, uint32_t);

#undef instantiate_raft_matrix_detail_select_k
192 changes: 0 additions & 192 deletions cpp/include/raft/matrix/detail/select_k-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include "select_radix.cuh"
#include "select_warpsort.cuh"

#include <raft/core/device_csr_matrix.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/nvtx.hpp>
Expand All @@ -31,8 +30,6 @@

#include <cub/cub.cuh>

#include <type_traits>

namespace raft::matrix::detail {

/**
Expand Down Expand Up @@ -320,193 +317,4 @@ void select_k(raft::resources const& handle,
default: RAFT_FAIL("K-selection Algorithm not supported.");
}
}

/**
* Selects the k smallest or largest keys/values from each row of the input CSR matrix.
*
* This function operates on a CSR matrix `in_val` with a logical dense shape of [batch_size, len],
* selecting the k smallest or largest elements from each row. The selected elements are then stored
* in a row-major output matrix `out_val` with dimensions `batch_size` x k.
*
* @tparam T
* Type of the elements being compared (keys).
* @tparam IdxT
* Type of the indices associated with the keys.
* @tparam NZType
* Type representing non-zero elements of `in_val`.
*
* @param[in] handle
* Container for managing reusable resources.
* @param[in] in_val
* Input matrix in CSR format with a logical dense shape of [batch_size, len],
* containing the elements to be compared and selected.
* @param[in] in_idx
* Optional input indices [in_val.nnz] associated with `in_val.values`.
* If `in_idx` is `std::nullopt`, it defaults to a contiguous array from 0 to len-1.
* @param[out] out_val
* Output matrix [in_val.get_n_row(), k] storing the selected k smallest/largest elements
* from each row of `in_val`.
* @param[out] out_idx
* Output indices [in_val.get_n_row(), k] corresponding to the selected elements in `out_val`.
* @param[in] select_min
* Flag indicating whether to select the k smallest (true) or largest (false) elements.
* @param[in] sorted
* whether to make sure selected pairs are sorted by value
* @param[in] algo
* the selection algorithm to use
*/
template <typename T, typename IdxT>
void select_k(raft::resources const& handle,
raft::device_csr_matrix_view<const T, IdxT, IdxT, IdxT> in_val,
std::optional<raft::device_vector_view<const IdxT, IdxT>> in_idx,
raft::device_matrix_view<T, IdxT, raft::row_major> out_val,
raft::device_matrix_view<IdxT, IdxT, raft::row_major> out_idx,
bool select_min,
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto)
{
auto csr_view = in_val.structure_view();
auto nnz = csr_view.get_nnz();

if (nnz == 0) return;

auto batch_size = csr_view.get_n_rows();
auto len = csr_view.get_n_cols();
auto k = IdxT(out_val.extent(1));

RAFT_EXPECTS(out_val.extent(1) <= int64_t(std::numeric_limits<int>::max()),
"output k must fit the int type.");

RAFT_EXPECTS(batch_size == out_val.extent(0), "batch sizes must be equal");
RAFT_EXPECTS(batch_size == out_idx.extent(0), "batch sizes must be equal");

if (in_idx.has_value()) {
RAFT_EXPECTS(size_t(nnz) == in_idx->size(),
"nnz of in_val must be equal to the length of in_idx");
}
RAFT_EXPECTS(IdxT(k) == out_idx.extent(1), "value and index output lengths must be equal");

if (algo == SelectAlgo::kAuto) { algo = choose_select_k_algorithm(batch_size, len, k); }

auto indptr = csr_view.get_indptr().data();

switch (algo) {
case SelectAlgo::kRadix8bits:
case SelectAlgo::kRadix11bits:
case SelectAlgo::kRadix11bitsExtraPass: {
if (algo == SelectAlgo::kRadix8bits) {
detail::select::radix::select_k<T, IdxT, 8, 512, false>(
handle,
in_val.get_elements().data(),
(in_idx.has_value() ? in_idx->data_handle() : csr_view.get_indices().data()),
batch_size,
len,
k,
out_val.data_handle(),
out_idx.data_handle(),
select_min,
true,
indptr);
} else {
bool fused_last_filter = algo == SelectAlgo::kRadix11bits;
detail::select::radix::select_k<T, IdxT, 11, 512, false>(
handle,
in_val.get_elements().data(),
(in_idx.has_value() ? in_idx->data_handle() : csr_view.get_indices().data()),
batch_size,
len,
k,
out_val.data_handle(),
out_idx.data_handle(),
select_min,
fused_last_filter,
indptr);
}

if (sorted) {
auto offsets = make_device_mdarray<IdxT, IdxT>(
handle, resource::get_workspace_resource(handle), make_extents<IdxT>(batch_size + 1));
raft::linalg::map_offset(handle, offsets.view(), mul_const_op<IdxT>(k));

auto keys =
raft::make_device_vector_view<T, IdxT>(out_val.data_handle(), (IdxT)(batch_size * k));
auto vals =
raft::make_device_vector_view<IdxT, IdxT>(out_idx.data_handle(), (IdxT)(batch_size * k));

segmented_sort_by_key<T, IdxT>(
handle, raft::make_const_mdspan(offsets.view()), keys, vals, select_min);
}

return;
}
case SelectAlgo::kWarpDistributed:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_distributed>(
handle,
in_val.get_elements().data(),
(in_idx.has_value() ? in_idx->data_handle() : csr_view.get_indices().data()),
batch_size,
len,
k,
out_val.data_handle(),
out_idx.data_handle(),
select_min,
indptr);
case SelectAlgo::kWarpDistributedShm:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_distributed_ext>(
handle,
in_val.get_elements().data(),
(in_idx.has_value() ? in_idx->data_handle() : csr_view.get_indices().data()),
batch_size,
len,
k,
out_val.data_handle(),
out_idx.data_handle(),
select_min,
indptr);
case SelectAlgo::kWarpAuto:
return detail::select::warpsort::select_k<T, IdxT>(
handle,
in_val.get_elements().data(),
(in_idx.has_value() ? in_idx->data_handle() : csr_view.get_indices().data()),
batch_size,
len,
k,
out_val.data_handle(),
out_idx.data_handle(),
select_min,
indptr);
case SelectAlgo::kWarpImmediate:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_immediate>(
handle,
in_val.get_elements().data(),
(in_idx.has_value() ? in_idx->data_handle() : csr_view.get_indices().data()),
batch_size,
len,
k,
out_val.data_handle(),
out_idx.data_handle(),
select_min,
indptr);
case SelectAlgo::kWarpFiltered:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_filtered>(
handle,
in_val.get_elements().data(),
(in_idx.has_value() ? in_idx->data_handle() : csr_view.get_indices().data()),
batch_size,
len,
k,
out_val.data_handle(),
out_idx.data_handle(),
select_min,
indptr);
default: RAFT_FAIL("K-selection Algorithm not supported.");
}

return;
}

} // namespace raft::matrix::detail
67 changes: 67 additions & 0 deletions cpp/include/raft/sparse/matrix/detail/select_k-ext.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/device_csr_matrix.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/matrix/select_k_types.hpp>
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT

#include <rmm/cuda_stream_view.hpp> // rmm:cuda_stream_view
#include <rmm/mr/device/device_memory_resource.hpp> // rmm::mr::device_memory_resource

#include <cuda_fp16.h> // __half

#include <cstdint> // uint32_t

#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY

namespace raft::sparse::matrix::detail {

template <typename T, typename IdxT>
void select_k(raft::resources const& handle,
raft::device_csr_matrix_view<const T, IdxT, IdxT, IdxT> in_val,
std::optional<raft::device_vector_view<const IdxT, IdxT>> in_idx,
raft::device_matrix_view<T, IdxT, raft::row_major> out_val,
raft::device_matrix_view<IdxT, IdxT, raft::row_major> out_idx,
bool select_min,
bool sorted = false,
raft::matrix::SelectAlgo algo = SelectAlgo::kAuto) RAFT_EXPLICIT;
} // namespace raft::sparse::matrix::detail

#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY

#define instantiate_raft_sparse_matrix_detail_select_k(T, IdxT) \
extern template void raft::sparse::matrix::detail::select_k( \
raft::resources const& handle, \
raft::device_csr_matrix_view<const T, IdxT, IdxT, IdxT> in_val, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> in_idx, \
raft::device_matrix_view<T, IdxT, raft::row_major> out_val, \
raft::device_matrix_view<IdxT, IdxT, raft::row_major> out_idx, \
bool select_min, \
bool sorted, \
raft::matrix::SelectAlgo algo)

instantiate_raft_sparse_matrix_detail_select_k(__half, uint32_t);
instantiate_raft_sparse_matrix_detail_select_k(__half, int64_t);
instantiate_raft_sparse_matrix_detail_select_k(float, int64_t);
instantiate_raft_sparse_matrix_detail_select_k(float, uint32_t);
instantiate_raft_sparse_matrix_detail_select_k(float, int);
instantiate_raft_sparse_matrix_detail_select_k(double, int64_t);
instantiate_raft_sparse_matrix_detail_select_k(double, uint32_t);

#undef instantiate_raft_sparse_matrix_detail_select_k
Loading

0 comments on commit 04a8bb4

Please sign in to comment.