Skip to content

Commit

Permalink
Implemented the Streaming Feature to stream vectors from Java to JNI …
Browse files Browse the repository at this point in the history
…layer to enable creation of larger segments for vector indices

Changes include:
1. Add the interface for streaming the vectors from java to jni layer with initial capacity (opensearch-project#1586)
2. Integrating storeVectors interfaces with createIndex and createIndexTemplate functions. (opensearch-project#1588)
3. Update KNN80BinaryDocValues reader count live docs and use live docs as initial capacity to initialize vector address(opensearch-project#1595)

Signed-off-by: Navneet Verma <navneev@amazon.com>
  • Loading branch information
navneet1v committed Apr 9, 2024
1 parent 5de10fc commit 1fc15f1
Show file tree
Hide file tree
Showing 41 changed files with 980 additions and 388 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Enhancements
* Make the HitQueue size more appropriate for exact search [#1549](https://github.com/opensearch-project/k-NN/pull/1549)
* Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573)
* Implemented the Streaming Feature to stream vectors from Java to JNI layer to enable creation of larger segments for vector indices []()
### Bug Fixes
### Infrastructure
* Add micro-benchmark module in k-NN plugin for benchmark streaming vectors to JNI layer functionality. [#1583](https://github.com/opensearch-project/k-NN/pull/1583)
Expand Down
37 changes: 37 additions & 0 deletions jni/include/commons.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/
#include "jni_util.h"
#include <jni.h>
namespace knn_jni {
namespace commons {
/**
* This is utility function that can be used to store data in native memory. This function will allocate memory for
* the data(rows*columns) with initialCapacity and return the memory address where the data is stored.
* If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created.
* For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location
* will throw Exception.
*
* @param memoryAddress The address of the memory location where data will be stored.
* @param data 2D float array containing data to be stored in native memory.
* @param initialCapacity The initial capacity of the memory location.
* @return memory address where the data is stored.
*/
jlong storeVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong);

/**
* Free up the memory allocated for the data stored in memory address. This function should be used with the memory
* address returned by {@link JNICommons#storeVectorData(long, float[][], long, long)}
*
* @param memoryAddress address to be freed.
*/
void freeVectorData(jlong);
}
}
4 changes: 2 additions & 2 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ namespace knn_jni {
namespace faiss_wrapper {
// Create an index with ids and vectors. The configuration is defined by values in the Java map, parametersJ.
// The index is serialized to indexPathJ.
void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jobjectArray vectorsJ,
void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ,
jstring indexPathJ, jobject parametersJ);

// Create an index with ids and vectors. Instead of creating a new index, this function creates the index
// based off of the template index passed in. The index is serialized to indexPathJ.
void CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jobjectArray vectorsJ, jstring indexPathJ, jbyteArray templateIndexJ,
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ,
jobject parametersJ);

// Load an index from indexPathJ into memory.
Expand Down
4 changes: 4 additions & 0 deletions jni/include/jni_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ namespace knn_jni {
virtual std::vector<float> Convert2dJavaObjectArrayToCppFloatVector(JNIEnv *env, jobjectArray array2dJ,
int dim) = 0;

virtual void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ,
int dim, std::vector<float> *vect ) = 0;

virtual std::vector<int64_t> ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ) = 0;

// --------------------------------------------------------------------------
Expand Down Expand Up @@ -164,6 +167,7 @@ namespace knn_jni {
void ReleaseLongArrayElements(JNIEnv *env, jlongArray array, jlong *elems, jint mode);
void SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val);
void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf);
void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector<float> *vect);

