Skip to content

Commit

Permalink
Migrate trustworthiness and silhouette_score stats from RAFT (#313)
Browse files Browse the repository at this point in the history
Authors:
  - Ben Frederickson (https://github.com/benfred)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #313
  • Loading branch information
benfred authored Sep 6, 2024
1 parent 9e8ec39 commit 7d144cf
Show file tree
Hide file tree
Showing 14 changed files with 1,837 additions and 25 deletions.
2 changes: 2 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,8 @@ add_library(
src/selection/select_k_float_int64_t.cu
src/selection/select_k_float_uint32_t.cu
src/selection/select_k_half_uint32_t.cu
src/stats/silhouette_score.cu
src/stats/trustworthiness_score.cu
)

target_compile_definitions(cuvs PRIVATE "CUVS_EXPLICIT_INSTANTIATE_ONLY")
Expand Down
121 changes: 121 additions & 0 deletions cpp/include/cuvs/stats/silhouette_score.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <cuvs/distance/distance.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/resources.hpp>

namespace cuvs {
namespace stats {

/**
* @defgroup stats_silhouette_score Silhouette Score
* @{
*/
/**
* @brief main function that returns the average silhouette score for a given set of data and its
* clusterings
* @param[in] handle: raft handle for managing expensive resources
* @param[in] X_in: input matrix Data in row-major format (nRows x nCols)
* @param[in] labels: the pointer to the array containing labels for every data sample (length:
* nRows)
* @param[out] silhouette_score_per_sample: optional array populated with the silhouette score
* for every sample (length: nRows)
* @param[in] n_unique_labels: number of unique labels in the labels array
* @param[in] metric: Distance metric to use. Euclidean (L2) is used by default
* @return: The silhouette score.
*/
float silhouette_score(
raft::resources const& handle,
raft::device_matrix_view<const float, int64_t, raft::row_major> X_in,
raft::device_vector_view<const int, int64_t> labels,
std::optional<raft::device_vector_view<float, int64_t>> silhouette_score_per_sample,
int64_t n_unique_labels,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded);

/**
* @brief function that returns the average silhouette score for a given set of data and its
* clusterings
* @param[in] handle: raft handle for managing expensive resources
* @param[in] X: input matrix Data in row-major format (nRows x nCols)
* @param[in] labels: the pointer to the array containing labels for every data sample (length:
* nRows)
* @param[out] silhouette_score_per_sample: optional array populated with the silhouette score
* for every sample (length: nRows)
* @param[in] n_unique_labels: number of unique labels in the labels array
* @param[in] batch_size: number of samples per batch
* @param[in] metric: the numerical value that maps to the type of distance metric to be used in
* the calculations
* @return: The silhouette score.
*/
float silhouette_score_batched(
raft::resources const& handle,
raft::device_matrix_view<const float, int64_t, raft::row_major> X,
raft::device_vector_view<const int, int64_t> labels,
std::optional<raft::device_vector_view<float, int64_t>> silhouette_score_per_sample,
int64_t n_unique_labels,
int64_t batch_size,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded);

/**
* @brief main function that returns the average silhouette score for a given set of data and its
* clusterings
* @param[in] handle: raft handle for managing expensive resources
* @param[in] X_in: input matrix Data in row-major format (nRows x nCols)
* @param[in] labels: the pointer to the array containing labels for every data sample (length:
* nRows)
* @param[out] silhouette_score_per_sample: optional array populated with the silhouette score
* for every sample (length: nRows)
* @param[in] n_unique_labels: number of unique labels in the labels array
* @param[in] metric: the numerical value that maps to the type of distance metric to be used in
* the calculations
* @return: The silhouette score.
*/
double silhouette_score(
raft::resources const& handle,
raft::device_matrix_view<const double, int64_t, raft::row_major> X_in,
raft::device_vector_view<const int, int64_t> labels,
std::optional<raft::device_vector_view<double, int64_t>> silhouette_score_per_sample,
int64_t n_unique_labels,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded);

/**
* @brief function that returns the average silhouette score for a given set of data and its
* clusterings
* @param[in] handle: raft handle for managing expensive resources
* @param[in] X: input matrix Data in row-major format (nRows x nCols)
* @param[in] labels: the pointer to the array containing labels for every data sample (length:
* nRows)
* @param[out] silhouette_score_per_sample: optional array populated with the silhouette score
* for every sample (length: nRows)
* @param[in] n_unique_labels: number of unique labels in the labels array
* @param[in] batch_size: number of samples per batch
* @param[in] metric: the numerical value that maps to the type of distance metric to be used in
* the calculations
* @return: The silhouette score.
*/
double silhouette_score_batched(
raft::resources const& handle,
raft::device_matrix_view<const double, int64_t, raft::row_major> X,
raft::device_vector_view<const int, int64_t> labels,
std::optional<raft::device_vector_view<double, int64_t>> silhouette_score_per_sample,
int64_t n_unique_labels,
int64_t batch_size,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded);

} // namespace stats
} // namespace cuvs
51 changes: 51 additions & 0 deletions cpp/include/cuvs/stats/trustworthiness_score.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once
#include <cuvs/distance/distance.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/resources.hpp>

namespace cuvs {
namespace stats {
/**
* @defgroup stats_trustworthiness Trustworthiness
* @{
*/

/**
* @brief Compute the trustworthiness score
* @param[in] handle the raft handle
* @param[in] X: Data in original dimension
* @param[in] X_embedded: Data in target dimension (embedding)
* @param[in] n_neighbors Number of neighbors considered by trustworthiness score
* @param[in] metric Distance metric to use. Euclidean (L2) is used by default
* @param[in] batch_size Batch size
* @return Trustworthiness score
* @note The constness of the data in X_embedded is currently casted away and the data is slightly
* modified.
*/
double trustworthiness_score(
raft::resources const& handle,
raft::device_matrix_view<const float, int64_t, raft::row_major> X,
raft::device_matrix_view<const float, int64_t, raft::row_major> X_embedded,
int n_neighbors,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2SqrtUnexpanded,
int batch_size = 512);

/** @} */ // end group stats_trustworthiness
} // namespace stats
} // namespace cuvs
Loading

0 comments on commit 7d144cf

Please sign in to comment.