Skip to content

Commit

Permalink
New solution to handle IVF_FLAT backward compatibility
Browse files Browse the repository at this point in the history
Signed-off-by: Yudong Cai <yudong.cai@zilliz.com>
  • Loading branch information
cydrain committed Sep 12, 2023
1 parent 5f337f9 commit 3799781
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 210 deletions.
53 changes: 32 additions & 21 deletions src/common/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,28 +83,39 @@ ConvertIVFFlatIfNeeded(const BinarySet& binset, const uint8_t* raw_data, const s

MemoryIOReader reader(binary->data.get(), binary->size);

// there are 2 possibilities for the input index binary:
// 1. native IVF_FLAT, do nothing
// 2. IVF_FLAT_NM, convert to native IVF_FLAT
try {
// try to parse as native format, if it's actually _NM format,
// faiss will raise a "read error" exception for IVF_FLAT_NM format
faiss::read_index(&reader);
} catch (faiss::FaissException& e) {
reader.reset();

// convert IVF_FLAT_NM to native IVF_FLAT
auto* index = static_cast<faiss::IndexIVFFlat*>(faiss::read_index_nm(&reader));
index->restore_codes(raw_data, raw_size);

// over-write IVF_FLAT_NM binary with native IVF_FLAT binary
MemoryIOWriter writer;
faiss::write_index(index, &writer);
std::shared_ptr<uint8_t[]> data(writer.data());
binary->data = data;
binary->size = writer.tellg();

LOG_KNOWHERE_INFO_ << "Convert IVF_FLAT_NM to native IVF_FLAT";
uint32_t h;
reader.read(&h, sizeof(h), 1);

// only read IVF_FLAT index header
faiss::IndexIVFFlat* ivfl = new faiss::IndexIVFFlat();
faiss::read_ivf_header(ivfl, &reader);
ivfl->code_size = ivfl->d * sizeof(float);

auto remains = binary->size - reader.tellg() - sizeof(uint32_t) - sizeof(ivfl->invlists->nlist) -
sizeof(ivfl->invlists->code_size);
auto invlist_size = sizeof(uint32_t) + sizeof(size_t) + ivfl->nlist * sizeof(size_t);
auto ids_size = ivfl->ntotal * sizeof(faiss::Index::idx_t);
// auto codes_size = ivfl->d * ivfl->ntotal * sizeof(float);

// IVF_FLAT_NM format, need convert to new format
if (remains == invlist_size + ids_size) {
faiss::read_InvertedLists_nm(ivfl, &reader);
ivfl->restore_codes(raw_data, raw_size);

// over-write IVF_FLAT_NM binary with native IVF_FLAT binary
MemoryIOWriter writer;
faiss::write_index(ivfl, &writer);
std::shared_ptr<uint8_t[]> data(writer.data());
binary->data = data;
binary->size = writer.tellg();

LOG_KNOWHERE_INFO_ << "Convert IVF_FLAT_NM to native IVF_FLAT, rows " << ivfl->ntotal << ", dim "
<< ivfl->d;
}
} catch (...) {
// not IVF_FLAT_NM format, do nothing
return;
}
}

