Skip to content

Commit

Permalink
Fix benchmarks build
Browse files Browse the repository at this point in the history
  • Loading branch information
harrism committed Apr 23, 2024
1 parent debb2f5 commit 87e0737
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 16 deletions.
11 changes: 6 additions & 5 deletions cpp/bench/ann/src/raft/raft_cagra_hnswlib.cu
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,11 @@ int main(int argc, char** argv)
// and is initially sized to half of free device memory.
rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource> pool_mr{
&cuda_mr, rmm::percent_of_free_device_memory(50)};
rmm::mr::set_current_device_resource(
&pool_mr); // Updates the current device resource pointer to `pool_mr`
rmm::device_async_resource_ref mr =
rmm::mr::get_current_device_resource(); // Points to `pool_mr`
return raft::bench::ann::run_main(argc, argv);
// Updates the current device resource pointer to `pool_mr`
auto old_mr = rmm::mr::set_current_device_resource(&pool_mr);
auto ret = raft::bench::ann::run_main(argc, argv);
// Restores the current device resource pointer to its previous value
rmm::mr::set_current_device_resource(old_mr);
return ret;
}
#endif
11 changes: 9 additions & 2 deletions cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,14 @@ void RaftIvfFlatGpu<T, IdxT>::search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const
{
static_assert(sizeof(size_t) == sizeof(IdxT), "IdxT is incompatible with size_t");
raft::neighbors::ivf_flat::search(
handle_, search_params_, *index_, queries, batch_size, k, (IdxT*)neighbors, distances);
raft::neighbors::ivf_flat::search(handle_,
search_params_,
*index_,
queries,
batch_size,
k,
(IdxT*)neighbors,
distances,
resource::get_workspace_resource(handle_));
}
} // namespace raft::bench::ann
4 changes: 2 additions & 2 deletions cpp/bench/prims/common/benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
#include <rmm/cuda_stream.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_buffer.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <benchmark/benchmark.h>

Expand All @@ -44,7 +44,7 @@ namespace raft::bench {
*/
struct using_pool_memory_res {
private:
rmm::device_async_resource_ref orig_res_;
rmm::mr::device_memory_resource* orig_res_;
rmm::mr::cuda_memory_resource cuda_res_{};
rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource> pool_res_;

Expand Down
4 changes: 2 additions & 2 deletions cpp/bench/prims/matrix/gather.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
#include <raft/util/itertools.hpp>

#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>
#include <rmm/resource_ref.hpp>

namespace raft::bench::matrix {

Expand Down Expand Up @@ -109,7 +109,7 @@ struct Gather : public fixture {

private:
GatherParams<IdxT> params;
rmm::device_async_resource_ref old_mr;
rmm::mr::device_memory_resource* old_mr;
rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource> pool_mr;
raft::device_matrix<T, IdxT> matrix, out;
raft::host_matrix<T, IdxT> matrix_h;
Expand Down
14 changes: 11 additions & 3 deletions cpp/bench/prims/neighbors/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <raft/spatial/knn/knn.cuh>
#include <raft/util/itertools.hpp>

#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/mr/device/managed_memory_resource.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/host/new_delete_resource.hpp>
Expand Down Expand Up @@ -106,7 +107,7 @@ struct device_resource {

private:
const bool managed_;
rmm::device_async_resource_ref res_;
rmm::mr::device_memory_resource* res_;
};

template <typename T>
Expand Down Expand Up @@ -159,8 +160,15 @@ struct ivf_flat_knn {
IdxT* out_idxs)
{
search_params.n_probes = 20;
raft::neighbors::ivf_flat::search(
handle, search_params, *index, search_items, ps.n_queries, ps.k, out_idxs, out_dists);
raft::neighbors::ivf_flat::search(handle,
search_params,
*index,
search_items,
ps.n_queries,
ps.k,
out_idxs,
out_dists,
resource::get_workspace_resource(handle));
}
};

Expand Down
4 changes: 2 additions & 2 deletions cpp/bench/prims/random/subsample.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
#include <raft/util/cudart_utils.hpp>

#include <rmm/device_scalar.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <cub/cub.cuh>

Expand Down Expand Up @@ -95,7 +95,7 @@ struct sample : public fixture {
private:
float GiB = 1073741824.0f;
raft::device_resources res;
rmm::device_async_resource_ref old_mr;
rmm::mr::device_memory_resource* old_mr;
rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource> pool_mr;
sample_inputs params;
raft::device_vector<T, int64_t> out, in;
Expand Down

0 comments on commit 87e0737

Please sign in to comment.