From fe885292b06e3c7f093ac469d8477d031d841fb6 Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Sun, 28 Mar 2021 19:12:33 +0300 Subject: [PATCH 1/7] Added high-level IVFFlatIndex impl --- faiss-sys/src/bindings.rs | 7 +- src/index/ivf_flat.rs | 201 ++++++++++++++++++++++++++++++++++++++ src/index/mod.rs | 1 + 3 files changed, 203 insertions(+), 6 deletions(-) create mode 100644 src/index/ivf_flat.rs diff --git a/faiss-sys/src/bindings.rs b/faiss-sys/src/bindings.rs index 4da43bb..66f76c9 100644 --- a/faiss-sys/src/bindings.rs +++ b/faiss-sys/src/bindings.rs @@ -694,12 +694,7 @@ extern "C" { index: *mut FaissIndexFlat1D, ) -> ::std::os::raw::c_int; } -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct FaissIndexIVFFlat_H { - _unused: [u8; 0], -} -pub type FaissIndexIVFFlat = FaissIndexIVFFlat_H; +pub type FaissIndexIVFFlat = FaissIndex_H; extern "C" { pub fn faiss_IndexIVFFlat_free(obj: *mut FaissIndexIVFFlat); } diff --git a/src/index/ivf_flat.rs b/src/index/ivf_flat.rs new file mode 100644 index 0000000..99dc82f --- /dev/null +++ b/src/index/ivf_flat.rs @@ -0,0 +1,201 @@ +//! Interface and implementation to IVFFlat index type. + +use super::*; + +use crate::error::Result; +use crate::faiss_try; +use std::mem; +use std::ptr; + +/// Alias for the native implementation of a flat index. +pub type IVFFlatIndex = IVFFlatIndexImpl; + +/// Native implementation of a flat index. +#[derive(Debug)] +pub struct IVFFlatIndexImpl { + inner: *mut FaissIndexIVFFlat, +} + +unsafe impl Send for IVFFlatIndexImpl {} +unsafe impl Sync for IVFFlatIndexImpl {} + +impl CpuIndex for IVFFlatIndexImpl {} + +impl Drop for IVFFlatIndexImpl { + fn drop(&mut self) { + unsafe { + faiss_IndexIVFFlat_free(self.inner); + } + } +} + +impl IVFFlatIndexImpl { + /// Create a new IVF flat index. + pub fn new(quantizer: flat::FlatIndex, d: u32, nlist: u32, metric: MetricType) -> Result { + unsafe { + let metric = metric as c_uint; + let mut inner = ptr::null_mut(); + faiss_try(faiss_IndexIVFFlat_new_with_metric( + &mut inner, + quantizer.inner_ptr(), + d as usize, + nlist as usize, + metric, + ))?; + + mem::forget(quantizer); // own_fields default == true + Ok(IVFFlatIndexImpl { inner }) + } + } + + /// Create a new IVF flat index with L2 as the metric type. + pub fn new_l2(quantizer: flat::FlatIndex, d: u32, nlist: u32) -> Result { + IVFFlatIndexImpl::new(quantizer, d, nlist, MetricType::L2) + } + + /// Create a new IVF flat index with IP (inner product) as the metric type. + pub fn new_ip(quantizer: flat::FlatIndex, d: u32, nlist: u32) -> Result { + IVFFlatIndexImpl::new(quantizer, d, nlist, MetricType::InnerProduct) + } +} + +impl NativeIndex for IVFFlatIndexImpl { + fn inner_ptr(&self) -> *mut FaissIndex { + self.inner + } +} + +impl FromInnerPtr for IVFFlatIndexImpl { + unsafe fn from_inner_ptr(inner_ptr: *mut FaissIndex) -> Self { + IVFFlatIndexImpl { + inner: inner_ptr as *mut FaissIndexIVFFlat, + } + } +} + +impl_native_index!(IVFFlatIndex); + +impl_native_index_clone!(IVFFlatIndex); + +impl ConcurrentIndex for IVFFlatIndexImpl { + fn assign(&self, query: &[f32], k: usize) -> Result { + unsafe { + let nq = query.len() / self.d() as usize; + let mut out_labels = vec![Idx::none(); k * nq]; + faiss_try(faiss_Index_assign( + self.inner, + nq as idx_t, + query.as_ptr(), + out_labels.as_mut_ptr() as *mut _, + k as i64, + ))?; + Ok(AssignSearchResult { labels: out_labels }) + } + } + fn search(&self, query: &[f32], k: usize) -> Result { + unsafe { + let nq = query.len() / self.d() as usize; + let mut distances = vec![0_f32; k * nq]; + let mut labels = vec![Idx::none(); k * nq]; + faiss_try(faiss_Index_search( + self.inner, + nq as idx_t, + query.as_ptr(), + k as idx_t, + distances.as_mut_ptr(), + labels.as_mut_ptr() as *mut _, + ))?; + Ok(SearchResult { distances, labels }) + } + } + fn range_search(&self, query: &[f32], radius: f32) -> Result { + unsafe { + let nq = (query.len() / self.d() as usize) as idx_t; + let mut p_res: *mut FaissRangeSearchResult = ptr::null_mut(); + faiss_try(faiss_RangeSearchResult_new(&mut p_res, nq))?; + faiss_try(faiss_Index_range_search( + self.inner, + nq, + query.as_ptr(), + radius, + p_res, + ))?; + Ok(RangeSearchResult { inner: p_res }) + } + } +} + +#[cfg(test)] +mod tests { + + use super::IVFFlatIndex; + use crate::index::flat::FlatIndexImpl; + use crate::index::{ConcurrentIndex, Idx, Index}; + + const D: u32 = 8; + + #[test] + // #[ignore] + fn index_search() { + let q = FlatIndexImpl::new_l2(D).unwrap(); + let mut index = IVFFlatIndex::new_l2(q, D, 1).unwrap(); + assert_eq!(index.d(), D); + assert_eq!(index.ntotal(), 0); + let some_data = &[ + 7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 4., + -4., -8., 1., 1., 2., 4., -1., 8., 8., 10., -10., -10., 10., -10., 10., 16., 16., 32., + 25., 20., 20., 40., 15., + ]; + index.train(some_data).unwrap(); + index.add(some_data).unwrap(); + assert_eq!(index.ntotal(), 5); + + let my_query = [0.; D as usize]; + let result = index.search(&my_query, 3).unwrap(); + assert_eq!(result.labels.len(), 3); + assert!(result.labels.into_iter().all(Idx::is_some)); + assert_eq!(result.distances.len(), 3); + assert!(result.distances.iter().all(|x| *x > 0.)); + + let my_query = [100.; D as usize]; + // flat index can be used behind an immutable ref + let result = (&index).search(&my_query, 3).unwrap(); + assert_eq!(result.labels.len(), 3); + assert!(result.labels.into_iter().all(Idx::is_some)); + assert_eq!(result.distances.len(), 3); + assert!(result.distances.iter().all(|x| *x > 0.)); + + index.reset().unwrap(); + assert_eq!(index.ntotal(), 0); + } + + #[test] + fn index_assign() { + let q = FlatIndexImpl::new_l2(D).unwrap(); + let mut index = IVFFlatIndex::new_l2(q, D, 1).unwrap(); + assert_eq!(index.d(), D); + assert_eq!(index.ntotal(), 0); + let some_data = &[ + 7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 4., + -4., -8., 1., 1., 2., 4., -1., 8., 8., 10., -10., -10., 10., -10., 10., 16., 16., 32., + 25., 20., 20., 40., 15., + ]; + index.train(some_data).unwrap(); + index.add(some_data).unwrap(); + assert_eq!(index.ntotal(), 5); + + let my_query = [0.; D as usize]; + let result = index.assign(&my_query, 3).unwrap(); + assert_eq!(result.labels.len(), 3); + assert!(result.labels.into_iter().all(Idx::is_some)); + + let my_query = [100.; D as usize]; + // flat index can be used behind an immutable ref + let result = (&index).assign(&my_query, 3).unwrap(); + assert_eq!(result.labels.len(), 3); + assert!(result.labels.into_iter().all(Idx::is_some)); + + index.reset().unwrap(); + assert_eq!(index.ntotal(), 0); + } +} diff --git a/src/index/mod.rs b/src/index/mod.rs index 3853238..d41feb7 100644 --- a/src/index/mod.rs +++ b/src/index/mod.rs @@ -24,6 +24,7 @@ use faiss_sys::*; pub mod flat; pub mod id_map; pub mod io; +pub mod ivf_flat; pub mod lsh; #[cfg(feature = "gpu")] From 419a812d44237fac8a917fb5468157437d4ef929 Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Tue, 6 Apr 2021 15:50:51 +0300 Subject: [PATCH 2/7] Impl cast fn for IVFFlatIndexImpl --- src/index/ivf_flat.rs | 41 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 37 insertions(+), 4 deletions(-) diff --git a/src/index/ivf_flat.rs b/src/index/ivf_flat.rs index 99dc82f..26e63e9 100644 --- a/src/index/ivf_flat.rs +++ b/src/index/ivf_flat.rs @@ -125,12 +125,28 @@ impl ConcurrentIndex for IVFFlatIndexImpl { } } +impl IndexImpl { + /// Attempt a dynamic cast of an index to the IVF flat index type. + pub fn into_ivf_flat(self) -> Result { + unsafe { + let new_inner = faiss_IndexIVFFlat_cast(self.inner_ptr()); + if new_inner.is_null() { + Err(Error::BadCast) + } else { + mem::forget(self); + Ok(IVFFlatIndexImpl { inner: new_inner }) + } + } + } +} + #[cfg(test)] mod tests { - use super::IVFFlatIndex; + use super::IVFFlatIndexImpl; use crate::index::flat::FlatIndexImpl; - use crate::index::{ConcurrentIndex, Idx, Index}; + use crate::index::{index_factory, ConcurrentIndex, Idx, Index}; + use crate::MetricType; const D: u32 = 8; @@ -138,7 +154,7 @@ mod tests { // #[ignore] fn index_search() { let q = FlatIndexImpl::new_l2(D).unwrap(); - let mut index = IVFFlatIndex::new_l2(q, D, 1).unwrap(); + let mut index = IVFFlatIndexImpl::new_l2(q, D, 1).unwrap(); assert_eq!(index.d(), D); assert_eq!(index.ntotal(), 0); let some_data = &[ @@ -172,7 +188,7 @@ mod tests { #[test] fn index_assign() { let q = FlatIndexImpl::new_l2(D).unwrap(); - let mut index = IVFFlatIndex::new_l2(q, D, 1).unwrap(); + let mut index = IVFFlatIndexImpl::new_l2(q, D, 1).unwrap(); assert_eq!(index.d(), D); assert_eq!(index.ntotal(), 0); let some_data = &[ @@ -198,4 +214,21 @@ mod tests { index.reset().unwrap(); assert_eq!(index.ntotal(), 0); } + + #[test] + fn ivf_flat_index_from_cast() { + let mut index = index_factory(8, "IVF1,Flat", MetricType::L2).unwrap(); + let some_data = &[ + 7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0., 0., + 0., 1., 1., 0., 0., -1., 100., 100., 100., 100., -100., 100., 100., 100., 120., 100., + 100., 105., -100., 100., 100., 105., + ]; + index.train(some_data).unwrap(); + index.add(some_data).unwrap(); + assert_eq!(index.ntotal(), 5); + + let index: IVFFlatIndexImpl = index.into_ivf_flat().unwrap(); + assert_eq!(index.is_trained(), true); + assert_eq!(index.ntotal(), 5); + } } From af9a79fa338a739f4b1af4eb721cb41572c81af6 Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Thu, 15 Apr 2021 14:56:29 +0300 Subject: [PATCH 3/7] Expected own_fields == false for IndexIVF --- src/index/ivf_flat.rs | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/index/ivf_flat.rs b/src/index/ivf_flat.rs index 26e63e9..9623366 100644 --- a/src/index/ivf_flat.rs +++ b/src/index/ivf_flat.rs @@ -31,7 +31,12 @@ impl Drop for IVFFlatIndexImpl { impl IVFFlatIndexImpl { /// Create a new IVF flat index. - pub fn new(quantizer: flat::FlatIndex, d: u32, nlist: u32, metric: MetricType) -> Result { + pub fn new( + quantizer: &flat::FlatIndex, + d: u32, + nlist: u32, + metric: MetricType, + ) -> Result { unsafe { let metric = metric as c_uint; let mut inner = ptr::null_mut(); @@ -43,18 +48,17 @@ impl IVFFlatIndexImpl { metric, ))?; - mem::forget(quantizer); // own_fields default == true Ok(IVFFlatIndexImpl { inner }) } } /// Create a new IVF flat index with L2 as the metric type. - pub fn new_l2(quantizer: flat::FlatIndex, d: u32, nlist: u32) -> Result { + pub fn new_l2(quantizer: &flat::FlatIndex, d: u32, nlist: u32) -> Result { IVFFlatIndexImpl::new(quantizer, d, nlist, MetricType::L2) } /// Create a new IVF flat index with IP (inner product) as the metric type. - pub fn new_ip(quantizer: flat::FlatIndex, d: u32, nlist: u32) -> Result { + pub fn new_ip(quantizer: &flat::FlatIndex, d: u32, nlist: u32) -> Result { IVFFlatIndexImpl::new(quantizer, d, nlist, MetricType::InnerProduct) } } @@ -154,7 +158,7 @@ mod tests { // #[ignore] fn index_search() { let q = FlatIndexImpl::new_l2(D).unwrap(); - let mut index = IVFFlatIndexImpl::new_l2(q, D, 1).unwrap(); + let mut index = IVFFlatIndexImpl::new_l2(&q, D, 1).unwrap(); assert_eq!(index.d(), D); assert_eq!(index.ntotal(), 0); let some_data = &[ @@ -188,7 +192,7 @@ mod tests { #[test] fn index_assign() { let q = FlatIndexImpl::new_l2(D).unwrap(); - let mut index = IVFFlatIndexImpl::new_l2(q, D, 1).unwrap(); + let mut index = IVFFlatIndexImpl::new_l2(&q, D, 1).unwrap(); assert_eq!(index.d(), D); assert_eq!(index.ntotal(), 0); let some_data = &[ From 690a24646172d9a5b32291e7a5eea9f6432e196f Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Thu, 6 May 2021 11:05:26 +0300 Subject: [PATCH 4/7] Add setters, getters for IVFFlatIndexImpl Add: - nprobe - own_fields --- faiss-sys/src/bindings.rs | 21 ++++++++++++++++ src/index/ivf_flat.rs | 52 +++++++++++++++++++++++++++++++++++---- 2 files changed, 68 insertions(+), 5 deletions(-) diff --git a/faiss-sys/src/bindings.rs b/faiss-sys/src/bindings.rs index 66f76c9..8c0e4aa 100644 --- a/faiss-sys/src/bindings.rs +++ b/faiss-sys/src/bindings.rs @@ -701,6 +701,21 @@ extern "C" { extern "C" { pub fn faiss_IndexIVFFlat_cast(arg1: *mut FaissIndex) -> *mut FaissIndexIVFFlat; } +extern "C" { + pub fn faiss_IndexIVFFlat_nprobe(arg1: *const FaissIndexIVFFlat) -> usize; +} +extern "C" { + pub fn faiss_IndexIVFFlat_set_nprobe(arg1: *mut FaissIndexIVFFlat, arg2: usize); +} +extern "C" { + pub fn faiss_IndexIVFFlat_own_fields(arg1: *const FaissIndexIVFFlat) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn faiss_IndexIVFFlat_set_own_fields( + arg1: *mut FaissIndexIVFFlat, + arg2: ::std::os::raw::c_int, + ); +} extern "C" { #[doc = " Inverted file with stored vectors. Here the inverted file"] #[doc = " pre-selects the vectors to be searched, but they are not otherwise"] @@ -761,6 +776,9 @@ extern "C" { extern "C" { pub fn faiss_IndexIVF_nprobe(arg1: *const FaissIndexIVF) -> usize; } +extern "C" { + pub fn faiss_IndexIVF_set_nprobe(arg1: *mut FaissIndexIVF, arg2: usize); +} extern "C" { pub fn faiss_IndexIVF_quantizer(arg1: *const FaissIndexIVF) -> *mut FaissIndex; } @@ -772,6 +790,9 @@ extern "C" { extern "C" { pub fn faiss_IndexIVF_own_fields(arg1: *const FaissIndexIVF) -> ::std::os::raw::c_int; } +extern "C" { + pub fn faiss_IndexIVF_set_own_fields(arg1: *mut FaissIndexIVF, arg2: ::std::os::raw::c_int); +} extern "C" { #[doc = " moves the entries from another dataset to self. On output,"] #[doc = " other is empty. add_id is added to all moved ids (for"] diff --git a/src/index/ivf_flat.rs b/src/index/ivf_flat.rs index 9623366..dde5487 100644 --- a/src/index/ivf_flat.rs +++ b/src/index/ivf_flat.rs @@ -31,11 +31,21 @@ impl Drop for IVFFlatIndexImpl { impl IVFFlatIndexImpl { /// Create a new IVF flat index. - pub fn new( + pub fn new_by_ref( quantizer: &flat::FlatIndex, d: u32, nlist: u32, metric: MetricType, + ) -> Result { + IVFFlatIndexImpl::new_helper(quantizer, d, nlist, metric, false) + } + + fn new_helper( + quantizer: &flat::FlatIndex, + d: u32, + nlist: u32, + metric: MetricType, + own_fields: bool, ) -> Result { unsafe { let metric = metric as c_uint; @@ -47,20 +57,52 @@ impl IVFFlatIndexImpl { nlist as usize, metric, ))?; + let own_fields_ = if own_fields { 1 } else { 0 }; + faiss_IndexIVFFlat_set_own_fields(inner, own_fields_); Ok(IVFFlatIndexImpl { inner }) } } + /// Create a new IVF flat index. + // The index owns the quantizer. + pub fn new(quantizer: flat::FlatIndex, d: u32, nlist: u32, metric: MetricType) -> Result { + IVFFlatIndexImpl::new_helper(&quantizer, d, nlist, metric, true) + } + + /// Create a new IVF flat index with L2 as the metric type. + pub fn new_l2_by_ref(quantizer: &flat::FlatIndex, d: u32, nlist: u32) -> Result { + IVFFlatIndexImpl::new_by_ref(quantizer, d, nlist, MetricType::L2) + } + /// Create a new IVF flat index with L2 as the metric type. - pub fn new_l2(quantizer: &flat::FlatIndex, d: u32, nlist: u32) -> Result { + // The index owns the quantizer. + pub fn new_l2(quantizer: flat::FlatIndex, d: u32, nlist: u32) -> Result { IVFFlatIndexImpl::new(quantizer, d, nlist, MetricType::L2) } /// Create a new IVF flat index with IP (inner product) as the metric type. - pub fn new_ip(quantizer: &flat::FlatIndex, d: u32, nlist: u32) -> Result { + pub fn new_ip_by_ref(quantizer: &flat::FlatIndex, d: u32, nlist: u32) -> Result { + IVFFlatIndexImpl::new_by_ref(quantizer, d, nlist, MetricType::InnerProduct) + } + + /// Create a new IVF flat index with IP (inner product) as the metric type. + // The index owns the quantizer. + pub fn new_ip(quantizer: flat::FlatIndex, d: u32, nlist: u32) -> Result { IVFFlatIndexImpl::new(quantizer, d, nlist, MetricType::InnerProduct) } + + /// Get number of probes at query time + pub fn nprobe(&self) -> u32 { + unsafe { faiss_IndexIVFFlat_nprobe(self.inner_ptr()) as u32 } + } + + /// Set number of probes at query time + pub fn set_nprobe(&mut self, value: u32) { + unsafe { + faiss_IndexIVFFlat_set_nprobe(self.inner_ptr(), value as usize); + } + } } impl NativeIndex for IVFFlatIndexImpl { @@ -158,7 +200,7 @@ mod tests { // #[ignore] fn index_search() { let q = FlatIndexImpl::new_l2(D).unwrap(); - let mut index = IVFFlatIndexImpl::new_l2(&q, D, 1).unwrap(); + let mut index = IVFFlatIndexImpl::new_l2_by_ref(&q, D, 1).unwrap(); assert_eq!(index.d(), D); assert_eq!(index.ntotal(), 0); let some_data = &[ @@ -192,7 +234,7 @@ mod tests { #[test] fn index_assign() { let q = FlatIndexImpl::new_l2(D).unwrap(); - let mut index = IVFFlatIndexImpl::new_l2(&q, D, 1).unwrap(); + let mut index = IVFFlatIndexImpl::new_l2_by_ref(&q, D, 1).unwrap(); assert_eq!(index.d(), D); assert_eq!(index.ntotal(), 0); let some_data = &[ From 2f6a7b28b9c4bec29f00be1fd7974a07daa460e0 Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Thu, 6 May 2021 12:38:52 +0300 Subject: [PATCH 5/7] Add some getters for IVFFlatIndexImpl Add: - nlist - train_type sync c_api from PR faiss 1787 --- faiss-sys/src/bindings.rs | 13 ++++++- src/index/ivf_flat.rs | 78 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 89 insertions(+), 2 deletions(-) diff --git a/faiss-sys/src/bindings.rs b/faiss-sys/src/bindings.rs index 8c0e4aa..8f51d34 100644 --- a/faiss-sys/src/bindings.rs +++ b/faiss-sys/src/bindings.rs @@ -704,9 +704,20 @@ extern "C" { extern "C" { pub fn faiss_IndexIVFFlat_nprobe(arg1: *const FaissIndexIVFFlat) -> usize; } +extern "C" { + pub fn faiss_IndexIVFFlat_nlist(arg1: *const FaissIndexIVFFlat) -> usize; +} extern "C" { pub fn faiss_IndexIVFFlat_set_nprobe(arg1: *mut FaissIndexIVFFlat, arg2: usize); } +extern "C" { + pub fn faiss_IndexIVFFlat_quantizer(arg1: *const FaissIndexIVFFlat) -> *mut FaissIndex; +} +extern "C" { + pub fn faiss_IndexIVFFlat_quantizer_trains_alone( + arg1: *const FaissIndexIVFFlat, + ) -> ::std::os::raw::c_char; +} extern "C" { pub fn faiss_IndexIVFFlat_own_fields(arg1: *const FaissIndexIVFFlat) -> ::std::os::raw::c_int; } @@ -850,7 +861,7 @@ extern "C" { pub fn faiss_IndexIVF_get_list_size(index: *const FaissIndexIVF, list_no: usize) -> usize; } extern "C" { - #[doc = " intialize a direct map"] + #[doc = " initialize a direct map"] #[doc = ""] #[doc = " @param new_maintain_direct_map if true, create a direct map,"] #[doc = " else clear it"] diff --git a/src/index/ivf_flat.rs b/src/index/ivf_flat.rs index dde5487..ae31dee 100644 --- a/src/index/ivf_flat.rs +++ b/src/index/ivf_flat.rs @@ -67,7 +67,10 @@ impl IVFFlatIndexImpl { /// Create a new IVF flat index. // The index owns the quantizer. pub fn new(quantizer: flat::FlatIndex, d: u32, nlist: u32, metric: MetricType) -> Result { - IVFFlatIndexImpl::new_helper(&quantizer, d, nlist, metric, true) + let result = IVFFlatIndexImpl::new_helper(&quantizer, d, nlist, metric, true); + std::mem::forget(quantizer); + + result } /// Create a new IVF flat index with L2 as the metric type. @@ -103,6 +106,45 @@ impl IVFFlatIndexImpl { faiss_IndexIVFFlat_set_nprobe(self.inner_ptr(), value as usize); } } + + /// Get number of possible key values + pub fn nlist(&self) -> u32 { + unsafe { faiss_IndexIVFFlat_nlist(self.inner_ptr()) as u32 } + } + + /// Get train type + pub fn train_type(&self) -> Option { + unsafe { + let code = faiss_IndexIVFFlat_quantizer_trains_alone(self.inner_ptr()); + TrainType::from_code(code) + } + } +} + +/** + * = 0: use the quantizer as index in a kmeans training + * = 1: just pass on the training set to the train() of the quantizer + * = 2: kmeans training on a flat index + add the centroids to the quantizer + */ +#[derive(Debug, Clone)] +pub enum TrainType { + /// use the quantizer as index in a kmeans training + QuantizerAsIndex, + /// just pass on the training set to the train() of the quantizer + QuantizerTrainsAlone, + /// kmeans training on a flat index + add the centroids to the quantizer + FlatIndexAndQuantizer, +} + +impl TrainType { + pub(crate) fn from_code(code: i8) -> Option { + match code { + 0 => Some(TrainType::QuantizerAsIndex), + 1 => Some(TrainType::QuantizerTrainsAlone), + 2 => Some(TrainType::FlatIndexAndQuantizer), + _ => None, + } + } } impl NativeIndex for IVFFlatIndexImpl { @@ -231,6 +273,40 @@ mod tests { assert_eq!(index.ntotal(), 0); } + #[test] + fn index_search_own() { + let q = FlatIndexImpl::new_l2(D).unwrap(); + let mut index = IVFFlatIndexImpl::new_l2(q, D, 1).unwrap(); + assert_eq!(index.d(), D); + assert_eq!(index.ntotal(), 0); + let some_data = &[ + 7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 4., + -4., -8., 1., 1., 2., 4., -1., 8., 8., 10., -10., -10., 10., -10., 10., 16., 16., 32., + 25., 20., 20., 40., 15., + ]; + index.train(some_data).unwrap(); + index.add(some_data).unwrap(); + assert_eq!(index.ntotal(), 5); + + let my_query = [0.; D as usize]; + let result = index.search(&my_query, 3).unwrap(); + assert_eq!(result.labels.len(), 3); + assert!(result.labels.into_iter().all(Idx::is_some)); + assert_eq!(result.distances.len(), 3); + assert!(result.distances.iter().all(|x| *x > 0.)); + + let my_query = [100.; D as usize]; + // flat index can be used behind an immutable ref + let result = (&index).search(&my_query, 3).unwrap(); + assert_eq!(result.labels.len(), 3); + assert!(result.labels.into_iter().all(Idx::is_some)); + assert_eq!(result.distances.len(), 3); + assert!(result.distances.iter().all(|x| *x > 0.)); + + index.reset().unwrap(); + assert_eq!(index.ntotal(), 0); + } + #[test] fn index_assign() { let q = FlatIndexImpl::new_l2(D).unwrap(); From 1a85213a2e09a266021c669a3d026abcd596c21b Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Thu, 13 May 2021 21:15:13 +0300 Subject: [PATCH 6/7] Fix ffi ownership mistakes for ivf_flat --- src/index/ivf_flat.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/index/ivf_flat.rs b/src/index/ivf_flat.rs index ae31dee..4806a86 100644 --- a/src/index/ivf_flat.rs +++ b/src/index/ivf_flat.rs @@ -57,9 +57,7 @@ impl IVFFlatIndexImpl { nlist as usize, metric, ))?; - let own_fields_ = if own_fields { 1 } else { 0 }; - faiss_IndexIVFFlat_set_own_fields(inner, own_fields_); - + faiss_IndexIVFFlat_set_own_fields(inner, own_fields as i32); Ok(IVFFlatIndexImpl { inner }) } } @@ -67,10 +65,10 @@ impl IVFFlatIndexImpl { /// Create a new IVF flat index. // The index owns the quantizer. pub fn new(quantizer: flat::FlatIndex, d: u32, nlist: u32, metric: MetricType) -> Result { - let result = IVFFlatIndexImpl::new_helper(&quantizer, d, nlist, metric, true); + let index = IVFFlatIndexImpl::new_helper(&quantizer, d, nlist, metric, true)?; std::mem::forget(quantizer); - result + Ok(index) } /// Create a new IVF flat index with L2 as the metric type. From 06a6063946b1af1a23611ae9fb634cfd028845b6 Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Tue, 1 Jun 2021 14:08:25 +0300 Subject: [PATCH 7/7] Fix request suggestions in IVFFlatIndex --- src/index/ivf_flat.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/index/ivf_flat.rs b/src/index/ivf_flat.rs index 4806a86..2504056 100644 --- a/src/index/ivf_flat.rs +++ b/src/index/ivf_flat.rs @@ -6,6 +6,7 @@ use crate::error::Result; use crate::faiss_try; use std::mem; use std::ptr; +use std::os::raw::c_int; /// Alias for the native implementation of a flat index. pub type IVFFlatIndex = IVFFlatIndexImpl; @@ -57,7 +58,7 @@ impl IVFFlatIndexImpl { nlist as usize, metric, ))?; - faiss_IndexIVFFlat_set_own_fields(inner, own_fields as i32); + faiss_IndexIVFFlat_set_own_fields(inner, c_int::from(own_fields)); Ok(IVFFlatIndexImpl { inner }) } } @@ -124,7 +125,7 @@ impl IVFFlatIndexImpl { * = 1: just pass on the training set to the train() of the quantizer * = 2: kmeans training on a flat index + add the centroids to the quantizer */ -#[derive(Debug, Clone)] + #[derive(Debug, Copy, Clone, Eq, Hash, PartialEq)] pub enum TrainType { /// use the quantizer as index in a kmeans training QuantizerAsIndex,