Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Implemented the Streaming Feature to stream vectors from Java to JNI layer to enable creation of larger segments for vector indices (#1604) #1608

Merged
merged 1 commit into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Enhancements
* Make the HitQueue size more appropriate for exact search [#1549](https://github.com/opensearch-project/k-NN/pull/1549)
* Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573)
* Implemented the Streaming Feature to stream vectors from Java to JNI layer to enable creation of larger segments for vector indices [#1604](https://github.com/opensearch-project/k-NN/pull/1604)
### Bug Fixes
### Infrastructure
* Add micro-benchmark module in k-NN plugin for benchmark streaming vectors to JNI layer functionality. [#1583](https://github.com/opensearch-project/k-NN/pull/1583)
Expand Down
3 changes: 2 additions & 1 deletion jni/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ endif()
# ----------------------------------------------------------------------------

# ---------------------------------- COMMON ----------------------------------
add_library(${TARGET_LIB_COMMON} SHARED ${CMAKE_CURRENT_SOURCE_DIR}/src/jni_util.cpp)
add_library(${TARGET_LIB_COMMON} SHARED ${CMAKE_CURRENT_SOURCE_DIR}/src/jni_util.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/org_opensearch_knn_jni_JNICommons.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/commons.cpp)
target_include_directories(${TARGET_LIB_COMMON} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include $ENV{JAVA_HOME}/include $ENV{JAVA_HOME}/include/${JVM_OS_TYPE})
set_target_properties(${TARGET_LIB_COMMON} PROPERTIES SUFFIX ${LIB_EXT})
set_target_properties(${TARGET_LIB_COMMON} PROPERTIES POSITION_INDEPENDENT_CODE ON)
Expand Down Expand Up @@ -236,6 +236,7 @@ if ("${WIN32}" STREQUAL "")
tests/faiss_util_test.cpp
tests/nmslib_wrapper_test.cpp
tests/test_util.cpp
tests/commons_test.cpp
)

target_link_libraries(
Expand Down
37 changes: 37 additions & 0 deletions jni/include/commons.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/
#include "jni_util.h"
#include <jni.h>
namespace knn_jni {
namespace commons {
/**
* This is utility function that can be used to store data in native memory. This function will allocate memory for
* the data(rows*columns) with initialCapacity and return the memory address where the data is stored.
* If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created.
* For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location
* will throw Exception.
*
* @param memoryAddress The address of the memory location where data will be stored.
* @param data 2D float array containing data to be stored in native memory.
* @param initialCapacity The initial capacity of the memory location.
* @return memory address where the data is stored.
*/
jlong storeVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong);

/**
* Free up the memory allocated for the data stored in memory address. This function should be used with the memory
* address returned by {@link JNICommons#storeVectorData(long, float[][], long, long)}
*
* @param memoryAddress address to be freed.
*/
void freeVectorData(jlong);
}
}
4 changes: 2 additions & 2 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ namespace knn_jni {
namespace faiss_wrapper {
// Create an index with ids and vectors. The configuration is defined by values in the Java map, parametersJ.
// The index is serialized to indexPathJ.
void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jobjectArray vectorsJ,
void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ,
jstring indexPathJ, jobject parametersJ);

// Create an index with ids and vectors. Instead of creating a new index, this function creates the index
// based off of the template index passed in. The index is serialized to indexPathJ.
void CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jobjectArray vectorsJ, jstring indexPathJ, jbyteArray templateIndexJ,
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ,
jobject parametersJ);

// Load an index from indexPathJ into memory.
Expand Down
4 changes: 4 additions & 0 deletions jni/include/jni_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ namespace knn_jni {
virtual std::vector<float> Convert2dJavaObjectArrayToCppFloatVector(JNIEnv *env, jobjectArray array2dJ,
int dim) = 0;

virtual void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ,
int dim, std::vector<float> *vect ) = 0;

virtual std::vector<int64_t> ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ) = 0;

// --------------------------------------------------------------------------
Expand Down Expand Up @@ -164,6 +167,7 @@ namespace knn_jni {
void ReleaseLongArrayElements(JNIEnv *env, jlongArray array, jlong *elems, jint mode);
void SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val);
void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf);
void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector<float> *vect);