Expand Down
3 changes: 3 additions & 0 deletions thirdparty/faiss/faiss/Index.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ struct Index {
idx_t ntotal; ///< total nb of indexed vectors
bool verbose; ///< verbosity level

/// both IP and COSINE are regarded as INNER_PRODUCT in faiss
bool is_cosine;

/// set if the Index does not require training, or if training is
/// done already
bool is_trained;
Expand Down
7 changes: 4 additions & 3 deletions thirdparty/faiss/faiss/IndexIVFFlat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ IndexIVFFlat::IndexIVFFlat(
size_t nlist,
bool is_cosine,
MetricType metric)
: is_cosine_(is_cosine), IndexIVF(quantizer, d, nlist, sizeof(float) * d, metric) {
: IndexIVF(quantizer, d, nlist, sizeof(float) * d, metric) {
this->is_cosine = is_cosine;
code_size = sizeof(float) * d;
replace_invlists(new ArrayInvertedLists(nlist, code_size, is_cosine), true);
}
Expand All @@ -50,7 +51,7 @@ void IndexIVFFlat::restore_codes(
}

void IndexIVFFlat::train(idx_t n, const float* x) {
if (is_cosine_) {
if (is_cosine) {
auto x_normalized = knowhere::CopyAndNormalizeVecs(x, n, d);
// use normalized data to train codes for cosine
IndexIVF::train(n, x_normalized.get());
Expand All @@ -61,7 +62,7 @@ void IndexIVFFlat::train(idx_t n, const float* x) {

void IndexIVFFlat::add_with_ids(idx_t n, const float* x, const idx_t* xids) {
std::unique_ptr<idx_t[]> coarse_idx(new idx_t[n]);
if (is_cosine_) {
if (is_cosine) {
auto x_normalized = std::make_unique<float[]>(n * d);
std::memcpy(x_normalized.get(), x, n * d * sizeof(float));
auto norms = knowhere::NormalizeVecs(x_normalized.get(), n, d);
Expand Down
3 changes: 0 additions & 3 deletions thirdparty/faiss/faiss/IndexIVFFlat.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ struct IndexIVFFlat : IndexIVF {
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;

IndexIVFFlat() {}

protected:
bool is_cosine_ = false;
};

struct IndexIVFFlatCC : IndexIVFFlat {
Expand Down
146 changes: 59 additions & 87 deletions thirdparty/faiss/faiss/impl/index_read.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,17 @@ namespace faiss {
static void read_index_header(Index* idx, IOReader* f) {
READ1(idx->d);
READ1(idx->ntotal);
READ1(idx->is_cosine);

uint8_t dummy8;
READ1(dummy8);
uint16_t dummy16;
READ1(dummy16);
uint32_t dummy32;
READ1(dummy32);
Index::idx_t dummy;
READ1(dummy);
READ1(dummy);

READ1(idx->is_trained);
READ1(idx->metric_type);
if (idx->metric_type > 1) {
Expand Down Expand Up @@ -213,27 +221,65 @@ InvertedLists* read_InvertedLists(IOReader* f, int io_flags) {
READANDCHECK(ails->readonly_codes.data(), n * code_size);
#endif
return ails;
} else if (h == fourcc("ilca")) {
size_t nlist, code_size, segment_size;
READ1(nlist);
READ1(code_size);
READ1(segment_size);

bool save_norm = io_flags & IO_FLAG_WITH_NORM;
auto lca = new ConcurrentArrayInvertedLists(nlist, code_size, segment_size, save_norm);
std::vector<size_t> sizes(nlist);
read_ArrayInvertedLists_sizes(f, sizes);
for (size_t i = 0; i < lca->nlist; i++) {
lca->resize(i, sizes[i]);
}
for (size_t i = 0; i < lca->nlist; i++) {
size_t n = lca->list_size(i);
if (n > 0) {
size_t seg_num = lca->get_segment_num(i);
for (size_t j = 0; j < seg_num; j++) {
size_t seg_size = lca->get_segment_size(i , j);
size_t seg_off = lca->get_segment_offset(i, j);
READANDCHECK(lca->codes[i][j].data_.data(), seg_size * lca->code_size);
READANDCHECK(lca->ids[i][j].data_.data(), seg_size);
if (save_norm) {
READANDCHECK(lca->code_norms[i][j].data_.data(), seg_size);
}
}
}
}
return lca;
} else if (h == fourcc("ilar") && !(io_flags & IO_FLAG_SKIP_IVF_DATA)) {
auto ails = new ArrayInvertedLists(0, 0);
READ1(ails->nlist);
READ1(ails->code_size);
ails->with_norm = io_flags & IO_FLAG_WITH_NORM;
ails->ids.resize(ails->nlist);
ails->codes.resize(ails->nlist);
if (ails->with_norm) {
ails->code_norms.resize(ails->nlist);
}
std::vector<size_t> sizes(ails->nlist);
read_ArrayInvertedLists_sizes(f, sizes);
for (size_t i = 0; i < ails->nlist; i++) {
ails->ids[i].resize(sizes[i]);
ails->codes[i].resize(sizes[i] * ails->code_size);
if (ails->with_norm) {
ails->code_norms[i].resize(sizes[i]);
}
}
for (size_t i = 0; i < ails->nlist; i++) {
size_t n = ails->ids[i].size();
if (n > 0) {
READANDCHECK(ails->codes[i].data(), n * ails->code_size);
READANDCHECK(ails->ids[i].data(), n);
if (ails->with_norm) {
READANDCHECK(ails->code_norms[i].data(), n);
}
}
}
return ails;

} else if (h == fourcc("ilar") && (io_flags & IO_FLAG_SKIP_IVF_DATA)) {
// code is always ilxx where xx is specific to the type of invlists we
// want so we get the 16 high bits from the io_flag and the 16 low bits
Expand Down Expand Up @@ -327,94 +373,14 @@ InvertedLists *read_InvertedLists_nm(IOReader *f, int io_flags) {
}
}

static void read_InvertedLists_nm(IndexIVF *ivf, IOReader *f, int io_flags) {
void read_InvertedLists_nm(IndexIVF *ivf, IOReader *f, int io_flags) {
InvertedLists *ils = read_InvertedLists_nm (f, io_flags);
FAISS_THROW_IF_NOT(!ils || (ils->nlist == ivf->nlist &&
ils->code_size == ivf->code_size));
ivf->invlists = ils;
ivf->own_invlists = true;
}

InvertedLists* read_InvertedLists_with_norm(IOReader* f, int io_flags) {
uint32_t h;
READ1(h);
if (h == fourcc("ilca")) {
size_t nlist, code_size, segment_size;
bool save_norm;
READ1(nlist);
READ1(code_size);
READ1(segment_size);
READ1(save_norm);

auto lca = new ConcurrentArrayInvertedLists(nlist, code_size, segment_size, save_norm);
std::vector<size_t> sizes(nlist);
read_ArrayInvertedLists_sizes(f, sizes);
for (size_t i = 0; i < lca->nlist; i++) {
lca->resize(i, sizes[i]);
}
for (size_t i = 0; i < lca->nlist; i++) {
size_t n = lca->list_size(i);
if (n > 0) {
size_t seg_num = lca->get_segment_num(i);
for (size_t j = 0; j < seg_num; j++) {
size_t seg_size = lca->get_segment_size(i , j);
size_t seg_off = lca->get_segment_offset(i, j);
READANDCHECK(lca->codes[i][j].data_.data(), seg_size * lca->code_size);
READANDCHECK(lca->ids[i][j].data_.data(), seg_size);
if (save_norm) {
READANDCHECK(lca->code_norms[i][j].data_.data(), seg_size);
}
}
}
}
return lca;
} else if (h == fourcc("ilar") && !(io_flags & IO_FLAG_SKIP_IVF_DATA)) {
auto ails = new ArrayInvertedLists(0, 0);
READ1(ails->nlist);
READ1(ails->code_size);
READ1(ails->with_norm);
ails->ids.resize(ails->nlist);
ails->codes.resize(ails->nlist);
if (ails->with_norm) {
ails->code_norms.resize(ails->nlist);
}
std::vector<size_t> sizes(ails->nlist);
read_ArrayInvertedLists_sizes(f, sizes);
for (size_t i = 0; i < ails->nlist; i++) {
ails->ids[i].resize(sizes[i]);
ails->codes[i].resize(sizes[i] * ails->code_size);
if (ails->with_norm) {
ails->code_norms[i].resize(sizes[i]);
}
}
for (size_t i = 0; i < ails->nlist; i++) {
size_t n = ails->ids[i].size();
if (n > 0) {
READANDCHECK(ails->codes[i].data(), n * ails->code_size);
READANDCHECK(ails->ids[i].data(), n);
if (ails->with_norm) {
READANDCHECK(ails->code_norms[i].data(), n);
}
}
}
return ails;
} else {
return InvertedListsIOHook::lookup(h)->read(f, io_flags);
}
}

static void read_InvertedLists_with_norm(IndexIVF* ivf, IOReader* f, int io_flags) {
InvertedLists* ils = read_InvertedLists_with_norm(f, io_flags);
if (ils) {
FAISS_THROW_IF_NOT(ils->nlist == ivf->nlist);
FAISS_THROW_IF_NOT(
ils->code_size == InvertedLists::INVALID_CODE_SIZE ||
ils->code_size == ivf->code_size);
}
ivf->invlists = ils;
ivf->own_invlists = true;
}

static void read_ProductQuantizer(ProductQuantizer* pq, IOReader* f) {
READ1(pq->d);
READ1(pq->M);
Expand Down Expand Up @@ -566,10 +532,10 @@ static void read_direct_map(DirectMap* dm, IOReader* f) {
}
}

static void read_ivf_header(
void read_ivf_header(
IndexIVF* ivf,
IOReader* f,
std::vector<std::vector<Index::idx_t>>* ids = nullptr) {
std::vector<std::vector<Index::idx_t>>* ids) {
read_index_header(ivf, f);
READ1(ivf->nlist);
READ1(ivf->nprobe);
Expand Down Expand Up @@ -771,13 +737,19 @@ Index* read_index(IOReader* f, int io_flags) {
IndexIVFFlatCC* ivf_cc = new IndexIVFFlatCC();
read_ivf_header(ivf_cc, f);
ivf_cc->code_size = ivf_cc->d * sizeof(float);
read_InvertedLists_with_norm(ivf_cc, f, io_flags);
if (ivf_cc->is_cosine) {
io_flags |= IO_FLAG_WITH_NORM;
}
read_InvertedLists(ivf_cc, f, io_flags);
idx = ivf_cc;
} else if (h == fourcc("IwFl")) {
IndexIVFFlat* ivfl = new IndexIVFFlat();
read_ivf_header(ivfl, f);
ivfl->code_size = ivfl->d * sizeof(float);
read_InvertedLists_with_norm(ivfl, f, io_flags);
if (ivfl->is_cosine) {
io_flags |= IO_FLAG_WITH_NORM;
}
read_InvertedLists(ivfl, f, io_flags);
idx = ivfl;
} else if (h == fourcc("IxSQ")) {
IndexScalarQuantizer* idxs = new IndexScalarQuantizer();
Expand Down
Loading

0 comments on commit 3799781

Please sign in to comment.