diff --git a/faiss-sys/src/bindings.rs b/faiss-sys/src/bindings.rs index 4da43bb..8f51d34 100644 --- a/faiss-sys/src/bindings.rs +++ b/faiss-sys/src/bindings.rs @@ -694,18 +694,39 @@ extern "C" { index: *mut FaissIndexFlat1D, ) -> ::std::os::raw::c_int; } -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct FaissIndexIVFFlat_H { - _unused: [u8; 0], -} -pub type FaissIndexIVFFlat = FaissIndexIVFFlat_H; +pub type FaissIndexIVFFlat = FaissIndex_H; extern "C" { pub fn faiss_IndexIVFFlat_free(obj: *mut FaissIndexIVFFlat); } extern "C" { pub fn faiss_IndexIVFFlat_cast(arg1: *mut FaissIndex) -> *mut FaissIndexIVFFlat; } +extern "C" { + pub fn faiss_IndexIVFFlat_nprobe(arg1: *const FaissIndexIVFFlat) -> usize; +} +extern "C" { + pub fn faiss_IndexIVFFlat_nlist(arg1: *const FaissIndexIVFFlat) -> usize; +} +extern "C" { + pub fn faiss_IndexIVFFlat_set_nprobe(arg1: *mut FaissIndexIVFFlat, arg2: usize); +} +extern "C" { + pub fn faiss_IndexIVFFlat_quantizer(arg1: *const FaissIndexIVFFlat) -> *mut FaissIndex; +} +extern "C" { + pub fn faiss_IndexIVFFlat_quantizer_trains_alone( + arg1: *const FaissIndexIVFFlat, + ) -> ::std::os::raw::c_char; +} +extern "C" { + pub fn faiss_IndexIVFFlat_own_fields(arg1: *const FaissIndexIVFFlat) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn faiss_IndexIVFFlat_set_own_fields( + arg1: *mut FaissIndexIVFFlat, + arg2: ::std::os::raw::c_int, + ); +} extern "C" { #[doc = " Inverted file with stored vectors. Here the inverted file"] #[doc = " pre-selects the vectors to be searched, but they are not otherwise"] @@ -766,6 +787,9 @@ extern "C" { extern "C" { pub fn faiss_IndexIVF_nprobe(arg1: *const FaissIndexIVF) -> usize; } +extern "C" { + pub fn faiss_IndexIVF_set_nprobe(arg1: *mut FaissIndexIVF, arg2: usize); +} extern "C" { pub fn faiss_IndexIVF_quantizer(arg1: *const FaissIndexIVF) -> *mut FaissIndex; } @@ -777,6 +801,9 @@ extern "C" { extern "C" { pub fn faiss_IndexIVF_own_fields(arg1: *const FaissIndexIVF) -> ::std::os::raw::c_int; } +extern "C" { + pub fn faiss_IndexIVF_set_own_fields(arg1: *mut FaissIndexIVF, arg2: ::std::os::raw::c_int); +} extern "C" { #[doc = " moves the entries from another dataset to self. On output,"] #[doc = " other is empty. add_id is added to all moved ids (for"] @@ -834,7 +861,7 @@ extern "C" { pub fn faiss_IndexIVF_get_list_size(index: *const FaissIndexIVF, list_no: usize) -> usize; } extern "C" { - #[doc = " intialize a direct map"] + #[doc = " initialize a direct map"] #[doc = ""] #[doc = " @param new_maintain_direct_map if true, create a direct map,"] #[doc = " else clear it"] diff --git a/src/index/ivf_flat.rs b/src/index/ivf_flat.rs new file mode 100644 index 0000000..2504056 --- /dev/null +++ b/src/index/ivf_flat.rs @@ -0,0 +1,355 @@ +//! Interface and implementation to IVFFlat index type. + +use super::*; + +use crate::error::Result; +use crate::faiss_try; +use std::mem; +use std::ptr; +use std::os::raw::c_int; + +/// Alias for the native implementation of a flat index. +pub type IVFFlatIndex = IVFFlatIndexImpl; + +/// Native implementation of a flat index. +#[derive(Debug)] +pub struct IVFFlatIndexImpl { + inner: *mut FaissIndexIVFFlat, +} + +unsafe impl Send for IVFFlatIndexImpl {} +unsafe impl Sync for IVFFlatIndexImpl {} + +impl CpuIndex for IVFFlatIndexImpl {} + +impl Drop for IVFFlatIndexImpl { + fn drop(&mut self) { + unsafe { + faiss_IndexIVFFlat_free(self.inner); + } + } +} + +impl IVFFlatIndexImpl { + /// Create a new IVF flat index. + pub fn new_by_ref( + quantizer: &flat::FlatIndex, + d: u32, + nlist: u32, + metric: MetricType, + ) -> Result { + IVFFlatIndexImpl::new_helper(quantizer, d, nlist, metric, false) + } + + fn new_helper( + quantizer: &flat::FlatIndex, + d: u32, + nlist: u32, + metric: MetricType, + own_fields: bool, + ) -> Result { + unsafe { + let metric = metric as c_uint; + let mut inner = ptr::null_mut(); + faiss_try(faiss_IndexIVFFlat_new_with_metric( + &mut inner, + quantizer.inner_ptr(), + d as usize, + nlist as usize, + metric, + ))?; + faiss_IndexIVFFlat_set_own_fields(inner, c_int::from(own_fields)); + Ok(IVFFlatIndexImpl { inner }) + } + } + + /// Create a new IVF flat index. + // The index owns the quantizer. + pub fn new(quantizer: flat::FlatIndex, d: u32, nlist: u32, metric: MetricType) -> Result { + let index = IVFFlatIndexImpl::new_helper(&quantizer, d, nlist, metric, true)?; + std::mem::forget(quantizer); + + Ok(index) + } + + /// Create a new IVF flat index with L2 as the metric type. + pub fn new_l2_by_ref(quantizer: &flat::FlatIndex, d: u32, nlist: u32) -> Result { + IVFFlatIndexImpl::new_by_ref(quantizer, d, nlist, MetricType::L2) + } + + /// Create a new IVF flat index with L2 as the metric type. + // The index owns the quantizer. + pub fn new_l2(quantizer: flat::FlatIndex, d: u32, nlist: u32) -> Result { + IVFFlatIndexImpl::new(quantizer, d, nlist, MetricType::L2) + } + + /// Create a new IVF flat index with IP (inner product) as the metric type. + pub fn new_ip_by_ref(quantizer: &flat::FlatIndex, d: u32, nlist: u32) -> Result { + IVFFlatIndexImpl::new_by_ref(quantizer, d, nlist, MetricType::InnerProduct) + } + + /// Create a new IVF flat index with IP (inner product) as the metric type. + // The index owns the quantizer. + pub fn new_ip(quantizer: flat::FlatIndex, d: u32, nlist: u32) -> Result { + IVFFlatIndexImpl::new(quantizer, d, nlist, MetricType::InnerProduct) + } + + /// Get number of probes at query time + pub fn nprobe(&self) -> u32 { + unsafe { faiss_IndexIVFFlat_nprobe(self.inner_ptr()) as u32 } + } + + /// Set number of probes at query time + pub fn set_nprobe(&mut self, value: u32) { + unsafe { + faiss_IndexIVFFlat_set_nprobe(self.inner_ptr(), value as usize); + } + } + + /// Get number of possible key values + pub fn nlist(&self) -> u32 { + unsafe { faiss_IndexIVFFlat_nlist(self.inner_ptr()) as u32 } + } + + /// Get train type + pub fn train_type(&self) -> Option { + unsafe { + let code = faiss_IndexIVFFlat_quantizer_trains_alone(self.inner_ptr()); + TrainType::from_code(code) + } + } +} + +/** + * = 0: use the quantizer as index in a kmeans training + * = 1: just pass on the training set to the train() of the quantizer + * = 2: kmeans training on a flat index + add the centroids to the quantizer + */ + #[derive(Debug, Copy, Clone, Eq, Hash, PartialEq)] +pub enum TrainType { + /// use the quantizer as index in a kmeans training + QuantizerAsIndex, + /// just pass on the training set to the train() of the quantizer + QuantizerTrainsAlone, + /// kmeans training on a flat index + add the centroids to the quantizer + FlatIndexAndQuantizer, +} + +impl TrainType { + pub(crate) fn from_code(code: i8) -> Option { + match code { + 0 => Some(TrainType::QuantizerAsIndex), + 1 => Some(TrainType::QuantizerTrainsAlone), + 2 => Some(TrainType::FlatIndexAndQuantizer), + _ => None, + } + } +} + +impl NativeIndex for IVFFlatIndexImpl { + fn inner_ptr(&self) -> *mut FaissIndex { + self.inner + } +} + +impl FromInnerPtr for IVFFlatIndexImpl { + unsafe fn from_inner_ptr(inner_ptr: *mut FaissIndex) -> Self { + IVFFlatIndexImpl { + inner: inner_ptr as *mut FaissIndexIVFFlat, + } + } +} + +impl_native_index!(IVFFlatIndex); + +impl_native_index_clone!(IVFFlatIndex); + +impl ConcurrentIndex for IVFFlatIndexImpl { + fn assign(&self, query: &[f32], k: usize) -> Result { + unsafe { + let nq = query.len() / self.d() as usize; + let mut out_labels = vec![Idx::none(); k * nq]; + faiss_try(faiss_Index_assign( + self.inner, + nq as idx_t, + query.as_ptr(), + out_labels.as_mut_ptr() as *mut _, + k as i64, + ))?; + Ok(AssignSearchResult { labels: out_labels }) + } + } + fn search(&self, query: &[f32], k: usize) -> Result { + unsafe { + let nq = query.len() / self.d() as usize; + let mut distances = vec![0_f32; k * nq]; + let mut labels = vec![Idx::none(); k * nq]; + faiss_try(faiss_Index_search( + self.inner, + nq as idx_t, + query.as_ptr(), + k as idx_t, + distances.as_mut_ptr(), + labels.as_mut_ptr() as *mut _, + ))?; + Ok(SearchResult { distances, labels }) + } + } + fn range_search(&self, query: &[f32], radius: f32) -> Result { + unsafe { + let nq = (query.len() / self.d() as usize) as idx_t; + let mut p_res: *mut FaissRangeSearchResult = ptr::null_mut(); + faiss_try(faiss_RangeSearchResult_new(&mut p_res, nq))?; + faiss_try(faiss_Index_range_search( + self.inner, + nq, + query.as_ptr(), + radius, + p_res, + ))?; + Ok(RangeSearchResult { inner: p_res }) + } + } +} + +impl IndexImpl { + /// Attempt a dynamic cast of an index to the IVF flat index type. + pub fn into_ivf_flat(self) -> Result { + unsafe { + let new_inner = faiss_IndexIVFFlat_cast(self.inner_ptr()); + if new_inner.is_null() { + Err(Error::BadCast) + } else { + mem::forget(self); + Ok(IVFFlatIndexImpl { inner: new_inner }) + } + } + } +} + +#[cfg(test)] +mod tests { + + use super::IVFFlatIndexImpl; + use crate::index::flat::FlatIndexImpl; + use crate::index::{index_factory, ConcurrentIndex, Idx, Index}; + use crate::MetricType; + + const D: u32 = 8; + + #[test] + // #[ignore] + fn index_search() { + let q = FlatIndexImpl::new_l2(D).unwrap(); + let mut index = IVFFlatIndexImpl::new_l2_by_ref(&q, D, 1).unwrap(); + assert_eq!(index.d(), D); + assert_eq!(index.ntotal(), 0); + let some_data = &[ + 7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 4., + -4., -8., 1., 1., 2., 4., -1., 8., 8., 10., -10., -10., 10., -10., 10., 16., 16., 32., + 25., 20., 20., 40., 15., + ]; + index.train(some_data).unwrap(); + index.add(some_data).unwrap(); + assert_eq!(index.ntotal(), 5); + + let my_query = [0.; D as usize]; + let result = index.search(&my_query, 3).unwrap(); + assert_eq!(result.labels.len(), 3); + assert!(result.labels.into_iter().all(Idx::is_some)); + assert_eq!(result.distances.len(), 3); + assert!(result.distances.iter().all(|x| *x > 0.)); + + let my_query = [100.; D as usize]; + // flat index can be used behind an immutable ref + let result = (&index).search(&my_query, 3).unwrap(); + assert_eq!(result.labels.len(), 3); + assert!(result.labels.into_iter().all(Idx::is_some)); + assert_eq!(result.distances.len(), 3); + assert!(result.distances.iter().all(|x| *x > 0.)); + + index.reset().unwrap(); + assert_eq!(index.ntotal(), 0); + } + + #[test] + fn index_search_own() { + let q = FlatIndexImpl::new_l2(D).unwrap(); + let mut index = IVFFlatIndexImpl::new_l2(q, D, 1).unwrap(); + assert_eq!(index.d(), D); + assert_eq!(index.ntotal(), 0); + let some_data = &[ + 7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 4., + -4., -8., 1., 1., 2., 4., -1., 8., 8., 10., -10., -10., 10., -10., 10., 16., 16., 32., + 25., 20., 20., 40., 15., + ]; + index.train(some_data).unwrap(); + index.add(some_data).unwrap(); + assert_eq!(index.ntotal(), 5); + + let my_query = [0.; D as usize]; + let result = index.search(&my_query, 3).unwrap(); + assert_eq!(result.labels.len(), 3); + assert!(result.labels.into_iter().all(Idx::is_some)); + assert_eq!(result.distances.len(), 3); + assert!(result.distances.iter().all(|x| *x > 0.)); + + let my_query = [100.; D as usize]; + // flat index can be used behind an immutable ref + let result = (&index).search(&my_query, 3).unwrap(); + assert_eq!(result.labels.len(), 3); + assert!(result.labels.into_iter().all(Idx::is_some)); + assert_eq!(result.distances.len(), 3); + assert!(result.distances.iter().all(|x| *x > 0.)); + + index.reset().unwrap(); + assert_eq!(index.ntotal(), 0); + } + + #[test] + fn index_assign() { + let q = FlatIndexImpl::new_l2(D).unwrap(); + let mut index = IVFFlatIndexImpl::new_l2_by_ref(&q, D, 1).unwrap(); + assert_eq!(index.d(), D); + assert_eq!(index.ntotal(), 0); + let some_data = &[ + 7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 4., + -4., -8., 1., 1., 2., 4., -1., 8., 8., 10., -10., -10., 10., -10., 10., 16., 16., 32., + 25., 20., 20., 40., 15., + ]; + index.train(some_data).unwrap(); + index.add(some_data).unwrap(); + assert_eq!(index.ntotal(), 5); + + let my_query = [0.; D as usize]; + let result = index.assign(&my_query, 3).unwrap(); + assert_eq!(result.labels.len(), 3); + assert!(result.labels.into_iter().all(Idx::is_some)); + + let my_query = [100.; D as usize]; + // flat index can be used behind an immutable ref + let result = (&index).assign(&my_query, 3).unwrap(); + assert_eq!(result.labels.len(), 3); + assert!(result.labels.into_iter().all(Idx::is_some)); + + index.reset().unwrap(); + assert_eq!(index.ntotal(), 0); + } + + #[test] + fn ivf_flat_index_from_cast() { + let mut index = index_factory(8, "IVF1,Flat", MetricType::L2).unwrap(); + let some_data = &[ + 7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0., 0., + 0., 1., 1., 0., 0., -1., 100., 100., 100., 100., -100., 100., 100., 100., 120., 100., + 100., 105., -100., 100., 100., 105., + ]; + index.train(some_data).unwrap(); + index.add(some_data).unwrap(); + assert_eq!(index.ntotal(), 5); + + let index: IVFFlatIndexImpl = index.into_ivf_flat().unwrap(); + assert_eq!(index.is_trained(), true); + assert_eq!(index.ntotal(), 5); + } +} diff --git a/src/index/mod.rs b/src/index/mod.rs index 77792de..cc9dcb1 100644 --- a/src/index/mod.rs +++ b/src/index/mod.rs @@ -25,6 +25,7 @@ pub mod autotune; pub mod flat; pub mod id_map; pub mod io; +pub mod ivf_flat; pub mod lsh; #[cfg(feature = "gpu")]