private:
std::unordered_map<std::string, jclass> cachedClasses;
Expand Down
2 changes: 1 addition & 1 deletion jni/include/nmslib_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace knn_jni {
namespace nmslib_wrapper {
// Create an index with ids and vectors. The configuration is defined by values in the Java map, parametersJ.
// The index is serialized to indexPathJ.
void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jobjectArray vectorsJ,
void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddress, jint dim,
jstring indexPathJ, jobject parametersJ);

// Load an index from indexPathJ into memory. Use parametersJ to set any query time parameters
Expand Down
24 changes: 4 additions & 20 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@ extern "C" {
/*
* Class: org_opensearch_knn_jni_FaissService
* Method: createIndex
* Signature: ([I[[FLjava/lang/String;Ljava/util/Map;)V
* Signature: ([IJILjava/lang/String;Ljava/util/Map;)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex
(JNIEnv *, jclass, jintArray, jobjectArray, jstring, jobject);
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: createIndexFromTemplate
* Signature: ([I[[FLjava/lang/String;[BLjava/util/Map;)V
* Signature: ([IJILjava/lang/String;[BLjava/util/Map;)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate
(JNIEnv *, jclass, jintArray, jobjectArray, jstring, jbyteArray, jobject);
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
Expand Down Expand Up @@ -122,22 +122,6 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors
(JNIEnv *, jclass, jlong, jobjectArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: transferVectorsV2
* Signature: (J[[F)J
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectorsV2
(JNIEnv *, jclass, jlong, jobjectArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: freeVectors
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_freeVectors
(JNIEnv *, jclass, jlong);

#ifdef __cplusplus
}
#endif
Expand Down
40 changes: 40 additions & 0 deletions jni/include/org_opensearch_knn_jni_JNICommons.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

/* DO NOT EDIT THIS FILE - it is machine generated */
#include <jni.h>
/* Header for class org_opensearch_knn_jni_JNICommons */

#ifndef _Included_org_opensearch_knn_jni_JNICommons
#define _Included_org_opensearch_knn_jni_JNICommons
#ifdef __cplusplus
extern "C" {
#endif
/*
* Class: org_opensearch_knn_jni_JNICommons
* Method: storeVectorData
* Signature: (J[[FJJ)
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData
(JNIEnv *, jclass, jlong, jobjectArray, jlong);

/*
* Class: org_opensearch_knn_jni_JNICommons
* Method: freeVectorData
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeVectorData
(JNIEnv *, jclass, jlong);

#ifdef __cplusplus
}
#endif
#endif
4 changes: 2 additions & 2 deletions jni/include/org_opensearch_knn_jni_NmslibService.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ extern "C" {
/*
* Class: org_opensearch_knn_jni_NmslibService
* Method: createIndex
* Signature: ([I[[FLjava/lang/String;Ljava/util/Map;)V
* Signature: ([IJILjava/lang/String;Ljava/util/Map;)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_NmslibService_createIndex
(JNIEnv *, jclass, jintArray, jobjectArray, jstring, jobject);
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject);

/*
* Class: org_opensearch_knn_jni_NmslibService
Expand Down
41 changes: 41 additions & 0 deletions jni/src/commons.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/
#ifndef OPENSEARCH_KNN_COMMONS_H
#define OPENSEARCH_KNN_COMMONS_H
#include <jni.h>

#include <vector>

#include "jni_util.h"
#include "commons.h"

jlong knn_jni::commons::storeVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong memoryAddressJ,
jobjectArray dataJ, jlong initialCapacityJ) {
std::vector<float> *vect;
if ((long) memoryAddressJ == 0) {
vect = new std::vector<float>();
vect->reserve((long)initialCapacityJ);
} else {
vect = reinterpret_cast<std::vector<float>*>(memoryAddressJ);
}
int dim = jniUtil->GetInnerDimensionOf2dJavaFloatArray(env, dataJ);
jniUtil->Convert2dJavaObjectArrayAndStoreToFloatVector(env, dataJ, dim, vect);

return (jlong) vect;
}

void knn_jni::commons::freeVectorData(jlong memoryAddressJ) {
if (memoryAddressJ != 0) {
auto *vect = reinterpret_cast<std::vector<float>*>(memoryAddressJ);
delete vect;
}
}
#endif //OPENSEARCH_KNN_COMMONS_H
57 changes: 38 additions & 19 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,19 @@ bool isIndexIVFPQL2(faiss::Index * index);
// IndexIDMap which has member that will point to underlying index that stores the data
faiss::IndexIVFPQ * extractIVFPQIndex(faiss::Index * index);

void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jobjectArray vectorsJ, jstring indexPathJ, jobject parametersJ) {
void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ,
jstring indexPathJ, jobject parametersJ) {

if (idsJ == nullptr) {
throw std::runtime_error("IDs cannot be null");
}

if (vectorsJ == nullptr) {
throw std::runtime_error("Vectors cannot be null");
if (vectorsAddressJ <= 0) {
throw std::runtime_error("VectorsAddress cannot be less than 0");
}

if(dimJ <= 0) {
throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0");
}

if (indexPathJ == nullptr) {
Expand All @@ -109,16 +113,20 @@ void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JN
std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ));
faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp);

// Read data set
int numVectors = jniUtil->GetJavaObjectArrayLength(env, vectorsJ);
// Read vectors from memory address
auto *inputVectors = reinterpret_cast<std::vector<float>*>(vectorsAddressJ);
int dim = (int)dimJ;
// The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value
int numVectors = (int) (inputVectors->size() / (uint64_t) dim);
if(numVectors == 0) {
throw std::runtime_error("Number of vectors cannot be 0");
}

int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ);
if (numIds != numVectors) {
throw std::runtime_error("Number of IDs does not match number of vectors");
}

int dim = jniUtil->GetInnerDimensionOf2dJavaFloatArray(env, vectorsJ);
auto dataset = jniUtil->Convert2dJavaObjectArrayToCppFloatVector(env, vectorsJ, dim);

// Create faiss index
jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION);
std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ));
Expand Down Expand Up @@ -148,22 +156,30 @@ void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JN

auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ);
faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get());
idMap.add_with_ids(numVectors, dataset.data(), idVector.data());
idMap.add_with_ids(numVectors, inputVectors->data(), idVector.data());

// Write the index to disk
std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ));
faiss::write_index(&idMap, indexPathCpp.c_str());
// Releasing the vectorsAddressJ memory as that is not required once we have created the index.
// This is not the ideal approach, please refer this gh issue for long term solution:
// https://github.com/opensearch-project/k-NN/issues/1600
delete inputVectors;
}

void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jobjectArray vectorsJ, jstring indexPathJ,
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ,
jbyteArray templateIndexJ, jobject parametersJ) {
if (idsJ == nullptr) {
throw std::runtime_error("IDs cannot be null");
}

if (vectorsJ == nullptr) {
throw std::runtime_error("Vectors cannot be null");
if (vectorsAddressJ <= 0) {
throw std::runtime_error("VectorsAddress cannot be less than 0");
}

if(dimJ <= 0) {
throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0");
}

if (indexPathJ == nullptr) {
Expand All @@ -183,15 +199,15 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface *
jniUtil->DeleteLocalRef(env, parametersJ);

// Read data set
int numVectors = jniUtil->GetJavaObjectArrayLength(env, vectorsJ);
// Read vectors from memory address
auto *inputVectors = reinterpret_cast<std::vector<float>*>(vectorsAddressJ);
int dim = (int)dimJ;
int numVectors = (int) (inputVectors->size() / (uint64_t) dim);
int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ);
if (numIds != numVectors) {
throw std::runtime_error("Number of IDs does not match number of vectors");
}

int dim = jniUtil->GetInnerDimensionOf2dJavaFloatArray(env, vectorsJ);
auto dataset = jniUtil->Convert2dJavaObjectArrayToCppFloatVector(env, vectorsJ, dim);

// Get vector of bytes from jbytearray
int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ);
jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr);
Expand All @@ -208,8 +224,11 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface *

auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ);
faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get());
idMap.add_with_ids(numVectors, dataset.data(), idVector.data());

idMap.add_with_ids(numVectors, inputVectors->data(), idVector.data());
// Releasing the vectorsAddressJ memory as that is not required once we have created the index.
// This is not the ideal approach, please refer this gh issue for long term solution:
// https://github.com/opensearch-project/k-NN/issues/1600
delete inputVectors;
// Write the index to disk
std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ));
faiss::write_index(&idMap, indexPathCpp.c_str());
Expand Down
Loading
Loading