forked from technicolor-research/faiss-quickeradc
-
Notifications
You must be signed in to change notification settings - Fork 3
/
IndexVPQ.h
180 lines (137 loc) · 5.04 KB
/
IndexVPQ.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
/**
* Copyright (c) 2018-present, Thomson Licensing, SAS.
* Copyright (c) 2015-present, Facebook, Inc.
* All rights reserved.
*
* Modifications related the introduction of Quicker ADC (Vectorized Product Quantization)
* are licensed under the Clear BSD license found in the LICENSE file in the root directory
* of this source tree.
*
* The rest of the source code is licensed under the BSD+Patents license found in the
* LICENSE file in the root directory of this source tree
*/
// Copyright 2004-present Facebook. All Rights Reserved.
// -*- c++ -*-
#ifndef FAISS_INDEX_VPQ_H
#define FAISS_INDEX_VPQ_H
#include <stdint.h>
#include <vector>
#include "Index.h"
#include "VecProductQuantizer.h"
#include <boost/align/aligned_allocator.hpp>
namespace faiss {
/// statistics are robust to internal threading, but not if
/// IndexPQ::search is called by multiple threads
struct IndexVPQStats {
size_t nq; // nb of queries run
size_t ncode; // nb of codes visited
IndexVPQStats () {reset (); }
void reset ();
};
extern IndexVPQStats indexVPQ_stats;
struct AbstractIndexVPQ {
int initial_scan_estim_param;
virtual ~AbstractIndexVPQ() = default;
};
/** Index based on a product quantizer. Stored vectors are
* approximated by PQ codes. */
template<class T_VPQ>
struct IndexVPQ: Index, AbstractIndexVPQ {
typedef T_VPQ VPQ_t;
/// The product quantizer used to encode the vectors
T_VPQ pq;
typedef typename T_VPQ::group groupt;
/// Codes. Size ntotal * pq.code_size
std::vector<groupt, boost::alignment::aligned_allocator<groupt, 64>> codes;
/** Constructor.
*
* @param d dimensionality of the input vectors
* @param M number of subquantizers
* @param nbits number of bit per subvector index
*/
IndexVPQ (int d, ///< dimensionality of the input vectors
MetricType metric = METRIC_L2):
Index(d, metric), pq(d)
{
is_trained = false;
search_type = ST_PQ;
this->initial_scan_estim_param=pq.get_initial_scan_estim_parameter();
}
IndexVPQ () {
metric_type = METRIC_L2;
is_trained = false;
search_type = ST_PQ;
}
void train(idx_t n, const float* x) override {
pq.train(n, x);
is_trained = true;
}
void add(idx_t n, const float* x) override {
FAISS_THROW_IF_NOT (is_trained);
codes.resize (pq.nb_groups(n + ntotal));
idx_t already_added=0;
while(n>0){
idx_t count = n < 10000L*pq.codes_per_group ? n : 10000L*pq.codes_per_group ;
// ntotal is offset where the codes from x are added.
pq.encode_multiple(x+already_added*d, codes.data(), ntotal, count);
ntotal += count;
n -= count;
already_added += count;
}
}
void search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const override {
FAISS_THROW_IF_NOT (is_trained);
if (search_type == ST_PQ) { // Simple PQ search
if (metric_type == METRIC_L2) {
float_maxheap_array_t res = {
size_t(n), size_t(k), labels, distances };
pq.search (x, n, codes.data(), ntotal, &res, true, this->initial_scan_estim_param);
} else {
float_minheap_array_t res = {
size_t(n), size_t(k), labels, distances };
pq.search_ip (x, n, codes.data(), ntotal, &res, true);
}
indexVPQ_stats.nq += n;
indexVPQ_stats.ncode += n * ntotal;
} else if(search_type == ST_SDC){ // code-to-code distances
std::vector<groupt, boost::alignment::aligned_allocator<groupt, 64>> q_codes_v;
q_codes_v.resize(pq.nb_groups(n));
groupt * q_codes = q_codes_v.data();
pq.encode_multiple(x, q_codes, 0, n);
float_maxheap_array_t res = {
size_t(n), size_t(k), labels, distances};
pq.search_sdc (q_codes, n, codes.data(), ntotal, &res, true);
indexVPQ_stats.nq += n;
indexVPQ_stats.ncode += n * ntotal;
}
}
void reset() override {
codes.clear();
ntotal = 0;
}
void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override {
FAISS_THROW_IF_NOT (ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
for (idx_t i = 0; i < ni; i++) {
pq.decode (codes.data(), recons + i * d,i);
}
}
void reconstruct(idx_t key, float* recons) const override {
FAISS_THROW_IF_NOT (key >= 0 && key < ntotal);
pq.decode (codes.data(), recons, key);
}
/// how to perform the search in search_core
enum Search_type_t {
ST_PQ, ///< asymmetric product quantizer (default)
ST_SDC, ///< symmetric product quantizer (SDC)
};
Search_type_t search_type;
};
template <class T>
inline std::string fourcc_vpq(const IndexVPQ<T>* n){return "j"+cc_vpq((T*)NULL);}
} // namespace faiss
#endif