private:
std::unordered_map<std::string, jclass> cachedClasses;
Expand Down
2 changes: 1 addition & 1 deletion jni/include/nmslib_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace knn_jni {
namespace nmslib_wrapper {
// Create an index with ids and vectors. The configuration is defined by values in the Java map, parametersJ.
// The index is serialized to indexPathJ.
void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jobjectArray vectorsJ,
void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddress, jint dim,
jstring indexPathJ, jobject parametersJ);

// Load an index from indexPathJ into memory. Use parametersJ to set any query time parameters
Expand Down
24 changes: 4 additions & 20 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@ extern "C" {
/*
* Class: org_opensearch_knn_jni_FaissService
* Method: createIndex
* Signature: ([I[[FLjava/lang/String;Ljava/util/Map;)V
* Signature: ([IJILjava/lang/String;Ljava/util/Map;)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex
(JNIEnv *, jclass, jintArray, jobjectArray, jstring, jobject);
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: createIndexFromTemplate
* Signature: ([I[[FLjava/lang/String;[BLjava/util/Map;)V
* Signature: ([IJILjava/lang/String;[BLjava/util/Map;)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate
(JNIEnv *, jclass, jintArray, jobjectArray, jstring, jbyteArray, jobject);
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
Expand Down Expand Up @@ -122,22 +122,6 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors
(JNIEnv *, jclass, jlong, jobjectArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: transferVectorsV2
* Signature: (J[[F)J
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectorsV2
(JNIEnv *, jclass, jlong, jobjectArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: freeVectors
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_freeVectors
(JNIEnv *, jclass, jlong);

#ifdef __cplusplus
}
#endif
Expand Down
40 changes: 40 additions & 0 deletions jni/include/org_opensearch_knn_jni_JNICommons.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

/* DO NOT EDIT THIS FILE - it is machine generated */
#include <jni.h>
/* Header for class org_opensearch_knn_jni_JNICommons */

#ifndef _Included_org_opensearch_knn_jni_JNICommons
#define _Included_org_opensearch_knn_jni_JNICommons
#ifdef __cplusplus
extern "C" {
#endif
/*
* Class: org_opensearch_knn_jni_JNICommons
* Method: storeVectorData
* Signature: (J[[FJJ)
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData
(JNIEnv *, jclass, jlong, jobjectArray, jlong);

/*
* Class: org_opensearch_knn_jni_JNICommons
* Method: freeVectorData
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeVectorData
(JNIEnv *, jclass, jlong);

#ifdef __cplusplus
}
#endif
#endif
4 changes: 2 additions & 2 deletions jni/include/org_opensearch_knn_jni_NmslibService.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ extern "C" {
/*
* Class: org_opensearch_knn_jni_NmslibService
* Method: createIndex
* Signature: ([I[[FLjava/lang/String;Ljava/util/Map;)V
* Signature: ([IJILjava/lang/String;Ljava/util/Map;)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_NmslibService_createIndex
(JNIEnv *, jclass, jintArray, jobjectArray, jstring, jobject);
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject);

/*
* Class: org_opensearch_knn_jni_NmslibService
Expand Down
41 changes: 41 additions & 0 deletions jni/src/commons.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/
#ifndef OPENSEARCH_KNN_COMMONS_H
#define OPENSEARCH_KNN_COMMONS_H
#include <jni.h>

#include <vector>

#include "jni_util.h"
#include "commons.h"

jlong knn_jni::commons::storeVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong memoryAddressJ,
jobjectArray dataJ, jlong initialCapacityJ) {
std::vector<float> *vect;
if ((long) memoryAddressJ == 0) {
vect = new std::vector<float>();
vect->reserve((long)initialCapacityJ);
} else {
vect = reinterpret_cast<std::vector<float>*>(memoryAddressJ);
}
int dim = jniUtil->GetInnerDimensionOf2dJavaFloatArray(env, dataJ);
jniUtil->Convert2dJavaObjectArrayAndStoreToFloatVector(env, dataJ, dim, vect);

return (jlong) vect;
}

void knn_jni::commons::freeVectorData(jlong memoryAddressJ) {
if (memoryAddressJ != 0) {
auto *vect = reinterpret_cast<std::vector<float>*>(memoryAddressJ);
delete vect;
}
}
#endif //OPENSEARCH_KNN_COMMONS_H
48 changes: 30 additions & 18 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,19 @@ bool isIndexIVFPQL2(faiss::Index * index);
// IndexIDMap which has member that will point to underlying index that stores the data
faiss::IndexIVFPQ * extractIVFPQIndex(faiss::Index * index);

void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jobjectArray vectorsJ, jstring indexPathJ, jobject parametersJ) {
void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ,
jstring indexPathJ, jobject parametersJ) {

if (idsJ == nullptr) {
throw std::runtime_error("IDs cannot be null");
}

if (vectorsJ == nullptr) {
throw std::runtime_error("Vectors cannot be null");
if (vectorsAddressJ <= 0) {
throw std::runtime_error("VectorsAddress cannot be less than 0");
}

if(dimJ <= 0) {
throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0");
}

if (indexPathJ == nullptr) {
Expand All @@ -109,16 +113,20 @@ void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JN
std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ));
faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp);

// Read data set
int numVectors = jniUtil->GetJavaObjectArrayLength(env, vectorsJ);
// Read vectors from memory address
auto *inputVectors = reinterpret_cast<std::vector<float>*>(vectorsAddressJ);
int dim = (int)dimJ;
// The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value
int numVectors = (int) (inputVectors->size() / (uint64_t) dim);
if(numVectors == 0) {
throw std::runtime_error("Number of vectors cannot be 0");
}

int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ);
if (numIds != numVectors) {
throw std::runtime_error("Number of IDs does not match number of vectors");
}

int dim = jniUtil->GetInnerDimensionOf2dJavaFloatArray(env, vectorsJ);
auto dataset = jniUtil->Convert2dJavaObjectArrayToCppFloatVector(env, vectorsJ, dim);

// Create faiss index
jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION);
std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ));
Expand Down Expand Up @@ -148,22 +156,26 @@ void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JN

auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ);
faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get());
idMap.add_with_ids(numVectors, dataset.data(), idVector.data());
idMap.add_with_ids(numVectors, inputVectors->data(), idVector.data());

// Write the index to disk
std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ));
faiss::write_index(&idMap, indexPathCpp.c_str());
}

void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jobjectArray vectorsJ, jstring indexPathJ,
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ,
jbyteArray templateIndexJ, jobject parametersJ) {
if (idsJ == nullptr) {
throw std::runtime_error("IDs cannot be null");
}

if (vectorsJ == nullptr) {
throw std::runtime_error("Vectors cannot be null");
if (vectorsAddressJ <= 0) {
throw std::runtime_error("VectorsAddress cannot be less than 0");
}

if(dimJ <= 0) {
throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0");
}

if (indexPathJ == nullptr) {
Expand All @@ -183,15 +195,15 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface *
jniUtil->DeleteLocalRef(env, parametersJ);

// Read data set
int numVectors = jniUtil->GetJavaObjectArrayLength(env, vectorsJ);
// Read vectors from memory address
auto *inputVectors = reinterpret_cast<std::vector<float>*>(vectorsAddressJ);
int dim = (int)dimJ;
int numVectors = (int) (inputVectors->size() / (uint64_t) dim);
int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ);
if (numIds != numVectors) {
throw std::runtime_error("Number of IDs does not match number of vectors");
}

int dim = jniUtil->GetInnerDimensionOf2dJavaFloatArray(env, vectorsJ);
auto dataset = jniUtil->Convert2dJavaObjectArrayToCppFloatVector(env, vectorsJ, dim);

// Get vector of bytes from jbytearray
int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ);
jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr);
Expand All @@ -208,7 +220,7 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface *

auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ);
faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get());
idMap.add_with_ids(numVectors, dataset.data(), idVector.data());
idMap.add_with_ids(numVectors, inputVectors->data(), idVector.data());

// Write the index to disk
std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ));
Expand Down
11 changes: 8 additions & 3 deletions jni/src/jni_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,13 @@ int knn_jni::JNIUtil::ConvertJavaObjectToCppInteger(JNIEnv *env, jobject objectJ

std::vector<float> knn_jni::JNIUtil::Convert2dJavaObjectArrayToCppFloatVector(JNIEnv *env, jobjectArray array2dJ,
int dim) {
std::vector<float> vect;
Convert2dJavaObjectArrayAndStoreToFloatVector(env, array2dJ, dim, &vect);
return vect;
}

void knn_jni::JNIUtil::Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ,
int dim, std::vector<float> *vect) {

if (array2dJ == nullptr) {
throw std::runtime_error("Array cannot be null");
Expand All @@ -231,7 +238,6 @@ std::vector<float> knn_jni::JNIUtil::Convert2dJavaObjectArrayToCppFloatVector(JN
int numVectors = env->GetArrayLength(array2dJ);
this->HasExceptionInStack(env);

std::vector<float> floatVectorCpp;
for (int i = 0; i < numVectors; ++i) {
auto vectorArray = (jfloatArray)env->GetObjectArrayElement(array2dJ, i);
this->HasExceptionInStack(env, "Unable to get object array element");
Expand All @@ -247,13 +253,12 @@ std::vector<float> knn_jni::JNIUtil::Convert2dJavaObjectArrayToCppFloatVector(JN
}

for(int j = 0; j < dim; ++j) {
floatVectorCpp.push_back(vector[j]);
vect->push_back(vector[j]);
}
env->ReleaseFloatArrayElements(vectorArray, vector, JNI_ABORT);
}
this->HasExceptionInStack(env);
env->DeleteLocalRef(array2dJ);
return floatVectorCpp;
}

std::vector<int64_t> knn_jni::JNIUtil::ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ) {
Expand Down
Loading

0 comments on commit 1fc15f1

Please sign in to comment.