Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert device_memory_resource* to device_async_resource_ref #2269

Merged
merged 6 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,6 @@ if(RAFT_COMPILE_LIBRARY)
src/spatial/knn/detail/fused_l2_knn_int32_t_float.cu
src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu
src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu
src/util/memory_pool.cpp
)
set_target_properties(
raft_objs
Expand Down
5 changes: 3 additions & 2 deletions cpp/bench/ann/src/raft/raft_ann_bench_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/failure_callback_resource_adaptor.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <memory>
#include <type_traits>
Expand Down Expand Up @@ -130,8 +131,8 @@ class configured_raft_resources {
{
}

configured_raft_resources(configured_raft_resources&&) = default;
configured_raft_resources& operator=(configured_raft_resources&&) = default;
configured_raft_resources(configured_raft_resources&&) = delete;
configured_raft_resources& operator=(configured_raft_resources&&) = delete;
Comment on lines +134 to +135
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mark this constructor as deleted only because it defaulted to deleted already due to raft::device_resources member, or there was a more specific/design issue?
If the former, I'd prefer to wrap the device_resources into a unique pointer to keep the move constructor (we move it in benchmarks when initializing algo wrappers). But I can do it later in a follow-on PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly. clang-tidy complains that this is explicitly defaulted even though it is implicitly deleted due to the deleted ctors in the res_ member.

If you are going to have these clang-tidy settings, why ignore the warnings? Wrapping the device resources in a unique ptr is up to you -- it's outside of my scope and the scope of this PR.

I think I can revert all changes to this file. Turns out the resource_ref header is no longer used.

Copy link
Member Author

@harrism harrism Apr 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it should include device_memory_resource.hpp in order to correctly IWYU.

~configured_raft_resources() = default;
configured_raft_resources(const configured_raft_resources& res)
: configured_raft_resources{res.shared_res_}
Expand Down
12 changes: 7 additions & 5 deletions cpp/bench/ann/src/raft/raft_cagra_hnswlib.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <rmm/cuda_device.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>
#include <rmm/resource_ref.hpp>

#define JSON_DIAGNOSTICS 1
#include <nlohmann/json.hpp>
Expand Down Expand Up @@ -89,10 +90,11 @@ int main(int argc, char** argv)
// and is initially sized to half of free device memory.
rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource> pool_mr{
&cuda_mr, rmm::percent_of_free_device_memory(50)};
rmm::mr::set_current_device_resource(
&pool_mr); // Updates the current device resource pointer to `pool_mr`
rmm::mr::device_memory_resource* mr =
rmm::mr::get_current_device_resource(); // Points to `pool_mr`
return raft::bench::ann::run_main(argc, argv);
// Updates the current device resource pointer to `pool_mr`
auto old_mr = rmm::mr::set_current_device_resource(&pool_mr);
auto ret = raft::bench::ann::run_main(argc, argv);
// Restores the current device resource pointer to its previous value
rmm::mr::set_current_device_resource(old_mr);
return ret;
}
#endif
4 changes: 2 additions & 2 deletions cpp/bench/ann/src/raft/raft_cagra_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
#include <raft/util/cudart_utils.hpp>

#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <cassert>
#include <fstream>
Expand Down Expand Up @@ -138,7 +138,7 @@ class RaftCagra : public ANN<T>, public AnnGPU {
std::shared_ptr<raft::device_matrix<T, int64_t, row_major>> dataset_;
std::shared_ptr<raft::device_matrix_view<const T, int64_t, row_major>> input_dataset_v_;

inline rmm::mr::device_memory_resource* get_mr(AllocatorType mem_type)
inline rmm::device_async_resource_ref get_mr(AllocatorType mem_type)
{
switch (mem_type) {
case (AllocatorType::HostPinned): return &mr_pinned_;
Expand Down
11 changes: 9 additions & 2 deletions cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,14 @@ void RaftIvfFlatGpu<T, IdxT>::search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const
{
static_assert(sizeof(size_t) == sizeof(IdxT), "IdxT is incompatible with size_t");
raft::neighbors::ivf_flat::search(
handle_, search_params_, *index_, queries, batch_size, k, (IdxT*)neighbors, distances);
raft::neighbors::ivf_flat::search(handle_,
search_params_,
*index_,
queries,
batch_size,
k,
(IdxT*)neighbors,
distances,
resource::get_workspace_resource(handle_));
}
} // namespace raft::bench::ann
3 changes: 0 additions & 3 deletions cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@
#include <raft/neighbors/refine.cuh>
#include <raft/util/cudart_utils.hpp>

#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>

#include <type_traits>

namespace raft::bench::ann {
Expand Down
1 change: 1 addition & 0 deletions cpp/bench/prims/common/benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <rmm/cuda_stream.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_buffer.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>

Expand Down
1 change: 1 addition & 0 deletions cpp/bench/prims/matrix/gather.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <raft/util/itertools.hpp>

#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>

namespace raft::bench::matrix {
Expand Down
15 changes: 12 additions & 3 deletions cpp/bench/prims/neighbors/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@
#include <raft/spatial/knn/knn.cuh>
#include <raft/util/itertools.hpp>

#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/mr/device/managed_memory_resource.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/host/new_delete_resource.hpp>
#include <rmm/mr/host/pinned_memory_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <thrust/sequence.h>

Expand Down Expand Up @@ -101,7 +103,7 @@ struct device_resource {
if (managed_) { delete res_; }
}

[[nodiscard]] auto get() const -> rmm::mr::device_memory_resource* { return res_; }
[[nodiscard]] auto get() const -> rmm::device_async_resource_ref { return res_; }

private:
const bool managed_;
Expand Down Expand Up @@ -158,8 +160,15 @@ struct ivf_flat_knn {
IdxT* out_idxs)
{
search_params.n_probes = 20;
raft::neighbors::ivf_flat::search(
handle, search_params, *index, search_items, ps.n_queries, ps.k, out_idxs, out_dists);
raft::neighbors::ivf_flat::search(handle,
search_params,
*index,
search_items,
ps.n_queries,
ps.k,
out_idxs,
out_dists,
resource::get_workspace_resource(handle));
}
};

Expand Down
1 change: 1 addition & 0 deletions cpp/bench/prims/random/subsample.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <raft/util/cudart_utils.hpp>

#include <rmm/device_scalar.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>

#include <cub/cub.cuh>
Expand Down
44 changes: 21 additions & 23 deletions cpp/include/raft/cluster/detail/kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,14 @@

#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_scalar.hpp>
#include <rmm/device_vector.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/mr/device/managed_memory_resource.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <thrust/gather.h>
#include <thrust/transform.h>

#include <limits>
#include <optional>
#include <tuple>
#include <type_traits>

Expand Down Expand Up @@ -91,7 +90,7 @@ inline std::enable_if_t<std::is_floating_point_v<MathT>> predict_core(
const MathT* dataset_norm,
IdxT n_rows,
LabelT* labels,
rmm::mr::device_memory_resource* mr)
rmm::device_async_resource_ref mr)
{
auto stream = resource::get_cuda_stream(handle);
switch (params.metric) {
Expand Down Expand Up @@ -263,10 +262,9 @@ void calc_centers_and_sizes(const raft::resources& handle,
const LabelT* labels,
bool reset_counters,
MappingOpT mapping_op,
rmm::mr::device_memory_resource* mr = nullptr)
rmm::device_async_resource_ref mr)
{
auto stream = resource::get_cuda_stream(handle);
if (mr == nullptr) { mr = resource::get_workspace_resource(handle); }

if (!reset_counters) {
raft::linalg::matrixVectorOp(
Expand Down Expand Up @@ -322,12 +320,12 @@ void compute_norm(const raft::resources& handle,
IdxT dim,
IdxT n_rows,
MappingOpT mapping_op,
rmm::mr::device_memory_resource* mr = nullptr)
std::optional<rmm::device_async_resource_ref> mr = std::nullopt)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope("compute_norm");
auto stream = resource::get_cuda_stream(handle);
if (mr == nullptr) { mr = resource::get_workspace_resource(handle); }
rmm::device_uvector<MathT> mapped_dataset(0, stream, mr);
rmm::device_uvector<MathT> mapped_dataset(
0, stream, mr.value_or(resource::get_workspace_resource(handle)));

const MathT* dataset_ptr = nullptr;

Expand All @@ -338,7 +336,7 @@ void compute_norm(const raft::resources& handle,

linalg::unaryOp(mapped_dataset.data(), dataset, n_rows * dim, mapping_op, stream);

dataset_ptr = (const MathT*)mapped_dataset.data();
dataset_ptr = static_cast<const MathT*>(mapped_dataset.data());
}

raft::linalg::rowNorm<MathT, IdxT>(
Expand Down Expand Up @@ -376,22 +374,22 @@ void predict(const raft::resources& handle,
IdxT n_rows,
LabelT* labels,
MappingOpT mapping_op,
rmm::mr::device_memory_resource* mr = nullptr,
const MathT* dataset_norm = nullptr)
std::optional<rmm::device_async_resource_ref> mr = std::nullopt,
const MathT* dataset_norm = nullptr)
{
auto stream = resource::get_cuda_stream(handle);
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"predict(%zu, %u)", static_cast<size_t>(n_rows), n_clusters);
if (mr == nullptr) { mr = resource::get_workspace_resource(handle); }
auto mem_res = mr.value_or(resource::get_workspace_resource(handle));
auto [max_minibatch_size, _mem_per_row] =
calc_minibatch_size<MathT>(n_clusters, n_rows, dim, params.metric, std::is_same_v<T, MathT>);
rmm::device_uvector<MathT> cur_dataset(
std::is_same_v<T, MathT> ? 0 : max_minibatch_size * dim, stream, mr);
std::is_same_v<T, MathT> ? 0 : max_minibatch_size * dim, stream, mem_res);
bool need_compute_norm =
dataset_norm == nullptr && (params.metric == raft::distance::DistanceType::L2Expanded ||
params.metric == raft::distance::DistanceType::L2SqrtExpanded);
rmm::device_uvector<MathT> cur_dataset_norm(
need_compute_norm ? max_minibatch_size : 0, stream, mr);
need_compute_norm ? max_minibatch_size : 0, stream, mem_res);
const MathT* dataset_norm_ptr = nullptr;
auto cur_dataset_ptr = cur_dataset.data();
for (IdxT offset = 0; offset < n_rows; offset += max_minibatch_size) {
Expand All @@ -407,7 +405,7 @@ void predict(const raft::resources& handle,
// Compute the norm now if it hasn't been pre-computed.
if (need_compute_norm) {
compute_norm(
handle, cur_dataset_norm.data(), cur_dataset_ptr, dim, minibatch_size, mapping_op, mr);
handle, cur_dataset_norm.data(), cur_dataset_ptr, dim, minibatch_size, mapping_op, mem_res);
dataset_norm_ptr = cur_dataset_norm.data();
} else if (dataset_norm != nullptr) {
dataset_norm_ptr = dataset_norm + offset;
Expand All @@ -422,7 +420,7 @@ void predict(const raft::resources& handle,
dataset_norm_ptr,
minibatch_size,
labels + offset,
mr);
mem_res);
}
}

Expand Down Expand Up @@ -530,7 +528,7 @@ auto adjust_centers(MathT* centers,
MathT threshold,
MappingOpT mapping_op,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* device_memory) -> bool
rmm::device_async_resource_ref device_memory) -> bool
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"adjust_centers(%zu, %u)", static_cast<size_t>(n_rows), n_clusters);
Expand Down Expand Up @@ -628,7 +626,7 @@ void balancing_em_iters(const raft::resources& handle,
uint32_t balancing_pullback,
MathT balancing_threshold,
MappingOpT mapping_op,
rmm::mr::device_memory_resource* device_memory)
rmm::device_async_resource_ref device_memory)
{
auto stream = resource::get_cuda_stream(handle);
uint32_t balancing_counter = balancing_pullback;
Expand Down Expand Up @@ -711,7 +709,7 @@ void build_clusters(const raft::resources& handle,
LabelT* cluster_labels,
CounterT* cluster_sizes,
MappingOpT mapping_op,
rmm::mr::device_memory_resource* device_memory,
rmm::device_async_resource_ref device_memory,
const MathT* dataset_norm = nullptr)
{
auto stream = resource::get_cuda_stream(handle);
Expand Down Expand Up @@ -853,8 +851,8 @@ auto build_fine_clusters(const raft::resources& handle,
IdxT fine_clusters_nums_max,
MathT* cluster_centers,
MappingOpT mapping_op,
rmm::mr::device_memory_resource* managed_memory,
rmm::mr::device_memory_resource* device_memory) -> IdxT
rmm::device_async_resource_ref managed_memory,
rmm::device_async_resource_ref device_memory) -> IdxT
{
auto stream = resource::get_cuda_stream(handle);
rmm::device_uvector<IdxT> mc_trainset_ids_buf(mesocluster_size_max, stream, managed_memory);
Expand Down Expand Up @@ -971,7 +969,7 @@ void build_hierarchical(const raft::resources& handle,

// TODO: Remove the explicit managed memory- we shouldn't be creating this on the user's behalf.
rmm::mr::managed_memory_resource managed_memory;
rmm::mr::device_memory_resource* device_memory = resource::get_workspace_resource(handle);
rmm::device_async_resource_ref device_memory = resource::get_workspace_resource(handle);
auto [max_minibatch_size, mem_per_row] =
calc_minibatch_size<MathT>(n_clusters, n_rows, dim, params.metric, std::is_same_v<T, MathT>);

Expand Down
3 changes: 2 additions & 1 deletion cpp/include/raft/cluster/kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,8 @@ void calc_centers_and_sizes(const raft::resources& handle,
X.extent(0),
labels.data_handle(),
reset_counters,
mapping_op);
mapping_op,
resource::get_workspace_resource(handle));
}

} // namespace helpers
Expand Down
19 changes: 6 additions & 13 deletions cpp/include/raft/core/device_container_policy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@

