Skip to content

Commit

Permalink
Add the interface for streaming the vectors from java to jni layer wi…
Browse files Browse the repository at this point in the history
…th initial capacity

Signed-off-by: Navneet Verma <navneev@amazon.com>
  • Loading branch information
navneet1v committed Apr 1, 2024
1 parent 771c4b5 commit 32027fe
Show file tree
Hide file tree
Showing 18 changed files with 404 additions and 73 deletions.
3 changes: 2 additions & 1 deletion jni/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ endif()
# ----------------------------------------------------------------------------

# ---------------------------------- COMMON ----------------------------------
add_library(${TARGET_LIB_COMMON} SHARED ${CMAKE_CURRENT_SOURCE_DIR}/src/jni_util.cpp)
add_library(${TARGET_LIB_COMMON} SHARED ${CMAKE_CURRENT_SOURCE_DIR}/src/jni_util.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/org_opensearch_knn_jni_JNICommons.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/commons.cpp)
target_include_directories(${TARGET_LIB_COMMON} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include $ENV{JAVA_HOME}/include $ENV{JAVA_HOME}/include/${JVM_OS_TYPE})
set_target_properties(${TARGET_LIB_COMMON} PROPERTIES SUFFIX ${LIB_EXT})
set_target_properties(${TARGET_LIB_COMMON} PROPERTIES POSITION_INDEPENDENT_CODE ON)
Expand Down Expand Up @@ -236,6 +236,7 @@ if ("${WIN32}" STREQUAL "")
tests/faiss_util_test.cpp
tests/nmslib_wrapper_test.cpp
tests/test_util.cpp
tests/commons_test.cpp
)

target_link_libraries(
Expand Down
37 changes: 37 additions & 0 deletions jni/include/commons.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/
#include "jni_util.h"
#include <jni.h>
namespace knn_jni {
namespace commons {
/**
* This is utility function that can be used to store data in native memory. This function will allocate memory for
* the data(rows*columns) with initialCapacity and return the memory address where the data is stored.
* If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created.
* For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location
* will throw Exception.
*
* @param memoryAddress The address of the memory location where data will be stored.
* @param data 2D float array containing data to be stored in native memory.
* @param initialCapacity The initial capacity of the memory location.
* @return memory address where the data is stored.
*/
jlong storeVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong);

/**
* Free up the memory allocated for the data stored in memory address. This function should be used with the memory
* address returned by {@link JNICommons#storeVectorData(long, float[][], long, long)}
*
* @param memoryAddress address to be freed.
*/
void freeVectorData(jlong);
}
}
4 changes: 4 additions & 0 deletions jni/include/jni_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ namespace knn_jni {
virtual std::vector<float> Convert2dJavaObjectArrayToCppFloatVector(JNIEnv *env, jobjectArray array2dJ,
int dim) = 0;

virtual void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ,
int dim, std::vector<float> *vect ) = 0;

virtual std::vector<int64_t> ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ) = 0;

// --------------------------------------------------------------------------
Expand Down Expand Up @@ -164,6 +167,7 @@ namespace knn_jni {
void ReleaseLongArrayElements(JNIEnv *env, jlongArray array, jlong *elems, jint mode);
void SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val);
void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf);
void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector<float> *vect);

