Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scaling workspace resources #2322

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions cpp/bench/ann/src/raft/raft_ann_bench_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/mr/device/failure_callback_resource_adaptor.hpp>
#include <rmm/mr/device/managed_memory_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>

#include <memory>
Expand Down Expand Up @@ -74,13 +75,14 @@ inline auto rmm_oom_callback(std::size_t bytes, void*) -> bool
*/
class shared_raft_resources {
public:
using pool_mr_type = rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>;
using mr_type = rmm::mr::failure_callback_resource_adaptor<pool_mr_type>;
using pool_mr_type = rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>;
using mr_type = rmm::mr::failure_callback_resource_adaptor<pool_mr_type>;
using large_mr_type = rmm::mr::managed_memory_resource;

shared_raft_resources()
try : orig_resource_{rmm::mr::get_current_device_resource()},
pool_resource_(orig_resource_, 1024 * 1024 * 1024ull),
resource_(&pool_resource_, rmm_oom_callback, nullptr) {
resource_(&pool_resource_, rmm_oom_callback, nullptr), large_mr_() {
rmm::mr::set_current_device_resource(&resource_);
} catch (const std::exception& e) {
auto cuda_status = cudaGetLastError();
Expand All @@ -103,10 +105,16 @@ class shared_raft_resources {

~shared_raft_resources() noexcept { rmm::mr::set_current_device_resource(orig_resource_); }

auto get_large_memory_resource() noexcept
{
return static_cast<rmm::mr::device_memory_resource*>(&large_mr_);
}

private:
rmm::mr::device_memory_resource* orig_resource_;
pool_mr_type pool_resource_;
mr_type resource_;
large_mr_type large_mr_;
};

/**
Expand All @@ -129,6 +137,12 @@ class configured_raft_resources {
res_{std::make_unique<raft::device_resources>(
rmm::cuda_stream_view(get_stream_from_global_pool()))}
{
// set the large workspace resource to the raft handle, but without the deleter
// (this resource is managed by the shared_res).
raft::resource::set_large_workspace_resource(
*res_,
std::shared_ptr<rmm::mr::device_memory_resource>(shared_res_->get_large_memory_resource(),
raft::void_op{}));
}

/** Default constructor creates all resources anew. */
Expand Down
52 changes: 50 additions & 2 deletions cpp/include/raft/core/resource/device_memory_resource.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,6 +35,16 @@ namespace raft::resource {
* @{
*/

class device_memory_resource : public resource {
public:
explicit device_memory_resource(std::shared_ptr<rmm::mr::device_memory_resource> mr) : mr_(mr) {}
~device_memory_resource() override = default;
auto get_resource() -> void* override { return mr_.get(); }

private:
std::shared_ptr<rmm::mr::device_memory_resource> mr_;
};

class limiting_memory_resource : public resource {
public:
limiting_memory_resource(std::shared_ptr<rmm::mr::device_memory_resource> mr,
Expand Down Expand Up @@ -66,6 +76,29 @@ class limiting_memory_resource : public resource {
}
};

/**
* Factory that knows how to construct a specific raft::resource to populate
* the resources instance.
*/
class large_workspace_resource_factory : public resource_factory {
public:
explicit large_workspace_resource_factory(
std::shared_ptr<rmm::mr::device_memory_resource> mr = {nullptr})
: mr_{mr ? mr
: std::shared_ptr<rmm::mr::device_memory_resource>{
rmm::mr::get_current_device_resource(), void_op{}}}
{
}
auto get_resource_type() -> resource_type override
{
return resource_type::LARGE_WORKSPACE_RESOURCE;
}
auto make_resource() -> resource* override { return new device_memory_resource(mr_); }

private:
std::shared_ptr<rmm::mr::device_memory_resource> mr_;
};

/**
* Factory that knows how to construct a specific raft::resource to populate
* the resources instance.
Expand Down Expand Up @@ -144,7 +177,7 @@ class workspace_resource_factory : public resource_factory {
// Note, the workspace does not claim all this memory from the start, so it's still usable by
// the main resource as well.
// This limit is merely an order for algorithm internals to plan the batching accordingly.
return total_size / 2;
return total_size / 4;
}
};

Expand Down Expand Up @@ -241,6 +274,21 @@ inline void set_workspace_to_global_resource(
workspace_resource_factory::default_plain_resource(), allocation_limit, std::nullopt));
};

inline auto get_large_workspace_resource(resources const& res) -> rmm::mr::device_memory_resource*
{
if (!res.has_resource_factory(resource_type::LARGE_WORKSPACE_RESOURCE)) {
res.add_resource_factory(std::make_shared<large_workspace_resource_factory>());
}
return res.get_resource<rmm::mr::device_memory_resource>(resource_type::LARGE_WORKSPACE_RESOURCE);
};

inline void set_large_workspace_resource(resources const& res,
std::shared_ptr<rmm::mr::device_memory_resource> mr = {
nullptr})
{
res.add_resource_factory(std::make_shared<large_workspace_resource_factory>(mr));
};

/** @} */

} // namespace raft::resource
35 changes: 18 additions & 17 deletions cpp/include/raft/core/resource/resource_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,24 @@ namespace raft::resource {
*/
enum resource_type {
// device-specific resource types
CUBLAS_HANDLE = 0, // cublas handle
CUSOLVER_DN_HANDLE, // cusolver dn handle
CUSOLVER_SP_HANDLE, // cusolver sp handle
CUSPARSE_HANDLE, // cusparse handle
CUDA_STREAM_VIEW, // view of a cuda stream
CUDA_STREAM_POOL, // cuda stream pool
CUDA_STREAM_SYNC_EVENT, // cuda event for syncing streams
COMMUNICATOR, // raft communicator
SUB_COMMUNICATOR, // raft sub communicator
DEVICE_PROPERTIES, // cuda device properties
DEVICE_ID, // cuda device id
STREAM_VIEW, // view of a cuda stream or a placeholder in
// CUDA-free builds
THRUST_POLICY, // thrust execution policy
WORKSPACE_RESOURCE, // rmm device memory resource
CUBLASLT_HANDLE, // cublasLt handle
CUSTOM, // runtime-shared default-constructible resource
CUBLAS_HANDLE = 0, // cublas handle
CUSOLVER_DN_HANDLE, // cusolver dn handle
CUSOLVER_SP_HANDLE, // cusolver sp handle
CUSPARSE_HANDLE, // cusparse handle
CUDA_STREAM_VIEW, // view of a cuda stream
CUDA_STREAM_POOL, // cuda stream pool
CUDA_STREAM_SYNC_EVENT, // cuda event for syncing streams
COMMUNICATOR, // raft communicator
SUB_COMMUNICATOR, // raft sub communicator
DEVICE_PROPERTIES, // cuda device properties
DEVICE_ID, // cuda device id
STREAM_VIEW, // view of a cuda stream or a placeholder in
// CUDA-free builds
THRUST_POLICY, // thrust execution policy
WORKSPACE_RESOURCE, // rmm device memory resource for small temporary allocations
CUBLASLT_HANDLE, // cublasLt handle
CUSTOM, // runtime-shared default-constructible resource
LARGE_WORKSPACE_RESOURCE, // rmm device memory resource for somewhat large temporary allocations

LAST_KEY // reserved for the last key
};
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/matrix/detail/select_radix.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,7 @@ void radix_topk(const T* in,
unsigned grid_dim,
int sm_cnt,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource())
rmm::device_async_resource_ref mr)
{
// TODO: is it possible to relax this restriction?
static_assert(calc_num_passes<T, BitsPerPass>() > 1);
Expand Down
Loading