Skip to content

Commit

Permalink
Merge pull request #2237 from rapidsai/branch-24.04
Browse files Browse the repository at this point in the history
Forward-merge branch-24.04 to branch-24.06
  • Loading branch information
GPUtester authored Mar 21, 2024
2 parents b06b936 + b773494 commit 0eb6ed8
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 3 deletions.
23 changes: 23 additions & 0 deletions cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ extern template class raft::bench::ann::RaftIvfPQ<int8_t, int64_t>;
#endif
#ifdef RAFT_ANN_BENCH_USE_RAFT_CAGRA
extern template class raft::bench::ann::RaftCagra<float, uint32_t>;
extern template class raft::bench::ann::RaftCagra<half, uint32_t>;
extern template class raft::bench::ann::RaftCagra<uint8_t, uint32_t>;
extern template class raft::bench::ann::RaftCagra<int8_t, uint32_t>;
#endif
Expand Down Expand Up @@ -149,6 +150,20 @@ void parse_build_param(const nlohmann::json& conf,
}
}

inline void parse_build_param(const nlohmann::json& conf, raft::neighbors::vpq_params& param)
{
if (conf.contains("pq_bits")) { param.pq_bits = conf.at("pq_bits"); }
if (conf.contains("pq_dim")) { param.pq_dim = conf.at("pq_dim"); }
if (conf.contains("vq_n_centers")) { param.vq_n_centers = conf.at("vq_n_centers"); }
if (conf.contains("kmeans_n_iters")) { param.kmeans_n_iters = conf.at("kmeans_n_iters"); }
if (conf.contains("vq_kmeans_trainset_fraction")) {
param.vq_kmeans_trainset_fraction = conf.at("vq_kmeans_trainset_fraction");
}
if (conf.contains("pq_kmeans_trainset_fraction")) {
param.pq_kmeans_trainset_fraction = conf.at("pq_kmeans_trainset_fraction");
}
}

nlohmann::json collect_conf_with_prefix(const nlohmann::json& conf,
const std::string& prefix,
bool remove_prefix = true)
Expand Down Expand Up @@ -204,6 +219,12 @@ void parse_build_param(const nlohmann::json& conf,
}
param.nn_descent_params = nn_param;
}
nlohmann::json comp_search_conf = collect_conf_with_prefix(conf, "compression_");
if (!comp_search_conf.empty()) {
raft::neighbors::vpq_params vpq_pams;
parse_build_param(comp_search_conf, vpq_pams);
param.cagra_params.compression.emplace(vpq_pams);
}
}

