Skip to content

Commit

Permalink
[FEA] Add distance epilogue for NN Descent (#2364)
Browse files Browse the repository at this point in the history
Adding distance epilogue to NN Descent.
Planning to use them for the following use cases as of now;
- Calculating mutual reachability distance for HDBSCAN (possibly for usage with HDBSCAN - [related PR here](rapidsai/cuml#5939))
- Enabling `L2SqrtExpanded` distance metric, by `sqrt`-ing the current supported metric (`L2Expanded`) of NN Descent in the distance epilogue. (for usage with UMAP - [related PR here](rapidsai/cuml#5910))

Authors:
  - Jinsol Park (https://github.com/jinsolp)

Approvers:
  - Tarang Jain (https://github.com/tarang-jain)
  - Tamas Bela Feher (https://github.com/tfeher)
  - Divye Gala (https://github.com/divyegala)

URL: #2364
  • Loading branch information
jinsolp authored Jul 11, 2024
1 parent 5bf6642 commit ab5e128
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 55 deletions.
104 changes: 62 additions & 42 deletions cpp/include/raft/neighbors/detail/nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <raft/core/device_mdarray.hpp>
#include <raft/core/error.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/operators.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/neighbors/detail/cagra/device_common.hpp>
Expand Down Expand Up @@ -340,7 +341,7 @@ struct GnndGraph {
~GnndGraph();
};

template <typename Data_t = float, typename Index_t = int>
template <typename Data_t = float, typename Index_t = int, typename epilogue_op = raft::identity_op>
class GNND {
public:
GNND(raft::resources const& res, const BuildConfig& build_config);
Expand All @@ -351,7 +352,8 @@ class GNND {
const Index_t nrow,
Index_t* output_graph,
bool return_distances,
DistData_t* output_distances);
DistData_t* output_distances,
epilogue_op distance_epilogue = raft::identity_op());
~GNND() = default;
using ID_t = InternalID_t<Index_t>;

Expand All @@ -361,7 +363,7 @@ class GNND {
Index_t* d_rev_graph_ptr,
int2* list_sizes,
cudaStream_t stream = 0);
void local_join(cudaStream_t stream = 0);
void local_join(cudaStream_t stream = 0, epilogue_op distance_epilogue = raft::identity_op());

raft::resources const& res;

Expand Down Expand Up @@ -694,7 +696,9 @@ __device__ __forceinline__ void remove_duplicates(
// MAX_RESIDENT_THREAD_PER_SM = BLOCK_SIZE * BLOCKS_PER_SM = 2048
// For architectures 750 and 860 (890), the values for MAX_RESIDENT_THREAD_PER_SM
// is 1024 and 1536 respectively, which means the bounds don't work anymore
template <typename Index_t, typename ID_t = InternalID_t<Index_t>>
template <typename Index_t,
typename ID_t = InternalID_t<Index_t>,
typename epilogue_op = raft::identity_op>
RAFT_KERNEL
#ifdef __CUDA_ARCH__
#if (__CUDA_ARCH__) == 750 || ((__CUDA_ARCH__) >= 860 && (__CUDA_ARCH__) <= 890)
Expand All @@ -716,7 +720,8 @@ __launch_bounds__(BLOCK_SIZE, 4)
DistData_t* dists,
int graph_width,
int* locks,
DistData_t* l2_norms)
DistData_t* l2_norms,
epilogue_op distance_epilogue)
{
#if (__CUDA_ARCH__ >= 700)
using namespace nvcuda;
Expand Down Expand Up @@ -826,14 +831,17 @@ __launch_bounds__(BLOCK_SIZE, 4)
__syncthreads();

for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) {
if (i % SKEWED_MAX_NUM_BI_SAMPLES < list_new_size &&
i / SKEWED_MAX_NUM_BI_SAMPLES < list_new_size) {
auto row_idx = i % SKEWED_MAX_NUM_BI_SAMPLES;
auto col_idx = i / SKEWED_MAX_NUM_BI_SAMPLES;
if (row_idx < list_new_size && col_idx < list_new_size) {
auto r = new_neighbors[row_idx];
auto c = new_neighbors[col_idx];
if (l2_norms == nullptr) {
s_distances[i] = -s_distances[i];
auto dist_val = -s_distances[i];
s_distances[i] = distance_epilogue(dist_val, r, c);
} else {
s_distances[i] = l2_norms[new_neighbors[i % SKEWED_MAX_NUM_BI_SAMPLES]] +
l2_norms[new_neighbors[i / SKEWED_MAX_NUM_BI_SAMPLES]] -
2.0 * s_distances[i];
auto dist_val = l2_norms[r] + l2_norms[c] - 2.0 * s_distances[i];
s_distances[i] = distance_epilogue(dist_val, r, c);
}
} else {
s_distances[i] = std::numeric_limits<float>::max();
Expand Down Expand Up @@ -905,14 +913,17 @@ __launch_bounds__(BLOCK_SIZE, 4)
__syncthreads();

for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) {
if (i % SKEWED_MAX_NUM_BI_SAMPLES < list_old_size &&
i / SKEWED_MAX_NUM_BI_SAMPLES < list_new_size) {
auto row_idx = i % SKEWED_MAX_NUM_BI_SAMPLES;
auto col_idx = i / SKEWED_MAX_NUM_BI_SAMPLES;
if (row_idx < list_old_size && col_idx < list_new_size) {
auto r = old_neighbors[row_idx];
auto c = new_neighbors[col_idx];
if (l2_norms == nullptr) {
s_distances[i] = -s_distances[i];
auto dist_val = -s_distances[i];
s_distances[i] = distance_epilogue(dist_val, r, c);
} else {
s_distances[i] = l2_norms[old_neighbors[i % SKEWED_MAX_NUM_BI_SAMPLES]] +
l2_norms[new_neighbors[i / SKEWED_MAX_NUM_BI_SAMPLES]] -
2.0 * s_distances[i];
auto dist_val = l2_norms[r] + l2_norms[c] - 2.0 * s_distances[i];
s_distances[i] = distance_epilogue(dist_val, r, c);
}
} else {
s_distances[i] = std::numeric_limits<float>::max();
Expand Down Expand Up @@ -1140,8 +1151,9 @@ GnndGraph<Index_t>::~GnndGraph()
assert(h_graph == nullptr);
}

template <typename Data_t, typename Index_t>
GNND<Data_t, Index_t>::GNND(raft::resources const& res, const BuildConfig& build_config)
template <typename Data_t, typename Index_t, typename epilogue_op>
GNND<Data_t, Index_t, epilogue_op>::GNND(raft::resources const& res,
const BuildConfig& build_config)
: res(res),
build_config_(build_config),
graph_(build_config.max_dataset_size,
Expand Down Expand Up @@ -1180,21 +1192,22 @@ GNND<Data_t, Index_t>::GNND(raft::resources const& res, const BuildConfig& build
thrust::fill(thrust::device, d_locks_.data_handle(), d_locks_.data_handle() + d_locks_.size(), 0);
};

template <typename Data_t, typename Index_t>
void GNND<Data_t, Index_t>::add_reverse_edges(Index_t* graph_ptr,
Index_t* h_rev_graph_ptr,
Index_t* d_rev_graph_ptr,
int2* list_sizes,
cudaStream_t stream)
template <typename Data_t, typename Index_t, typename epilogue_op>
void GNND<Data_t, Index_t, epilogue_op>::add_reverse_edges(Index_t* graph_ptr,
Index_t* h_rev_graph_ptr,
Index_t* d_rev_graph_ptr,
int2* list_sizes,
cudaStream_t stream)
{
add_rev_edges_kernel<<<nrow_, raft::warp_size(), 0, stream>>>(
graph_ptr, d_rev_graph_ptr, NUM_SAMPLES, list_sizes);
raft::copy(
h_rev_graph_ptr, d_rev_graph_ptr, nrow_ * NUM_SAMPLES, raft::resource::get_cuda_stream(res));
}

template <typename Data_t, typename Index_t>
void GNND<Data_t, Index_t>::local_join(cudaStream_t stream)
template <typename Data_t, typename Index_t, typename epilogue_op>
void GNND<Data_t, Index_t, epilogue_op>::local_join(cudaStream_t stream,
epilogue_op distance_epilogue)
{
thrust::fill(thrust::device.on(stream),
dists_buffer_.data_handle(),
Expand All @@ -1214,15 +1227,17 @@ void GNND<Data_t, Index_t>::local_join(cudaStream_t stream)
dists_buffer_.data_handle(),
DEGREE_ON_DEVICE,
d_locks_.data_handle(),
l2_norms_.data_handle());
l2_norms_.data_handle(),
distance_epilogue);
}

template <typename Data_t, typename Index_t>
void GNND<Data_t, Index_t>::build(Data_t* data,
const Index_t nrow,
Index_t* output_graph,
bool return_distances,
DistData_t* output_distances)
template <typename Data_t, typename Index_t, typename epilogue_op>
void GNND<Data_t, Index_t, epilogue_op>::build(Data_t* data,
const Index_t nrow,
Index_t* output_graph,
bool return_distances,
DistData_t* output_distances,
epilogue_op distance_epilogue)
{
using input_t = typename std::remove_const<Data_t>::type;

Expand Down Expand Up @@ -1318,7 +1333,7 @@ void GNND<Data_t, Index_t>::build(Data_t* data,
raft::util::arch::SM_range(raft::util::arch::SM_70(), raft::util::arch::SM_future());
if (wmma_range.contains(runtime_arch)) {
local_join(stream);
local_join(stream, distance_epilogue);
} else {
THROW("NN_DESCENT cannot be run for __CUDA_ARCH__ < 700");
}
Expand Down Expand Up @@ -1385,13 +1400,15 @@ void GNND<Data_t, Index_t>::build(Data_t* data,
}
template <typename T,
typename IdxT = uint32_t,
typename IdxT = uint32_t,
typename epilogue_op = raft::identity_op,
typename Accessor =
host_device_accessor<std::experimental::default_accessor<T>, memory_type::host>>
void build(raft::resources const& res,
const index_params& params,
mdspan<const T, matrix_extent<int64_t>, row_major, Accessor> dataset,
index<IdxT>& idx)
index<IdxT>& idx,
epilogue_op distance_epilogue = raft::identity_op())
{
RAFT_EXPECTS(dataset.extent(0) < std::numeric_limits<int>::max() - 1,
"The dataset size for GNND should be less than %d",
Expand Down Expand Up @@ -1433,7 +1450,7 @@ void build(raft::resources const& res,
.termination_threshold = params.termination_threshold,
.output_graph_degree = params.graph_degree};
GNND<const T, int> nnd(res, build_config);
GNND<const T, int, epilogue_op> nnd(res, build_config);
if (idx.distances().has_value() || !params.return_distances) {
nnd.build(dataset.data_handle(),
Expand All @@ -1442,7 +1459,8 @@ void build(raft::resources const& res,
params.return_distances,
idx.distances()
.value_or(raft::make_device_matrix<float, int64_t>(res, 0, 0).view())
.data_handle());
.data_handle(),
distance_epilogue);
} else {
RAFT_EXPECTS(!params.return_distances,
"Distance view not allocated. Using return_distances set to true requires "
Expand All @@ -1459,12 +1477,14 @@ void build(raft::resources const& res,
}
template <typename T,
typename IdxT = uint32_t,
typename IdxT = uint32_t,
typename epilogue_op = raft::identity_op,
typename Accessor =
host_device_accessor<std::experimental::default_accessor<T>, memory_type::host>>
index<IdxT> build(raft::resources const& res,
const index_params& params,
mdspan<const T, matrix_extent<int64_t>, row_major, Accessor> dataset)
mdspan<const T, matrix_extent<int64_t>, row_major, Accessor> dataset,
epilogue_op distance_epilogue = raft::identity_op())
{
size_t intermediate_degree = params.intermediate_graph_degree;
size_t graph_degree = params.graph_degree;
Expand All @@ -1481,7 +1501,7 @@ index<IdxT> build(raft::resources const& res,
index<IdxT> idx{
res, dataset.extent(0), static_cast<int64_t>(graph_degree), params.return_distances};
build(res, params, dataset, idx);
build(res, params, dataset, idx, distance_epilogue);
return idx;
}
Expand Down
38 changes: 25 additions & 13 deletions cpp/include/raft/neighbors/nn_descent.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -48,19 +48,22 @@ namespace raft::neighbors::experimental::nn_descent {
*
* @tparam T data-type of the input dataset
* @tparam IdxT data-type for the output index
* @tparam epilogue_op epilogue operation type for distances
* @param[in] res raft::resources is an object mangaging resources
* @param[in] params an instance of nn_descent::index_params that are parameters
* to run the nn-descent algorithm
* @param[in] dataset raft::device_matrix_view input dataset expected to be located
* in device memory
* @param[in] distance_epilogue epilogue operation for distances
* @return index<IdxT> index containing all-neighbors knn graph in host memory
*/
template <typename T, typename IdxT = uint32_t>
template <typename T, typename IdxT = uint32_t, typename epilogue_op = raft::identity_op>
index<IdxT> build(raft::resources const& res,
index_params const& params,
raft::device_matrix_view<const T, int64_t, row_major> dataset)
raft::device_matrix_view<const T, int64_t, row_major> dataset,
epilogue_op distance_epilogue = raft::identity_op())
{
return detail::build<T, IdxT>(res, params, dataset);
return detail::build<T, IdxT>(res, params, dataset, distance_epilogue);
}

/**
Expand All @@ -85,21 +88,24 @@ index<IdxT> build(raft::resources const& res,
*
* @tparam T data-type of the input dataset
* @tparam IdxT data-type for the output index
* @tparam epilogue_op epilogue operation type for distances
* @param res raft::resources is an object mangaging resources
* @param[in] params an instance of nn_descent::index_params that are parameters
* to run the nn-descent algorithm
* @param[in] dataset raft::device_matrix_view input dataset expected to be located
* in device memory
* @param[out] idx raft::neighbors::experimental::nn_descentindex containing all-neighbors knn graph
* in host memory
* @param[in] distance_epilogue epilogue operation for distances
*/
template <typename T, typename IdxT = uint32_t>
template <typename T, typename IdxT = uint32_t, typename epilogue_op = raft::identity_op>
void build(raft::resources const& res,
index_params const& params,
raft::device_matrix_view<const T, int64_t, row_major> dataset,
index<IdxT>& idx)
index<IdxT>& idx,
epilogue_op distance_epilogue = raft::identity_op())
{
detail::build<T, IdxT>(res, params, dataset, idx);
detail::build<T, IdxT>(res, params, dataset, idx, distance_epilogue);
}

/**
Expand All @@ -122,19 +128,22 @@ void build(raft::resources const& res,
*
* @tparam T data-type of the input dataset
* @tparam IdxT data-type for the output index
* @tparam epilogue_op epilogue operation type for distances
* @param res raft::resources is an object mangaging resources
* @param[in] params an instance of nn_descent::index_params that are parameters
* to run the nn-descent algorithm
* @param[in] dataset raft::host_matrix_view input dataset expected to be located
* in host memory
* @param[in] distance_epilogue epilogue operation for distances
* @return index<IdxT> index containing all-neighbors knn graph in host memory
*/
template <typename T, typename IdxT = uint32_t>
template <typename T, typename IdxT = uint32_t, typename epilogue_op = raft::identity_op>
index<IdxT> build(raft::resources const& res,
index_params const& params,
raft::host_matrix_view<const T, int64_t, row_major> dataset)
raft::host_matrix_view<const T, int64_t, row_major> dataset,
epilogue_op distance_epilogue = raft::identity_op())
{
return detail::build<T, IdxT>(res, params, dataset);
return detail::build<T, IdxT>(res, params, dataset, distance_epilogue);
}

/**
Expand All @@ -159,21 +168,24 @@ index<IdxT> build(raft::resources const& res,
*
* @tparam T data-type of the input dataset
* @tparam IdxT data-type for the output index
* @tparam epilogue_op epilogue operation type for distances
* @param[in] res raft::resources is an object mangaging resources
* @param[in] params an instance of nn_descent::index_params that are parameters
* to run the nn-descent algorithm
* @param[in] dataset raft::host_matrix_view input dataset expected to be located
* in host memory
* @param[out] idx raft::neighbors::experimental::nn_descentindex containing all-neighbors knn graph
* in host memory
* @param[in] distance_epilogue epilogue operation for distances
*/
template <typename T, typename IdxT = uint32_t>
template <typename T, typename IdxT = uint32_t, typename epilogue_op = raft::identity_op>
void build(raft::resources const& res,
index_params const& params,
raft::host_matrix_view<const T, int64_t, row_major> dataset,
index<IdxT>& idx)
index<IdxT>& idx,
epilogue_op distance_epilogue = raft::identity_op())
{
detail::build<T, IdxT>(res, params, dataset, idx);
detail::build<T, IdxT>(res, params, dataset, idx, distance_epilogue);
}

/** @} */ // end group nn-descent
Expand Down

0 comments on commit ab5e128

Please sign in to comment.