diff --git a/cpp/include/cuvs/neighbors/cagra.h b/cpp/include/cuvs/neighbors/cagra.h index 241f5d8b0..14331ebbc 100644 --- a/cpp/include/cuvs/neighbors/cagra.h +++ b/cpp/include/cuvs/neighbors/cagra.h @@ -267,6 +267,15 @@ cuvsError_t cuvsCagraIndexCreate(cuvsCagraIndex_t* index); */ cuvsError_t cuvsCagraIndexDestroy(cuvsCagraIndex_t index); +/** + * @brief Get dimension of the CAGRA index + * + * @param[in] index CAGRA index + * @param[out] dim return dimension of the index + * @return cuvsError_t + */ +cuvsError_t cuvsCagraIndexGetDims(cuvsCagraIndex_t index, int* dim); + /** * @} */ @@ -338,7 +347,7 @@ cuvsError_t cuvsCagraBuild(cuvsResources_t res, * with the same type of `queries`, such that `index.dtype.code == * queries.dl_tensor.dtype.code` Types for input are: * 1. `queries`: - *` a. kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32` + * a. `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32` * b. `kDLDataType.code == kDLInt` and `kDLDataType.bits = 8` * c. `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 8` * 2. `neighbors`: `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 32` diff --git a/cpp/include/cuvs/neighbors/hnsw.h b/cpp/include/cuvs/neighbors/hnsw.h index 5e94de60a..0495c574a 100644 --- a/cpp/include/cuvs/neighbors/hnsw.h +++ b/cpp/include/cuvs/neighbors/hnsw.h @@ -105,8 +105,10 @@ cuvsError_t cuvsHnswIndexDestroy(cuvsHnswIndex_t index); * with the same type of `queries`, such that `index.dtype.code == * queries.dl_tensor.dtype.code` * Supported types for input are: - * 1. `queries`: `kDLDataType.code == kDLFloat` or `kDLDataType.code == kDLInt` and - * `kDLDataType.bits = 32` + * 1. `queries`: + * a. `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32` + * b. `kDLDataType.code == kDLInt` and `kDLDataType.bits = 8` + * c. `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 8` * 2. `neighbors`: `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 64` * 3. `distances`: `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32` * NOTE: The HNSW index can only be searched by the hnswlib wrapper in cuVS, diff --git a/cpp/include/cuvs/neighbors/hnsw.hpp b/cpp/include/cuvs/neighbors/hnsw.hpp index 007adef0d..d5abd6d55 100644 --- a/cpp/include/cuvs/neighbors/hnsw.hpp +++ b/cpp/include/cuvs/neighbors/hnsw.hpp @@ -173,6 +173,8 @@ std::unique_ptr> from_cagra( /**@}*/ +// TODO: Filtered Search APIs: https://github.com/rapidsai/cuvs/issues/363 + /** * @defgroup hnsw_cpp_index_search Search hnswlib index * @{ @@ -260,7 +262,7 @@ void search(raft::resources const& res, void search(raft::resources const& res, const search_params& params, const index& idx, - raft::host_matrix_view queries, + raft::host_matrix_view queries, raft::host_matrix_view neighbors, raft::host_matrix_view distances); @@ -303,7 +305,7 @@ void search(raft::resources const& res, void search(raft::resources const& res, const search_params& params, const index& idx, - raft::host_matrix_view queries, + raft::host_matrix_view queries, raft::host_matrix_view neighbors, raft::host_matrix_view distances); diff --git a/cpp/src/neighbors/cagra_c.cpp b/cpp/src/neighbors/cagra_c.cpp index 164448f2c..6985ff094 100644 --- a/cpp/src/neighbors/cagra_c.cpp +++ b/cpp/src/neighbors/cagra_c.cpp @@ -176,6 +176,14 @@ extern "C" cuvsError_t cuvsCagraIndexDestroy(cuvsCagraIndex_t index_c_ptr) }); } +extern "C" cuvsError_t cuvsCagraIndexGetDims(cuvsCagraIndex_t index, int* dim) +{ + return cuvs::core::translate_exceptions([=] { + auto index_ptr = reinterpret_cast*>(index->addr); + *dim = index_ptr->dim(); + }); +} + extern "C" cuvsError_t cuvsCagraBuild(cuvsResources_t res, cuvsCagraIndexParams_t params, DLManagedTensor* dataset_tensor, diff --git a/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh index b92ef0ace..a077c098f 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh @@ -120,9 +120,9 @@ void serialize_to_hnswlib(raft::resources const& res, os.write(reinterpret_cast(&curr_element_count), sizeof(std::size_t)); // Example:M: 16, dim = 128, data_t = float, index_t = uint32_t, list_size_type = uint32_t, // labeltype: size_t size_data_per_element_ = M * 2 * sizeof(index_t) + sizeof(list_size_type) + - // dim * 4 + sizeof(labeltype) - auto size_data_per_element = - static_cast(index_.graph_degree() * sizeof(IdxT) + 4 + index_.dim() * 4 + 8); + // dim * sizeof(T) + sizeof(labeltype) + auto size_data_per_element = static_cast(index_.graph_degree() * sizeof(IdxT) + 4 + + index_.dim() * sizeof(T) + 8); os.write(reinterpret_cast(&size_data_per_element), sizeof(std::size_t)); // label_offset std::size_t label_offset = size_data_per_element - 8; @@ -185,18 +185,9 @@ void serialize_to_hnswlib(raft::resources const& res, } auto data_row = host_dataset.data_handle() + (index_.dim() * i); - if constexpr (std::is_same_v) { - for (std::size_t j = 0; j < index_.dim(); ++j) { - auto data_elem = static_cast(host_dataset(i, j)); - os.write(reinterpret_cast(&data_elem), sizeof(float)); - } - } else if constexpr (std::is_same_v or std::is_same_v) { - for (std::size_t j = 0; j < index_.dim(); ++j) { - auto data_elem = static_cast(host_dataset(i, j)); - os.write(reinterpret_cast(&data_elem), sizeof(int)); - } - } else { - RAFT_FAIL("Unsupported dataset type while saving CAGRA dataset to HNSWlib format"); + for (std::size_t j = 0; j < index_.dim(); ++j) { + auto data_elem = static_cast(host_dataset(i, j)); + os.write(reinterpret_cast(&data_elem), sizeof(T)); } os.write(reinterpret_cast(&i), sizeof(std::size_t)); diff --git a/cpp/src/neighbors/detail/hnsw.hpp b/cpp/src/neighbors/detail/hnsw.hpp index 0d1ae4ec9..ce1e03264 100644 --- a/cpp/src/neighbors/detail/hnsw.hpp +++ b/cpp/src/neighbors/detail/hnsw.hpp @@ -110,9 +110,9 @@ std::unique_ptr> from_cagra(raft::resources const& res, return std::unique_ptr>(hnsw_index); } -template -void get_search_knn_results(hnswlib::HierarchicalNSW const* idx, - const QueriesT* query, +template +void get_search_knn_results(hnswlib::HierarchicalNSW::type> const* idx, + const T* query, int k, uint64_t* indices, float* distances) @@ -127,11 +127,11 @@ void get_search_knn_results(hnswlib::HierarchicalNSW const* idx, } } -template +template void search(raft::resources const& res, const search_params& params, const index& idx, - raft::host_matrix_view queries, + raft::host_matrix_view queries, raft::host_matrix_view neighbors, raft::host_matrix_view distances) { @@ -146,7 +146,8 @@ void search(raft::resources const& res, idx.set_ef(params.ef); auto const* hnswlib_index = - reinterpret_cast const*>(idx.get_index()); + reinterpret_cast::type> const*>( + idx.get_index()); // when num_threads == 0, automatically maximize parallelism if (params.num_threads) { diff --git a/cpp/src/neighbors/hnsw.cpp b/cpp/src/neighbors/hnsw.cpp index 36cbb16c9..e6f3fbcc7 100644 --- a/cpp/src/neighbors/hnsw.cpp +++ b/cpp/src/neighbors/hnsw.cpp @@ -34,20 +34,20 @@ CUVS_INST_HNSW_FROM_CAGRA(int8_t); #undef CUVS_INST_HNSW_FROM_CAGRA -#define CUVS_INST_HNSW_SEARCH(T, QueriesT) \ - void search(raft::resources const& res, \ - const search_params& params, \ - const index& idx, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances) \ - { \ - detail::search(res, params, idx, queries, neighbors, distances); \ +#define CUVS_INST_HNSW_SEARCH(T) \ + void search(raft::resources const& res, \ + const search_params& params, \ + const index& idx, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + detail::search(res, params, idx, queries, neighbors, distances); \ } -CUVS_INST_HNSW_SEARCH(float, float); -CUVS_INST_HNSW_SEARCH(uint8_t, int); -CUVS_INST_HNSW_SEARCH(int8_t, int); +CUVS_INST_HNSW_SEARCH(float); +CUVS_INST_HNSW_SEARCH(uint8_t); +CUVS_INST_HNSW_SEARCH(int8_t); #undef CUVS_INST_HNSW_SEARCH diff --git a/cpp/src/neighbors/hnsw_c.cpp b/cpp/src/neighbors/hnsw_c.cpp index ab5268a6d..a19875641 100644 --- a/cpp/src/neighbors/hnsw_c.cpp +++ b/cpp/src/neighbors/hnsw_c.cpp @@ -31,7 +31,7 @@ #include namespace { -template +template void _search(cuvsResources_t res, cuvsHnswSearchParams params, cuvsHnswIndex index, @@ -46,7 +46,7 @@ void _search(cuvsResources_t res, search_params.ef = params.ef; search_params.num_threads = params.numThreads; - using queries_mdspan_type = raft::host_matrix_view; + using queries_mdspan_type = raft::host_matrix_view; using neighbors_mdspan_type = raft::host_matrix_view; using distances_mdspan_type = raft::host_matrix_view; auto queries_mds = cuvs::core::from_dlpack(queries_tensor); @@ -127,16 +127,13 @@ extern "C" cuvsError_t cuvsHnswSearch(cuvsResources_t res, auto index = *index_c_ptr; RAFT_EXPECTS(queries.dtype.code == index.dtype.code, "type mismatch between index and queries"); - RAFT_EXPECTS(queries.dtype.bits == 32, "number of bits in queries dtype should be 32"); if (index.dtype.code == kDLFloat) { - _search( - res, *params, index, queries_tensor, neighbors_tensor, distances_tensor); + _search(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor); } else if (index.dtype.code == kDLUInt) { - _search( - res, *params, index, queries_tensor, neighbors_tensor, distances_tensor); + _search(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor); } else if (index.dtype.code == kDLInt) { - _search(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor); + _search(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor); } else { RAFT_FAIL("Unsupported index dtype: %d and bits: %d", queries.dtype.code, queries.dtype.bits); } @@ -152,13 +149,10 @@ extern "C" cuvsError_t cuvsHnswDeserialize(cuvsResources_t res, return cuvs::core::translate_exceptions([=] { if (index->dtype.code == kDLFloat && index->dtype.bits == 32) { index->addr = reinterpret_cast(_deserialize(res, filename, dim, metric)); - index->dtype.code = kDLFloat; } else if (index->dtype.code == kDLUInt && index->dtype.bits == 8) { index->addr = reinterpret_cast(_deserialize(res, filename, dim, metric)); - index->dtype.code = kDLInt; } else if (index->dtype.code == kDLInt && index->dtype.bits == 8) { index->addr = reinterpret_cast(_deserialize(res, filename, dim, metric)); - index->dtype.code = kDLUInt; } else { RAFT_FAIL("Unsupported dtype in file %s", filename); } diff --git a/docs/source/c_api/neighbors.rst b/docs/source/c_api/neighbors.rst index dc55a74dc..9c3fce672 100644 --- a/docs/source/c_api/neighbors.rst +++ b/docs/source/c_api/neighbors.rst @@ -13,3 +13,4 @@ Nearest Neighbors neighbors_ivf_flat_c.rst neighbors_ivf_pq_c.rst neighbors_cagra_c.rst + neighbors_hnsw_c.rst diff --git a/docs/source/cpp_api/neighbors.rst b/docs/source/cpp_api/neighbors.rst index 0c68c8415..d55d58eb0 100644 --- a/docs/source/cpp_api/neighbors.rst +++ b/docs/source/cpp_api/neighbors.rst @@ -11,6 +11,7 @@ Nearest Neighbors neighbors_bruteforce.rst neighbors_cagra.rst + neighbors_hnsw.rst neighbors_ivf_flat.rst neighbors_ivf_pq.rst neighbors_nn_descent.rst diff --git a/docs/source/python_api/neighbors.rst b/docs/source/python_api/neighbors.rst index 022c50de3..cd4f2609c 100644 --- a/docs/source/python_api/neighbors.rst +++ b/docs/source/python_api/neighbors.rst @@ -11,5 +11,6 @@ Nearest Neighbors neighbors_brute_force.rst neighbors_cagra.rst + neighbors_hnsw.rst neighbors_ivf_flat.rst neighbors_ivf_pq.rst diff --git a/docs/source/python_api/neighbors_hnsw.rst b/docs/source/python_api/neighbors_hnsw.rst new file mode 100644 index 000000000..9922805b3 --- /dev/null +++ b/docs/source/python_api/neighbors_hnsw.rst @@ -0,0 +1,30 @@ +HNSW +==== + +This is a wrapper for hnswlib, to load a CAGRA index as an immutable HNSW index. The loaded HNSW index is only compatible in cuVS, and can be searched using wrapper functions. + +.. role:: py(code) + :language: python + :class: highlight + +Index search parameters +####################### + +.. autoclass:: cuvs.neighbors.hnsw.SearchParams + :members: + +Index +##### + +.. autoclass:: cuvs.neighbors.hnsw.Index + :members: + +Index Conversion +################ + +.. autofunction:: cuvs.neighbors.hnsw.from_cagra + +Index search +############ + +.. autofunction:: cuvs.neighbors.hnsw.search diff --git a/python/cuvs/cuvs/neighbors/CMakeLists.txt b/python/cuvs/cuvs/neighbors/CMakeLists.txt index 21c3db5da..f68bbea53 100644 --- a/python/cuvs/cuvs/neighbors/CMakeLists.txt +++ b/python/cuvs/cuvs/neighbors/CMakeLists.txt @@ -14,6 +14,7 @@ add_subdirectory(brute_force) add_subdirectory(cagra) +add_subdirectory(hnsw) add_subdirectory(ivf_flat) add_subdirectory(ivf_pq) add_subdirectory(filters) diff --git a/python/cuvs/cuvs/neighbors/cagra/cagra.pxd b/python/cuvs/cuvs/neighbors/cagra/cagra.pxd index b23c2a4b3..bba5a91a8 100644 --- a/python/cuvs/cuvs/neighbors/cagra/cagra.pxd +++ b/python/cuvs/cuvs/neighbors/cagra/cagra.pxd @@ -17,6 +17,7 @@ from libc.stdint cimport ( int8_t, + int32_t, int64_t, uint8_t, uint32_t, @@ -100,6 +101,8 @@ cdef extern from "cuvs/neighbors/cagra.h" nogil: cuvsError_t cuvsCagraIndexDestroy(cuvsCagraIndex_t index) + cuvsError_t cuvsCagraIndexGetDims(cuvsCagraIndex_t index, int32_t* dim) + cuvsError_t cuvsCagraBuild(cuvsResources_t res, cuvsCagraIndexParams* params, DLManagedTensor* dataset, @@ -117,6 +120,20 @@ cdef extern from "cuvs/neighbors/cagra.h" nogil: cuvsCagraIndex_t index, bool include_dataset) except + + cuvsError_t cuvsCagraSerializeToHnswlib(cuvsResources_t res, + const char * filename, + cuvsCagraIndex_t index) except + + cuvsError_t cuvsCagraDeserialize(cuvsResources_t res, const char * filename, cuvsCagraIndex_t index) except + + +cdef class Index: + """ + CAGRA index object. This object stores the trained CAGRA index state + which can be used to perform nearest neighbors searches. + """ + + cdef cuvsCagraIndex_t index + cdef bool trained + cdef str active_index_type diff --git a/python/cuvs/cuvs/neighbors/cagra/cagra.pyx b/python/cuvs/cuvs/neighbors/cagra/cagra.pyx index f940ab8bf..95209dbeb 100644 --- a/python/cuvs/cuvs/neighbors/cagra/cagra.pyx +++ b/python/cuvs/cuvs/neighbors/cagra/cagra.pyx @@ -36,6 +36,7 @@ from pylibraft.neighbors.common import _check_input_array from libc.stdint cimport ( int8_t, + int32_t, int64_t, uint8_t, uint32_t, @@ -206,16 +207,9 @@ cdef class IndexParams: cdef class Index: - """ - CAGRA index object. This object stores the trained CAGRA index state - which can be used to perform nearest neighbors searches. - """ - - cdef cuvsCagraIndex_t index - cdef bool trained - def __cinit__(self): self.trained = False + self.active_index_type = None check_cuvs(cuvsCagraIndexCreate(&self.index)) def __dealloc__(self): @@ -226,6 +220,12 @@ cdef class Index: def trained(self): return self.trained + @property + def dim(self): + cdef int32_t dim + check_cuvs(cuvsCagraIndexGetDims(self.index, &dim)) + return dim + def __repr__(self): # todo(dgd): update repr as we expose data through C API attr_str = [] @@ -299,6 +299,7 @@ def build(IndexParams index_params, dataset, resources=None): idx.index )) idx.trained = True + idx.active_index_type = dataset_ai.dtype.name return idx diff --git a/python/cuvs/cuvs/neighbors/hnsw/CMakeLists.txt b/python/cuvs/cuvs/neighbors/hnsw/CMakeLists.txt new file mode 100644 index 000000000..1f9c422ca --- /dev/null +++ b/python/cuvs/cuvs/neighbors/hnsw/CMakeLists.txt @@ -0,0 +1,24 @@ +# ============================================================================= +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. +# ============================================================================= + +# Set the list of Cython files to build +set(cython_sources hnsw.pyx) +set(linked_libraries cuvs::cuvs cuvs::c_api) + +# Build all of the Cython targets +rapids_cython_create_modules( + CXX + SOURCE_FILES "${cython_sources}" + LINKED_LIBRARIES "${linked_libraries}" ASSOCIATED_TARGETS cuvs MODULE_PREFIX neighbors_hnsw_ +) diff --git a/python/cuvs/cuvs/neighbors/hnsw/__init__.pxd b/python/cuvs/cuvs/neighbors/hnsw/__init__.pxd new file mode 100644 index 000000000..e69de29bb diff --git a/python/cuvs/cuvs/neighbors/hnsw/__init__.py b/python/cuvs/cuvs/neighbors/hnsw/__init__.py new file mode 100644 index 000000000..5efcdf68b --- /dev/null +++ b/python/cuvs/cuvs/neighbors/hnsw/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .hnsw import Index, SearchParams, from_cagra, load, save, search + +__all__ = [ + "Index", + "SearchParams", + "load", + "save", + "search", + "from_cagra", +] diff --git a/python/cuvs/cuvs/neighbors/hnsw/hnsw.pxd b/python/cuvs/cuvs/neighbors/hnsw/hnsw.pxd new file mode 100644 index 000000000..1cdc97406 --- /dev/null +++ b/python/cuvs/cuvs/neighbors/hnsw/hnsw.pxd @@ -0,0 +1,53 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: language_level=3 + +from libc.stdint cimport int32_t, uintptr_t + +from cuvs.common.c_api cimport cuvsError_t, cuvsResources_t +from cuvs.common.cydlpack cimport DLDataType, DLManagedTensor +from cuvs.distance_type cimport cuvsDistanceType + + +cdef extern from "cuvs/neighbors/hnsw.h" nogil: + ctypedef struct cuvsHnswSearchParams: + int32_t ef + int32_t numThreads + + ctypedef cuvsHnswSearchParams* cuvsHnswSearchParams_t + + ctypedef struct cuvsHnswIndex: + uintptr_t addr + DLDataType dtype + + ctypedef cuvsHnswIndex* cuvsHnswIndex_t + + cuvsError_t cuvsHnswIndexCreate(cuvsHnswIndex_t* index) + + cuvsError_t cuvsHnswIndexDestroy(cuvsHnswIndex_t index) + + cuvsError_t cuvsHnswSearch(cuvsResources_t res, + cuvsHnswSearchParams* params, + cuvsHnswIndex_t index, + DLManagedTensor* queries, + DLManagedTensor* neighbors, + DLManagedTensor* distances) except + + + cuvsError_t cuvsHnswDeserialize(cuvsResources_t res, + const char * filename, + int32_t dim, + cuvsDistanceType metric, + cuvsHnswIndex_t index) except + diff --git a/python/cuvs/cuvs/neighbors/hnsw/hnsw.pyx b/python/cuvs/cuvs/neighbors/hnsw/hnsw.pyx new file mode 100644 index 000000000..018fcfef9 --- /dev/null +++ b/python/cuvs/cuvs/neighbors/hnsw/hnsw.pyx @@ -0,0 +1,380 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: language_level=3 + +from libc.stdint cimport uint32_t +from libcpp cimport bool +from libcpp.string cimport string + +from cuvs.common.exceptions import check_cuvs +from cuvs.common.resources import auto_sync_resources + +from cuvs.common cimport cydlpack + +import numpy as np + +from cuvs.distance import DISTANCE_TYPES + +from cuvs.neighbors.cagra cimport cagra + +import os +import uuid + +from pylibraft.common import auto_convert_output +from pylibraft.common.cai_wrapper import wrap_array +from pylibraft.common.interruptible import cuda_interruptible +from pylibraft.neighbors.common import _check_input_array + + +cdef class SearchParams: + """ + HNSW search parameters + + Parameters + ---------- + ef: int, default = 200 + Maximum number of candidate list size used during search. + num_threads: int, default = 0 + Number of CPU threads used to increase search parallelism. + When set to 0, the number of threads is automatically determined + using OpenMP's `omp_get_max_threads()`. + """ + + cdef cuvsHnswSearchParams params + + def __init__(self, *, + ef=200, + num_threads=0): + self.params.ef = ef + self.params.numThreads = num_threads + + def __repr__(self): + attr_str = [attr + "=" + str(getattr(self, attr)) + for attr in [ + "ef", "num_threads"]] + return "SearchParams(type=HNSW, " + (", ".join(attr_str)) + ")" + + @property + def ef(self): + return self.params.ef + + @property + def num_threads(self): + return self.params.numThreads + + +cdef class Index: + """ + HNSW index object. This object stores the trained HNSW index state + which can be used to perform nearest neighbors searches. + """ + + cdef cuvsHnswIndex_t index + cdef bool trained + + def __cinit__(self): + self.trained = False + check_cuvs(cuvsHnswIndexCreate(&self.index)) + + def __dealloc__(self): + if self.index is not NULL: + check_cuvs(cuvsHnswIndexDestroy(self.index)) + + @property + def trained(self): + return self.trained + + def __repr__(self): + # todo(dgd): update repr as we expose data through C API + attr_str = [] + return "Index(type=HNSW, metric=L2" + (", ".join(attr_str)) + ")" + + +@auto_sync_resources +def save(filename, cagra.Index index, resources=None): + """ + Saves the CAGRA index to a file as an hnswlib index. + The saved index is immutable and can only be searched by the hnswlib + wrapper in cuVS, as the format is not compatible with the original + hnswlib. + + Saving / loading the index is experimental. The serialization format is + subject to change. + + Parameters + ---------- + filename : string + Name of the file. + index : Index + Trained CAGRA index. + {resources_docstring} + + Examples + -------- + >>> import cupy as cp + >>> from cuvs.neighbors import cagra + >>> n_samples = 50000 + >>> n_features = 50 + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> # Build index + >>> index = cagra.build(cagra.IndexParams(), dataset) + >>> # Serialize and deserialize the cagra index built + >>> hnsw.save("my_index.bin", index) + """ + cdef string c_filename = filename.encode('utf-8') + cdef cuvsResources_t res = resources.get_c_obj() + check_cuvs(cagra.cuvsCagraSerializeToHnswlib(res, + c_filename.c_str(), + index.index)) + + +@auto_sync_resources +def load(filename, dim, dtype, metric="sqeuclidean", resources=None): + """ + Loads base-layer-only hnswlib index from file, which was originally + saved as a built CAGRA index. The loaded index is immutable and can only + be searched by the hnswlib wrapper in cuVS, as the format is not + compatible with the original hnswlib. + + Saving / loading the index is experimental. The serialization format is + subject to change, therefore loading an index saved with a previous + version of cuVS is not guaranteed to work. + + Parameters + ---------- + filename : string + Name of the file. + dim : int + Dimensions of the training dataest + dtype : np.dtype of the saved index + Valid values for dtype: [np.float32, np.byte, np.ubyte] + metric : string denoting the metric type, default="sqeuclidean" + Valid values for metric: ["sqeuclidean", "inner_product"], where + - sqeuclidean is the euclidean distance without the square root + operation, i.e.: distance(a,b) = \\sum_i (a_i - b_i)^2, + - inner_product distance is defined as + distance(a, b) = \\sum_i a_i * b_i. + {resources_docstring} + + Returns + ------- + index : HnswIndex + + Examples + -------- + >>> import cupy as cp + >>> from cuvs.neighbors import cagra + >>> from cuvs.neighbors import hnsw + >>> n_samples = 50000 + >>> n_features = 50 + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> # Build index + >>> index = cagra.build(cagra.IndexParams(), dataset) + >>> # Serialize the CAGRA index to hnswlib base layer only index format + >>> hnsw.save("my_index.bin", index) + >>> index = hnsw.load("my_index.bin", n_features, np.float32, + ... "sqeuclidean") + """ + cdef Index idx = Index() + cdef cuvsResources_t res = resources.get_c_obj() + cdef string c_filename = filename.encode('utf-8') + cdef cydlpack.DLDataType dl_dtype + if dtype == np.float32: + dl_dtype.code = cydlpack.kDLFloat + dl_dtype.bits = 32 + dl_dtype.lanes = 1 + elif dtype == np.ubyte: + dl_dtype.code = cydlpack.kDLUInt + dl_dtype.bits = 8 + dl_dtype.lanes = 1 + elif dtype == np.byte: + dl_dtype.code = cydlpack.kDLInt + dl_dtype.bits = 8 + dl_dtype.lanes = 1 + else: + raise ValueError("Only float32 is supported for dtype") + + idx.index.dtype = dl_dtype + cdef cuvsDistanceType distance_type = DISTANCE_TYPES[metric] + + check_cuvs(cuvsHnswDeserialize( + res, + c_filename.c_str(), + dim, + distance_type, + idx.index + )) + idx.trained = True + return idx + + +@auto_sync_resources +def from_cagra(cagra.Index index, temporary_index_path=None, resources=None): + """ + Returns an hnsw base-layer-only index from a CAGRA index. + + NOTE: This method uses the filesystem to write the CAGRA index in + `/tmp/.bin` or the parameter `temporary_index_path` + if not None before reading it as an hnsw index, + then deleting the temporary file. The returned index is immutable + and can only be searched by the hnsw wrapper in cuVS, as the + format is not compatible with the original hnswlib library. + By `base_layer_only`, we mean that the hnsw index is created + without the additional layers that are used for the hierarchical + search in hnswlib. Instead, the base layer is used for the search. + + Saving / loading the index is experimental. The serialization format is + subject to change. + + Parameters + ---------- + index : Index + Trained CAGRA index. + temporary_index_path : string, default = None + Path to save the temporary index file. If None, the temporary file + will be saved in `/tmp/.bin`. + {resources_docstring} + + Examples + -------- + >>> import cupy as cp + >>> from cuvs.neighbors import cagra + >>> from cuvs.neighbors import hnsw + >>> n_samples = 50000 + >>> n_features = 50 + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> # Build index + >>> index = cagra.build(cagra.IndexParams(), dataset) + >>> # Serialize the CAGRA index to hnswlib base layer only index format + >>> hnsw_index = hnsw.from_cagra(index) + """ + uuid_num = uuid.uuid4() + filename = temporary_index_path if temporary_index_path else \ + f"/tmp/{uuid_num}.bin" + save(filename, index, resources=resources) + hnsw_index = load(filename, index.dim, np.dtype(index.active_index_type), + "sqeuclidean", resources=resources) + os.remove(filename) + return hnsw_index + + +@auto_sync_resources +@auto_convert_output +def search(SearchParams search_params, + Index index, + queries, + k, + neighbors=None, + distances=None, + resources=None): + """ + Find the k nearest neighbors for each query. + + Parameters + ---------- + search_params : SearchParams + index : Index + Trained CAGRA index. + queries : CUDA array interface compliant matrix shape (n_samples, dim) + Supported dtype [float, int] + k : int + The number of neighbors. + neighbors : Optional CUDA array interface compliant matrix shape + (n_queries, k), dtype uint64_t. If supplied, neighbor + indices will be written here in-place. (default None) + distances : Optional CUDA array interface compliant matrix shape + (n_queries, k) If supplied, the distances to the + neighbors will be written here in-place. (default None) + {resources_docstring} + + Examples + -------- + >>> import cupy as cp + >>> from cuvs.neighbors import cagra, hnsw + >>> n_samples = 50000 + >>> n_features = 50 + >>> n_queries = 1000 + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> # Build index + >>> index = cagra.build(cagra.IndexParams(), dataset) + >>> # Search using the built index + >>> queries = cp.random.random_sample((n_queries, n_features), + ... dtype=cp.float32) + >>> k = 10 + >>> search_params = hnsw.SearchParams( + ... ef=200, + ... num_threads=0 + ... ) + >>> # Convert CAGRA index to HNSW + >>> hnsw_index = hnsw.from_cagra(index) + >>> # Using a pooling allocator reduces overhead of temporary array + >>> # creation during search. This is useful if multiple searches + >>> # are performed with same query size. + >>> distances, neighbors = hnsw.search(search_params, index, queries, + ... k) + >>> neighbors = cp.asarray(neighbors) + >>> distances = cp.asarray(distances) + """ + if not index.trained: + raise ValueError("Index needs to be built before calling search.") + + # todo(dgd): we can make the check of dtype a parameter of wrap_array + # in RAFT to make this a single call + queries_ai = wrap_array(queries) + _check_input_array(queries_ai, [np.dtype('float32'), + np.dtype('uint8'), + np.dtype('int8')]) + + cdef uint32_t n_queries = queries_ai.shape[0] + + if neighbors is None: + neighbors = np.empty((n_queries, k), dtype='uint64') + + neighbors_ai = wrap_array(neighbors) + _check_input_array(neighbors_ai, [np.dtype('uint64')], + exp_rows=n_queries, exp_cols=k) + + if distances is None: + distances = np.empty((n_queries, k), dtype='float32') + + distances_ai = wrap_array(distances) + _check_input_array(distances_ai, [np.dtype('float32')], + exp_rows=n_queries, exp_cols=k) + + cdef cuvsHnswSearchParams* params = &search_params.params + cdef cydlpack.DLManagedTensor* queries_dlpack = \ + cydlpack.dlpack_c(queries_ai) + cdef cydlpack.DLManagedTensor* neighbors_dlpack = \ + cydlpack.dlpack_c(neighbors_ai) + cdef cydlpack.DLManagedTensor* distances_dlpack = \ + cydlpack.dlpack_c(distances_ai) + cdef cuvsResources_t res = resources.get_c_obj() + + with cuda_interruptible(): + check_cuvs(cuvsHnswSearch( + res, + params, + index.index, + queries_dlpack, + neighbors_dlpack, + distances_dlpack + )) + + return (distances, neighbors) diff --git a/python/cuvs/cuvs/test/test_hnsw.py b/python/cuvs/cuvs/test/test_hnsw.py new file mode 100644 index 000000000..0ae97266b --- /dev/null +++ b/python/cuvs/cuvs/test/test_hnsw.py @@ -0,0 +1,97 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# h ttp://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import pytest +from sklearn.neighbors import NearestNeighbors +from sklearn.preprocessing import normalize + +from cuvs.neighbors import cagra, hnsw +from cuvs.test.ann_utils import calc_recall, generate_data + + +def run_hnsw_build_search_test( + n_rows=1000, + n_cols=10, + n_queries=100, + k=10, + dtype=np.float32, + metric="sqeuclidean", + build_algo="ivf_pq", + intermediate_graph_degree=128, + graph_degree=64, + search_params={}, +): + dataset = generate_data((n_rows, n_cols), dtype) + if metric == "inner_product": + dataset = normalize(dataset, norm="l2", axis=1) + if dtype in [np.int8, np.uint8]: + pytest.skip( + "inner_product metric is not supported for int8/uint8 data" + ) + if build_algo == "nn_descent": + pytest.skip("inner_product metric is not supported for nn_descent") + + build_params = cagra.IndexParams( + metric=metric, + intermediate_graph_degree=intermediate_graph_degree, + graph_degree=graph_degree, + build_algo=build_algo, + ) + + index = cagra.build(build_params, dataset) + + assert index.trained + + hnsw_index = hnsw.from_cagra(index) + + queries = generate_data((n_queries, n_cols), dtype) + + search_params = hnsw.SearchParams(**search_params) + + out_dist, out_idx = hnsw.search(search_params, hnsw_index, queries, k) + + # Calculate reference values with sklearn + skl_metric = { + "sqeuclidean": "sqeuclidean", + "inner_product": "cosine", + "euclidean": "euclidean", + }[metric] + nn_skl = NearestNeighbors( + n_neighbors=k, algorithm="brute", metric=skl_metric + ) + nn_skl.fit(dataset) + skl_dist, skl_idx = nn_skl.kneighbors(queries, return_distance=True) + + recall = calc_recall(out_idx, skl_idx) + assert recall > 0.95 + + +@pytest.mark.parametrize("dtype", [np.float32, np.int8, np.uint8]) +@pytest.mark.parametrize("k", [10, 20]) +@pytest.mark.parametrize("ef", [30, 40]) +@pytest.mark.parametrize("num_threads", [2, 4]) +@pytest.mark.parametrize("metric", ["sqeuclidean"]) +@pytest.mark.parametrize("build_algo", ["ivf_pq", "nn_descent"]) +def test_hnsw(dtype, k, ef, num_threads, metric, build_algo): + # Note that inner_product tests use normalized input which we cannot + # represent in int8, therefore we test only sqeuclidean metric here. + run_hnsw_build_search_test( + dtype=dtype, + k=k, + metric=metric, + build_algo=build_algo, + search_params={"ef": ef, "num_threads": num_threads}, + )