raft::bench::ann::AllocatorType parse_allocator(std::string mem_type)
Expand Down Expand Up @@ -248,5 +269,7 @@ void parse_search_param(const nlohmann::json& conf,
if (conf.contains("internal_dataset_memory_type")) {
param.dataset_mem = parse_allocator(conf.at("internal_dataset_memory_type"));
}
// Same ratio as in IVF-PQ
param.refine_ratio = conf.value("refine_ratio", 1.0f);
}
#endif
83 changes: 80 additions & 3 deletions cpp/bench/ann/src/raft/raft_cagra_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <raft/neighbors/cagra.cuh>
#include <raft/neighbors/cagra_serialize.cuh>
#include <raft/neighbors/cagra_types.hpp>
#include <raft/neighbors/dataset.hpp>
#include <raft/neighbors/detail/cagra/cagra_build.cuh>
#include <raft/neighbors/ivf_pq_types.hpp>
#include <raft/neighbors/nn_descent_types.hpp>
Expand Down Expand Up @@ -56,6 +57,7 @@ class RaftCagra : public ANN<T>, public AnnGPU {

struct SearchParam : public AnnSearchParam {
raft::neighbors::experimental::cagra::search_params p;
float refine_ratio;
AllocatorType graph_mem = AllocatorType::Device;
AllocatorType dataset_mem = AllocatorType::Device;
auto needs_dataset() const -> bool override { return true; }
Expand Down Expand Up @@ -98,6 +100,8 @@ class RaftCagra : public ANN<T>, public AnnGPU {
// will be filled with (size_t)-1
void search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override;
void search_base(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const;

[[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override
{
Expand All @@ -124,6 +128,7 @@ class RaftCagra : public ANN<T>, public AnnGPU {
raft::mr::cuda_huge_page_resource mr_huge_page_;
AllocatorType graph_mem_;
AllocatorType dataset_mem_;
float refine_ratio_;
BuildParam index_params_;
bool need_dataset_update_;
raft::neighbors::cagra::search_params search_params_;
Expand Down Expand Up @@ -151,6 +156,9 @@ void RaftCagra<T, IdxT>::build(const T* dataset, size_t nrow)

auto& params = index_params_.cagra_params;

// Do include the compressed dataset for the CAGRA-Q
bool shall_include_dataset = params.compression.has_value();

index_ = std::make_shared<raft::neighbors::cagra::index<T, IdxT>>(
std::move(raft::neighbors::cagra::detail::build(handle_,
params,
Expand All @@ -159,7 +167,7 @@ void RaftCagra<T, IdxT>::build(const T* dataset, size_t nrow)
index_params_.ivf_pq_refine_rate,
index_params_.ivf_pq_build_params,
index_params_.ivf_pq_search_params,
false)));
shall_include_dataset)));
}

inline std::string allocator_to_string(AllocatorType mem_type)
Expand All @@ -179,6 +187,7 @@ void RaftCagra<T, IdxT>::set_search_param(const AnnSearchParam& param)
{
auto search_param = dynamic_cast<const SearchParam&>(param);
search_params_ = search_param.p;
refine_ratio_ = search_param.refine_ratio;
if (search_param.graph_mem != graph_mem_) {
// Move graph to correct memory space
graph_mem_ = search_param.graph_mem;
Expand Down Expand Up @@ -223,12 +232,16 @@ void RaftCagra<T, IdxT>::set_search_param(const AnnSearchParam& param)
template <typename T, typename IdxT>
void RaftCagra<T, IdxT>::set_search_dataset(const T* dataset, size_t nrow)
{
using ds_idx_type = decltype(index_->data().n_rows());
bool is_vpq =
dynamic_cast<const raft::neighbors::vpq_dataset<half, ds_idx_type>*>(&index_->data()) ||
dynamic_cast<const raft::neighbors::vpq_dataset<float, ds_idx_type>*>(&index_->data());
// It can happen that we are re-using a previous algo object which already has
// the dataset set. Check if we need update.
if (static_cast<size_t>(input_dataset_v_->extent(0)) != nrow ||
input_dataset_v_->data_handle() != dataset) {
*input_dataset_v_ = make_device_matrix_view<const T, int64_t>(dataset, nrow, this->dim_);
need_dataset_update_ = true;
need_dataset_update_ = !is_vpq; // ignore update if this is a VPQ dataset.
}
}

Expand Down Expand Up @@ -258,7 +271,7 @@ std::unique_ptr<ANN<T>> RaftCagra<T, IdxT>::copy()
}

template <typename T, typename IdxT>
void RaftCagra<T, IdxT>::search(
void RaftCagra<T, IdxT>::search_base(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const
{
IdxT* neighbors_IdxT;
Expand Down Expand Up @@ -286,4 +299,68 @@ void RaftCagra<T, IdxT>::search(
raft::resource::get_cuda_stream(handle_));
}
}

template <typename T, typename IdxT>
void RaftCagra<T, IdxT>::search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const
{
auto k0 = static_cast<size_t>(refine_ratio_ * k);
const bool disable_refinement = k0 <= static_cast<size_t>(k);
const raft::resources& res = handle_;
auto stream = resource::get_cuda_stream(res);

if (disable_refinement) {
search_base(queries, batch_size, k, neighbors, distances);
} else {
auto candidate_ixs = raft::make_device_matrix<int64_t, int64_t>(res, batch_size, k0);
auto candidate_dists = raft::make_device_matrix<float, int64_t>(res, batch_size, k0);
search_base(queries,
batch_size,
k0,
reinterpret_cast<size_t*>(candidate_ixs.data_handle()),
candidate_dists.data_handle());

if (raft::get_device_for_address(input_dataset_v_->data_handle()) >= 0) {
auto queries_v =
raft::make_device_matrix_view<const T, int64_t>(queries, batch_size, dimension_);
auto neighours_v = raft::make_device_matrix_view<int64_t, int64_t>(
reinterpret_cast<int64_t*>(neighbors), batch_size, k);
auto distances_v = raft::make_device_matrix_view<float, int64_t>(distances, batch_size, k);
raft::neighbors::refine<int64_t, T, float, int64_t>(
res,
*input_dataset_v_,
queries_v,
raft::make_const_mdspan(candidate_ixs.view()),
neighours_v,
distances_v,
index_->metric());
} else {
auto dataset_host = raft::make_host_matrix_view<const T, int64_t>(
input_dataset_v_->data_handle(), input_dataset_v_->extent(0), input_dataset_v_->extent(1));
auto queries_host = raft::make_host_matrix<T, int64_t>(batch_size, dimension_);
auto candidates_host = raft::make_host_matrix<int64_t, int64_t>(batch_size, k0);
auto neighbors_host = raft::make_host_matrix<int64_t, int64_t>(batch_size, k);
auto distances_host = raft::make_host_matrix<float, int64_t>(batch_size, k);

raft::copy(queries_host.data_handle(), queries, queries_host.size(), stream);
raft::copy(
candidates_host.data_handle(), candidate_ixs.data_handle(), candidates_host.size(), stream);

raft::resource::sync_stream(res); // wait for the queries and candidates
raft::neighbors::refine<int64_t, T, float, int64_t>(res,
dataset_host,
queries_host.view(),
candidates_host.view(),
neighbors_host.view(),
distances_host.view(),
index_->metric());

raft::copy(neighbors,
reinterpret_cast<size_t*>(neighbors_host.data_handle()),
neighbors_host.size(),
stream);
raft::copy(distances, distances_host.data_handle(), distances_host.size(), stream);
}
}
}
} // namespace raft::bench::ann

0 comments on commit 0eb6ed8

Please sign in to comment.