Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

[Cherry-Pick]Split queries in IVF and BF #703

Merged
merged 1 commit into from
Feb 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions include/knowhere/comp/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

#pragma once

#include <omp.h>

#include <memory>
#include <utility>

Expand Down Expand Up @@ -95,6 +97,18 @@ class ThreadPool {
return pool;
}

class ScopedOmpSetter {
int omp_before;

public:
explicit ScopedOmpSetter(int num_threads = 1) : omp_before(omp_get_num_threads()) {
omp_set_num_threads(num_threads);
}
~ScopedOmpSetter() {
omp_set_num_threads(omp_before);
}
};

private:
std::unique_ptr<ctpl::thread_pool> pool_;
};
Expand Down
309 changes: 195 additions & 114 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "common/range_util.h"
#include "faiss/utils/BinaryDistance.h"
#include "faiss/utils/distances.h"
#include "knowhere/comp/thread_pool.h"
#include "knowhere/config.h"

namespace knowhere {
Expand Down Expand Up @@ -49,51 +50,72 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
auto distances = new float[nq * topk];

auto faiss_metric_type = metric_type.value();
switch (faiss_metric_type) {
case faiss::METRIC_L2: {
faiss::float_maxheap_array_t buf{(size_t)nq, (size_t)topk, labels, distances};
faiss::knn_L2sqr((const float*)xq, (const float*)xb, dim, nq, nb, &buf, nullptr, bitset);
break;
}
case faiss::METRIC_INNER_PRODUCT: {
faiss::float_minheap_array_t buf{(size_t)nq, (size_t)topk, labels, distances};
faiss::knn_inner_product((const float*)xq, (const float*)xb, dim, nq, nb, &buf, bitset);
break;
}
case faiss::METRIC_Jaccard:
case faiss::METRIC_Tanimoto: {
faiss::float_maxheap_array_t res = {size_t(nq), size_t(topk), labels, distances};
binary_distance_knn_hc(faiss::METRIC_Jaccard, &res, (const uint8_t*)xq, (const uint8_t*)xb, nb, dim / 8,
bitset);

if (faiss_metric_type == faiss::METRIC_Tanimoto) {
for (int i = 0; i < topk * nq; i++) {
distances[i] = faiss::Jaccard_2_Tanimoto(distances[i]);
auto pool = ThreadPool::GetGlobalThreadPool();
std::vector<std::future<Status>> futs;
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
futs.push_back(pool->push([&, index = i] {
ThreadPool::ScopedOmpSetter setter(1);
auto cur_labels = labels + topk * index;
auto cur_distances = distances + topk * index;
switch (faiss_metric_type) {
case faiss::METRIC_L2: {
auto cur_query = (const float*)xq + dim * index;
faiss::float_maxheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
faiss::knn_L2sqr(cur_query, (const float*)xb, dim, 1, nb, &buf, nullptr, bitset);
break;
}
case faiss::METRIC_INNER_PRODUCT: {
auto cur_query = (const float*)xq + dim * index;
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
break;
}
case faiss::METRIC_Jaccard:
case faiss::METRIC_Tanimoto: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
faiss::float_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, cur_distances};
binary_distance_knn_hc(faiss::METRIC_Jaccard, &res, cur_query, (const uint8_t*)xb, nb, dim / 8,
bitset);

if (faiss_metric_type == faiss::METRIC_Tanimoto) {
for (int i = 0; i < topk; i++) {
cur_distances[i] = faiss::Jaccard_2_Tanimoto(distances[i]);
}
}
break;
}
case faiss::METRIC_Hamming: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
std::vector<int32_t> int_distances(topk);
faiss::int_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, int_distances.data()};
binary_distance_knn_hc(faiss::METRIC_Hamming, &res, (const uint8_t*)cur_query, (const uint8_t*)xb,
nb, dim / 8, bitset);
for (int i = 0; i < topk; ++i) {
cur_distances[i] = int_distances[i];
}
break;
}
case faiss::METRIC_Substructure:
case faiss::METRIC_Superstructure: {
// only matched ids will be chosen, not to use heap
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
binary_distance_knn_mc(faiss_metric_type, cur_query, (const uint8_t*)xb, 1, nb, topk, dim / 8,
cur_distances, cur_labels, bitset);
break;
}
default:
return Status::invalid_metric_type;
}
break;
}
case faiss::METRIC_Hamming: {
std::vector<int32_t> int_distances(nq * topk);
faiss::int_maxheap_array_t res = {size_t(nq), size_t(topk), labels, int_distances.data()};
binary_distance_knn_hc(faiss::METRIC_Hamming, &res, (const uint8_t*)xq, (const uint8_t*)xb, nb, dim / 8,
bitset);
for (int i = 0; i < nq * topk; ++i) {
distances[i] = int_distances[i];
}
break;
}
case faiss::METRIC_Substructure:
case faiss::METRIC_Superstructure: {
// only matched ids will be chosen, not to use heap
binary_distance_knn_mc(faiss_metric_type, (const uint8_t*)xq, (const uint8_t*)xb, nq, nb, topk, dim / 8,
distances, labels, bitset);
break;
return Status::success;
}));
}
for (auto& fut : futs) {
auto ret = fut.get();
if (ret != Status::success) {
return unexpected(ret);
}
default:
return unexpected(Status::invalid_metric_type);
}