#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <thrust/device_ptr.h>

Expand Down Expand Up @@ -117,7 +118,7 @@ class device_uvector {
*/
explicit device_uvector(std::size_t size,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
rmm::device_async_resource_ref mr)
: data_{size, stream, mr}
{
}
Expand Down Expand Up @@ -164,19 +165,11 @@ class device_uvector_policy {
public:
auto create(raft::resources const& res, size_t n) -> container_type
{
if (mr_ == nullptr) {
// NB: not using the workspace resource by default!
// The workspace resource is for short-lived temporary allocations.
return container_type(n, resource::get_cuda_stream(res));
} else {
return container_type(n, resource::get_cuda_stream(res), mr_);
}
return container_type(n, resource::get_cuda_stream(res), mr_);
}

constexpr device_uvector_policy() = default;
constexpr explicit device_uvector_policy(rmm::mr::device_memory_resource* mr) noexcept : mr_(mr)
{
}
explicit device_uvector_policy(rmm::device_async_resource_ref mr) noexcept : mr_(mr) {}

[[nodiscard]] constexpr auto access(container_type& c, size_t n) const noexcept -> reference
{
Expand All @@ -192,7 +185,7 @@ class device_uvector_policy {
[[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{}; }

private:
rmm::mr::device_memory_resource* mr_{nullptr};
rmm::device_async_resource_ref mr_{rmm::mr::get_current_device_resource()};
};

} // namespace raft
Expand Down
4 changes: 3 additions & 1 deletion cpp/include/raft/core/device_mdarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include <raft/core/mdarray.hpp>
#include <raft/core/resources.hpp>

#include <rmm/resource_ref.hpp>

#include <cstdint>

namespace raft {
Expand Down Expand Up @@ -107,7 +109,7 @@ template <typename ElementType,
typename LayoutPolicy = layout_c_contiguous,
size_t... Extents>
auto make_device_mdarray(raft::resources const& handle,
rmm::mr::device_memory_resource* mr,
rmm::device_async_resource_ref mr,
extents<IndexType, Extents...> exts)
{
using mdarray_t = device_mdarray<ElementType, decltype(exts), LayoutPolicy>;
Expand Down
Loading
Loading