private:
std::unordered_map<std::string, jclass> cachedClasses;
Expand Down
8 changes: 0 additions & 8 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,6 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors
(JNIEnv *, jclass, jlong, jobjectArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: transferVectorsV2
* Signature: (J[[F)J
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectorsV2
(JNIEnv *, jclass, jlong, jobjectArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: freeVectors
Expand Down
40 changes: 40 additions & 0 deletions jni/include/org_opensearch_knn_jni_JNICommons.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

/* DO NOT EDIT THIS FILE - it is machine generated */
#include <jni.h>
/* Header for class org_opensearch_knn_jni_JNICommons */

#ifndef _Included_org_opensearch_knn_jni_JNICommons
#define _Included_org_opensearch_knn_jni_JNICommons
#ifdef __cplusplus
extern "C" {
#endif
/*
* Class: org_opensearch_knn_jni_JNICommons
* Method: storeVectorData
* Signature: (J[[FJJ)
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData
(JNIEnv *, jclass, jlong, jobjectArray, jlong);

/*
* Class: org_opensearch_knn_jni_JNICommons
* Method: freeVectorData
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeVectorData
(JNIEnv *, jclass, jlong);

#ifdef __cplusplus
}
#endif
#endif
41 changes: 41 additions & 0 deletions jni/src/commons.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/
#ifndef OPENSEARCH_KNN_COMMONS_H
#define OPENSEARCH_KNN_COMMONS_H
#include <jni.h>

#include <vector>

#include "jni_util.h"
#include "commons.h"

jlong knn_jni::commons::storeVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong memoryAddressJ,
jobjectArray dataJ, jlong initialCapacityJ) {
std::vector<float> *vect;
if ((long) memoryAddressJ == 0) {
vect = new std::vector<float>();
vect->reserve((long)initialCapacityJ);
} else {
vect = reinterpret_cast<std::vector<float>*>(memoryAddressJ);
}
int dim = jniUtil->GetInnerDimensionOf2dJavaFloatArray(env, dataJ);
jniUtil->Convert2dJavaObjectArrayAndStoreToFloatVector(env, dataJ, dim, vect);

return (jlong) vect;
}

void knn_jni::commons::freeVectorData(jlong memoryAddressJ) {
if (memoryAddressJ != 0) {
auto *vect = reinterpret_cast<std::vector<float>*>(memoryAddressJ);
delete vect;
}
}
#endif //OPENSEARCH_KNN_COMMONS_H
11 changes: 8 additions & 3 deletions jni/src/jni_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,13 @@ int knn_jni::JNIUtil::ConvertJavaObjectToCppInteger(JNIEnv *env, jobject objectJ

std::vector<float> knn_jni::JNIUtil::Convert2dJavaObjectArrayToCppFloatVector(JNIEnv *env, jobjectArray array2dJ,
int dim) {
std::vector<float> vect;
Convert2dJavaObjectArrayAndStoreToFloatVector(env, array2dJ, dim, &vect);
return vect;
}

void knn_jni::JNIUtil::Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ,
int dim, std::vector<float> *vect) {

if (array2dJ == nullptr) {
throw std::runtime_error("Array cannot be null");
Expand All @@ -231,7 +238,6 @@ std::vector<float> knn_jni::JNIUtil::Convert2dJavaObjectArrayToCppFloatVector(JN
int numVectors = env->GetArrayLength(array2dJ);
this->HasExceptionInStack(env);

std::vector<float> floatVectorCpp;
for (int i = 0; i < numVectors; ++i) {
auto vectorArray = (jfloatArray)env->GetObjectArrayElement(array2dJ, i);
this->HasExceptionInStack(env, "Unable to get object array element");
Expand All @@ -247,13 +253,12 @@ std::vector<float> knn_jni::JNIUtil::Convert2dJavaObjectArrayToCppFloatVector(JN
}

for(int j = 0; j < dim; ++j) {
floatVectorCpp.push_back(vector[j]);
vect->push_back(vector[j]);
}
env->ReleaseFloatArrayElements(vectorArray, vector, JNI_ABORT);
}
this->HasExceptionInStack(env);
env->DeleteLocalRef(array2dJ);
return floatVectorCpp;
}

std::vector<int64_t> knn_jni::JNIUtil::ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ) {
Expand Down
19 changes: 0 additions & 19 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

#include <jni.h>

#include <algorithm>
#include <vector>

#include "faiss_wrapper.h"
Expand Down Expand Up @@ -191,24 +190,6 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors
return (jlong) vect;
}

JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectorsV2(JNIEnv * env, jclass cls,
jlong vectorsPointerJ,
jobjectArray vectorsJ)
{
std::vector<float> *vect;
if ((long) vectorsPointerJ == 0) {
vect = new std::vector<float>;
} else {
vect = reinterpret_cast<std::vector<float>*>(vectorsPointerJ);
}

int dim = jniUtil.GetInnerDimensionOf2dJavaFloatArray(env, vectorsJ);
auto dataset = jniUtil.Convert2dJavaObjectArrayToCppFloatVector(env, vectorsJ, dim);
vect->insert(vect->end(), dataset.begin(), dataset.end());

return (jlong) vect;
}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_freeVectors(JNIEnv * env, jclass cls,
jlong vectorsPointerJ)
{
Expand Down
60 changes: 60 additions & 0 deletions jni/src/org_opensearch_knn_jni_JNICommons.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

#include "org_opensearch_knn_jni_JNICommons.h"

#include <jni.h>
#include "commons.h"
#include "jni_util.h"

static knn_jni::JNIUtil jniUtil;
static const jint KNN_JNICOMMONS_JNI_VERSION = JNI_VERSION_1_1;

jint JNI_OnLoad(JavaVM* vm, void* reserved) {
// Obtain the JNIEnv from the VM and confirm JNI_VERSION
JNIEnv* env;
if (vm->GetEnv((void**)&env, KNN_JNICOMMONS_JNI_VERSION) != JNI_OK) {
return JNI_ERR;
}

jniUtil.Initialize(env);

return KNN_JNICOMMONS_JNI_VERSION;
}

void JNI_OnUnload(JavaVM *vm, void *reserved) {
JNIEnv* env;
vm->GetEnv((void**)&env, KNN_JNICOMMONS_JNI_VERSION);
jniUtil.Uninitialize(env);
}


JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData(JNIEnv * env, jclass cls,
jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ)

{
try {
return knn_jni::commons::storeVectorData(&jniUtil, env, memoryAddressJ, dataJ, initialCapacityJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
return (long)memoryAddressJ;
}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeVectorData(JNIEnv * env, jclass cls,
jlong memoryAddressJ)
{
try {
return knn_jni::commons::freeVectorData(memoryAddressJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
}
76 changes: 76 additions & 0 deletions jni/tests/commons_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/


#include "test_util.h"
#include <vector>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "jni_util.h"
#include "commons.h"

TEST(CommonsTests, BasicAssertions) {
long dim = 3;
long totalNumberOfVector = 5;
std::vector<std::vector<float>> data;
for(int i = 0 ; i < totalNumberOfVector - 1 ; i++) {
std::vector<float> vector;
for(int j = 0 ; j < dim ; j ++) {
vector.push_back((float)j);
}
data.push_back(vector);
}
JNIEnv *jniEnv = nullptr;

testing::NiceMock<test_util::MockJNIUtil> mockJNIUtil;

jlong memoryAddress = knn_jni::commons::storeVectorData(&mockJNIUtil, jniEnv, (jlong)0,
reinterpret_cast<jobjectArray>(&data), (jlong)(totalNumberOfVector * dim) , (jlong)0);
ASSERT_NE(memoryAddress, 0);
auto *vect = reinterpret_cast<std::vector<float>*>(memoryAddress);
ASSERT_EQ(vect->size(), totalNumberOfVector*dim);

// Check by inserting more vectors at same memory location
jlong oldMemoryAddress = memoryAddress;
std::vector<std::vector<float>> data2;
std::vector<float> vector;
for(int j = 0 ; j < dim ; j ++) {
vector.push_back((float)j);
}
data2.push_back(vector);
memoryAddress = knn_jni::commons::storeVectorData(&mockJNIUtil, jniEnv, memoryAddress,
reinterpret_cast<jobjectArray>(&data2), (jlong)(totalNumberOfVector * dim) , (jlong)(data.size() * dim));
ASSERT_NE(memoryAddress, 0);
ASSERT_EQ(memoryAddress, oldMemoryAddress);
vect = reinterpret_cast<std::vector<float>*>(memoryAddress);
int currentIndex = 0;
ASSERT_EQ(vect->size(), totalNumberOfVector*dim);

// Validate if all vectors data are at correct location
for(auto & i : data) {
for(float j : i) {
ASSERT_FLOAT_EQ(vect->at(currentIndex), j);
currentIndex++;
}
}

for(auto & i : data2) {
for(float j : i) {
ASSERT_FLOAT_EQ(vect->at(currentIndex), j);
currentIndex++;
}
}

// Validate if more data is provided then the initial capacity throw exception
ASSERT_THROW(
knn_jni::commons::storeVectorData(&mockJNIUtil, jniEnv, memoryAddress,
reinterpret_cast<jobjectArray>(&data2), (jlong)(totalNumberOfVector * dim), (jlong)(vect->size())), std::runtime_error);
}
8 changes: 8 additions & 0 deletions jni/tests/test_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ test_util::MockJNIUtil::MockJNIUtil() {
return data;
});

ON_CALL(*this, Convert2dJavaObjectArrayAndStoreToFloatVector)
.WillByDefault([this](JNIEnv *env, jobjectArray array2dJ, int dim, std::vector<float>* data) {
for (const auto &v :
(*reinterpret_cast<std::vector<std::vector<float>> *>(array2dJ)))
for (auto item : v) data->push_back(item);
});


// arrayJ is re-interpreted as std::vector<int64_t> *
ON_CALL(*this, ConvertJavaIntArrayToCppIntVector)
.WillByDefault([this](JNIEnv *env, jintArray arrayJ) {
Expand Down
Loading

0 comments on commit 32027fe

Please sign in to comment.