return GenResultDataSet(nq, cfg.k, labels, distances);
}

Expand All @@ -120,49 +142,71 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
auto distances = dis;

auto faiss_metric_type = metric_type.value();
switch (faiss_metric_type) {
case faiss::METRIC_L2: {
faiss::float_maxheap_array_t buf{(size_t)nq, (size_t)topk, labels, distances};
faiss::knn_L2sqr((const float*)xq, (const float*)xb, dim, nq, nb, &buf, nullptr, bitset);
break;
}
case faiss::METRIC_INNER_PRODUCT: {
faiss::float_minheap_array_t buf{(size_t)nq, (size_t)topk, labels, distances};
faiss::knn_inner_product((const float*)xq, (const float*)xb, dim, nq, nb, &buf, bitset);
break;
}
case faiss::METRIC_Jaccard:
case faiss::METRIC_Tanimoto: {
faiss::float_maxheap_array_t res = {size_t(nq), size_t(topk), labels, distances};
binary_distance_knn_hc(faiss::METRIC_Jaccard, &res, (const uint8_t*)xq, (const uint8_t*)xb, nb, dim / 8,
bitset);

if (faiss_metric_type == faiss::METRIC_Tanimoto) {
for (int i = 0; i < topk * nq; i++) {
distances[i] = faiss::Jaccard_2_Tanimoto(distances[i]);
auto pool = ThreadPool::GetGlobalThreadPool();
std::vector<std::future<Status>> futs;
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
futs.push_back(pool->push([&, index = i] {
ThreadPool::ScopedOmpSetter setter(1);
auto cur_labels = labels + topk * index;
auto cur_distances = distances + topk * index;
switch (faiss_metric_type) {
case faiss::METRIC_L2: {
auto cur_query = (const float*)xq + dim * index;
faiss::float_maxheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
faiss::knn_L2sqr(cur_query, (const float*)xb, dim, 1, nb, &buf, nullptr, bitset);
break;
}
case faiss::METRIC_INNER_PRODUCT: {
auto cur_query = (const float*)xq + dim * index;
faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
break;
}
case faiss::METRIC_Jaccard:
case faiss::METRIC_Tanimoto: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
faiss::float_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, cur_distances};
binary_distance_knn_hc(faiss::METRIC_Jaccard, &res, cur_query, (const uint8_t*)xb, nb, dim / 8,
bitset);

if (faiss_metric_type == faiss::METRIC_Tanimoto) {
for (int i = 0; i < topk; i++) {
cur_distances[i] = faiss::Jaccard_2_Tanimoto(distances[i]);
}
}
break;
}
case faiss::METRIC_Hamming: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
std::vector<int32_t> int_distances(topk);
faiss::int_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, int_distances.data()};
binary_distance_knn_hc(faiss::METRIC_Hamming, &res, (const uint8_t*)cur_query, (const uint8_t*)xb,
nb, dim / 8, bitset);
for (int i = 0; i < topk; ++i) {
cur_distances[i] = int_distances[i];
}
break;
}
case faiss::METRIC_Substructure:
case faiss::METRIC_Superstructure: {
// only matched ids will be chosen, not to use heap
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
binary_distance_knn_mc(faiss_metric_type, cur_query, (const uint8_t*)xb, 1, nb, topk, dim / 8,
cur_distances, cur_labels, bitset);
break;
}
default:
return Status::invalid_metric_type;
}
break;
}
case faiss::METRIC_Hamming: {
std::vector<int32_t> int_distances(nq * topk);
faiss::int_maxheap_array_t res = {size_t(nq), size_t(topk), labels, int_distances.data()};
binary_distance_knn_hc(faiss::METRIC_Hamming, &res, (const uint8_t*)xq, (const uint8_t*)xb, nb, dim / 8,
bitset);
for (int i = 0; i < nq * topk; ++i) {
distances[i] = int_distances[i];
}
break;
}
case faiss::METRIC_Substructure:
case faiss::METRIC_Superstructure: {
// only matched ids will be chosen, not to use heap
binary_distance_knn_mc(faiss_metric_type, (const uint8_t*)xq, (const uint8_t*)xb, nq, nb, topk, dim / 8,
distances, labels, bitset);
break;
return Status::success;
}));
}
for (auto& fut : futs) {
auto ret = fut.get();
if (ret != Status::success) {
return ret;
}
default:
return Status::invalid_metric_type;
}
return Status::success;
}
Expand All @@ -189,44 +233,81 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da

auto radius = cfg.radius;
bool is_ip = false;
float range_filter = cfg.range_filter;

faiss::RangeSearchResult res(nq);
auto faiss_metric_type = metric_type.value();
switch (faiss_metric_type) {
case faiss::METRIC_L2:
faiss::range_search_L2sqr((const float*)xq, (const float*)xb, dim, nq, nb, radius, &res, bitset);
break;
case faiss::METRIC_INNER_PRODUCT:
is_ip = true;
faiss::range_search_inner_product((const float*)xq, (const float*)xb, dim, nq, nb, radius, &res, bitset);
break;
case faiss::METRIC_Jaccard:
faiss::binary_range_search<faiss::CMin<float, int64_t>, float>(
faiss::METRIC_Jaccard, (const uint8_t*)xq, (const uint8_t*)xb, nq, nb, radius, dim / 8, &res, bitset);
break;
case faiss::METRIC_Tanimoto:
faiss::binary_range_search<faiss::CMin<float, int64_t>, float>(
faiss::METRIC_Tanimoto, (const uint8_t*)xq, (const uint8_t*)xb, nq, nb, radius, dim / 8, &res, bitset);
break;
case faiss::METRIC_Hamming:
faiss::binary_range_search<faiss::CMin<int, int64_t>, int>(faiss::METRIC_Hamming, (const uint8_t*)xq,
(const uint8_t*)xb, nq, nb, (int)radius, dim / 8,
&res, bitset);
break;
default:
return unexpected(Status::invalid_metric_type);
auto pool = ThreadPool::GetGlobalThreadPool();

std::vector<std::vector<int64_t>> result_id_array(nq);
std::vector<std::vector<float>> result_dist_array(nq);
std::vector<size_t> result_size(nq);
std::vector<size_t> result_lims(nq + 1);
std::vector<std::future<Status>> futs;
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
futs.push_back(pool->push([&, index = i] {
ThreadPool::ScopedOmpSetter setter(1);
faiss::RangeSearchResult res(1);
switch (faiss_metric_type) {
case faiss::METRIC_L2: {
auto cur_query = (const float*)xq + dim * index;
faiss::range_search_L2sqr(cur_query, (const float*)xb, dim, 1, nb, radius, &res, bitset);
break;
}
case faiss::METRIC_INNER_PRODUCT: {
is_ip = true;
auto cur_query = (const float*)xq + dim * index;
faiss::range_search_inner_product(cur_query, (const float*)xb, dim, 1, nb, radius, &res, bitset);
break;
}
case faiss::METRIC_Jaccard: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
faiss::binary_range_search<faiss::CMin<float, int64_t>, float>(
faiss::METRIC_Jaccard, cur_query, (const uint8_t*)xb, 1, nb, radius, dim / 8, &res, bitset);
break;
}
case faiss::METRIC_Tanimoto: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
faiss::binary_range_search<faiss::CMin<float, int64_t>, float>(
faiss::METRIC_Tanimoto, cur_query, (const uint8_t*)xb, 1, nb, radius, dim / 8, &res, bitset);
break;
}
case faiss::METRIC_Hamming: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
faiss::binary_range_search<faiss::CMin<int, int64_t>, int>(faiss::METRIC_Hamming, cur_query,
(const uint8_t*)xb, 1, nb, (int)radius,
dim / 8, &res, bitset);
break;
}
default:
return Status::invalid_metric_type;
}
auto elem_cnt = res.lims[1];
result_dist_array[index].resize(elem_cnt);
result_id_array[index].resize(elem_cnt);
result_size[index] = elem_cnt;
for (size_t j = 0; j < elem_cnt; j++) {
result_dist_array[index][j] = res.distances[j];
result_id_array[index][j] = res.labels[j];
}
if (cfg.range_filter != defaultRangeFilter) {
FilterRangeSearchResultForOneNq(result_dist_array[index], result_id_array[index], is_ip, radius,
range_filter);
}
return Status::success;
}));
}
for (auto& fut : futs) {
auto ret = fut.get();
if (ret != Status::success) {
return unexpected(ret);
}
}

int64_t* labels = nullptr;
int64_t* ids = nullptr;
float* distances = nullptr;
size_t* lims = nullptr;

if (cfg.range_filter != defaultRangeFilter) {
GetRangeSearchResult(res, is_ip, nq, radius, cfg.range_filter, distances, labels, lims, bitset);
} else {
GetRangeSearchResult(res, is_ip, nq, radius, distances, labels, lims);
}

return GenResultDataSet(nq, labels, distances, lims);
GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius, range_filter, distances, ids, lims);
return GenResultDataSet(nq, ids, distances, lims);
}
} // namespace knowhere
Loading