Skip to content

Commit

Permalink
Fix bug where quantization framework does not work with training (#2100)
Browse files Browse the repository at this point in the history
* Initial implementation

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Modify integration test and fix bugs in jni

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Fix unit test

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Fix integration test after merge

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Add changelog (release notes)

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Add unit test

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Remove entry for release notes

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Add null checks

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

---------

Signed-off-by: Ryan Bogan <rbogan@amazon.com>
  • Loading branch information
ryanbogan committed Sep 12, 2024
1 parent 18c26f3 commit 5d10d64
Show file tree
Hide file tree
Showing 11 changed files with 168 additions and 61 deletions.
25 changes: 17 additions & 8 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,20 +684,29 @@ jobjectArray knn_jni::faiss_wrapper::QueryBinaryIndex_WithFilter(knn_jni::JNIUti
} else {
faiss::SearchParameters *searchParameters = nullptr;
faiss::SearchParametersHNSW hnswParams;
faiss::SearchParametersIVF ivfParams;
std::unique_ptr<faiss::IDGrouperBitmap> idGrouper;
std::vector<uint64_t> idGrouperBitmap;
auto hnswReader = dynamic_cast<const faiss::IndexBinaryHNSW*>(indexReader->index);
auto ivfReader = dynamic_cast<const faiss::IndexBinaryIVF*>(indexReader->index);
// TODO currently, search parameter is not supported in binary index
// To avoid test failure, we skip setting ef search when methodPramsJ is null temporary
if(hnswReader!= nullptr && (methodParamsJ != nullptr || parentIdsJ != nullptr)) {
// Query param efsearch supersedes ef_search provided during index setting.
hnswParams.efSearch = knn_jni::commons::getIntegerMethodParameter(env, jniUtil, methodParams, EF_SEARCH, hnswReader->hnsw.efSearch);
if (parentIdsJ != nullptr) {
idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap);
hnswParams.grp = idGrouper.get();
if (ivfReader) {
int indexNprobe = ivfReader->nprobe;
ivfParams.nprobe = commons::getIntegerMethodParameter(env, jniUtil, methodParams, NPROBES, indexNprobe);
searchParameters = &ivfParams;
} else {
auto hnswReader = dynamic_cast<const faiss::IndexBinaryHNSW*>(indexReader->index);
if(hnswReader != nullptr && (methodParamsJ != nullptr || parentIdsJ != nullptr)) {
// Query param efsearch supersedes ef_search provided during index setting.
hnswParams.efSearch = knn_jni::commons::getIntegerMethodParameter(env, jniUtil, methodParams, EF_SEARCH, hnswReader->hnsw.efSearch);
if (parentIdsJ != nullptr) {
idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap);
hnswParams.grp = idGrouper.get();
}
searchParameters = &hnswParams;
}
searchParameters = &hnswParams;
}

try {
indexReader->search(1, reinterpret_cast<uint8_t*>(rawQueryvector), kJ, dis.data(), ids.data(), searchParameters);
} catch (...) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public enum CompressionLevel {
x32(32, "32x", new RescoreContext(3.0f), Set.of(Mode.ON_DISK));

// Internally, an empty string is easier to deal with them null. However, from the mapping,
// we do not want users to pass in the empty string and instead want null. So we make the conversion herex
// we do not want users to pass in the empty string and instead want null. So we make the conversion here
public static final String[] NAMES_ARRAY = new String[] {
NOT_CONFIGURED.getName(),
x1.getName(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
package org.opensearch.knn.index.memory;

import lombok.Getter;
import lombok.Setter;
import org.apache.lucene.index.LeafReaderContext;
import org.opensearch.knn.common.featureflags.KNNFeatureFlags;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.query.KNNWeight;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.index.engine.KNNEngine;
Expand Down Expand Up @@ -252,6 +254,9 @@ class TrainingDataAllocation implements NativeMemoryAllocation {
private volatile boolean closed;
private long memoryAddress;
private final int size;
@Getter
@Setter
private QuantizationConfig quantizationConfig = QuantizationConfig.EMPTY;

// Implement reader/writer with semaphores to deal with passing lock conditions between threads
private int readCount;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@

package org.opensearch.knn.index.memory;

import lombok.Getter;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.Nullable;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.util.IndexUtil;
import org.opensearch.knn.index.VectorDataType;

Expand Down Expand Up @@ -171,6 +173,8 @@ public static class TrainingDataEntryContext extends NativeMemoryEntryContext<Na
private final int maxVectorCount;
private final int searchSize;
private final VectorDataType vectorDataType;
@Getter
private final QuantizationConfig quantizationConfig;

/**
* Constructor
Expand All @@ -191,7 +195,8 @@ public TrainingDataEntryContext(
ClusterService clusterService,
int maxVectorCount,
int searchSize,
VectorDataType vectorDataType
VectorDataType vectorDataType,
QuantizationConfig quantizationConfig
) {
super(generateKey(trainIndexName, trainFieldName));
this.size = size;
Expand All @@ -202,6 +207,7 @@ public TrainingDataEntryContext(
this.maxVectorCount = maxVectorCount;
this.searchSize = searchSize;
this.vectorDataType = vectorDataType;
this.quantizationConfig = quantizationConfig;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import lombok.extern.log4j.Log4j2;
import org.opensearch.core.action.ActionListener;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.util.IndexUtil;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.index.engine.KNNEngine;
Expand Down Expand Up @@ -171,6 +172,9 @@ public NativeMemoryAllocation.TrainingDataAllocation load(
nativeMemoryEntryContext.getVectorDataType()
);

QuantizationConfig quantizationConfig = nativeMemoryEntryContext.getQuantizationConfig();
trainingDataAllocation.setQuantizationConfig(quantizationConfig);

TrainingDataConsumer vectorDataConsumer = nativeMemoryEntryContext.getVectorDataType()
.getTrainingDataConsumer(trainingDataAllocation);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.knn.index.engine.KNNLibraryIndexingContext;
import org.opensearch.knn.index.engine.KNNMethodConfigContext;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
import org.opensearch.knn.index.memory.NativeMemoryEntryContext;
import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy;
Expand Down Expand Up @@ -45,6 +48,15 @@ public TrainingModelTransportAction(TransportService transportService, ActionFil

@Override
protected void doExecute(Task task, TrainingModelRequest request, ActionListener<TrainingModelResponse> listener) {
KNNMethodContext knnMethodContext = request.getKnnMethodContext();
KNNMethodConfigContext knnMethodConfigContext = request.getKnnMethodConfigContext();
QuantizationConfig quantizationConfig = QuantizationConfig.EMPTY;

if (knnMethodContext != null && request.getKnnMethodConfigContext() != null) {
KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine()
.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext);
quantizationConfig = knnLibraryIndexingContext.getQuantizationConfig();
}

NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = new NativeMemoryEntryContext.TrainingDataEntryContext(
request.getTrainingDataSizeInKB(),
Expand All @@ -54,7 +66,8 @@ protected void doExecute(Task task, TrainingModelRequest request, ActionListener
clusterService,
request.getMaximumVectorCount(),
request.getSearchSize(),
request.getVectorDataType()
request.getVectorDataType(),
quantizationConfig
);

// Allocation representing size model will occupy in memory during training
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,19 @@

import org.apache.commons.lang.ArrayUtils;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.jni.JNICommons;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.index.memory.NativeMemoryAllocation;
import org.opensearch.knn.quantization.factory.QuantizerFactory;
import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.quantization.models.requests.TrainingRequest;
import org.opensearch.knn.quantization.quantizer.Quantizer;
import org.opensearch.search.SearchHit;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

Expand All @@ -25,23 +34,37 @@
*/
public class FloatTrainingDataConsumer extends TrainingDataConsumer {

private final QuantizationConfig quantizationConfig;

/**
* Constructor
*
* @param trainingDataAllocation NativeMemoryAllocation that contains information about native memory allocation.
*/
public FloatTrainingDataConsumer(NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation) {
super(trainingDataAllocation);
this.quantizationConfig = trainingDataAllocation.getQuantizationConfig();
}

@Override
public void accept(List<?> floats) {
trainingDataAllocation.setMemoryAddress(
JNIService.transferVectors(
trainingDataAllocation.getMemoryAddress(),
floats.stream().map(v -> ArrayUtils.toPrimitive((Float[]) v)).toArray(float[][]::new)
)
);
if (isValidFloatsAndQuantizationConfig(floats)) {
try {
List<byte[]> byteVectors = quantizeVectors(floats);
long memoryAddress = trainingDataAllocation.getMemoryAddress();
memoryAddress = JNICommons.storeBinaryVectorData(memoryAddress, byteVectors.toArray(new byte[0][0]), byteVectors.size());
trainingDataAllocation.setMemoryAddress(memoryAddress);
} catch (IOException e) {
throw new RuntimeException(e);
}
} else {
trainingDataAllocation.setMemoryAddress(
JNIService.transferVectors(
trainingDataAllocation.getMemoryAddress(),
floats.stream().map(v -> ArrayUtils.toPrimitive((Float[]) v)).toArray(float[][]::new)
)
);
}
}

@Override
Expand All @@ -64,4 +87,29 @@ public void processTrainingVectors(SearchResponse searchResponse, int vectorsToA

accept(vectors);
}

private List<byte[]> quantizeVectors(List<?> vectors) throws IOException {
List<byte[]> bytes = new ArrayList<>();
ScalarQuantizationParams quantizationParams = new ScalarQuantizationParams(quantizationConfig.getQuantizationType());
Quantizer<float[], byte[]> quantizer = QuantizerFactory.getQuantizer(quantizationParams);
// Create training request
TrainingRequest<float[]> trainingRequest = new TrainingRequest<float[]>(vectors.size()) {
@Override
public float[] getVectorAtThePosition(int position) {
return ArrayUtils.toPrimitive((Float[]) vectors.get(position));
}
};
QuantizationState quantizationState = quantizer.train(trainingRequest);
BinaryQuantizationOutput binaryQuantizationOutput = new BinaryQuantizationOutput(quantizationConfig.getQuantizationType().getId());
for (int i = 0; i < vectors.size(); i++) {
quantizer.quantize(ArrayUtils.toPrimitive((Float[]) vectors.get(i)), quantizationState, binaryQuantizationOutput);
bytes.add(binaryQuantizationOutput.getQuantizedVectorCopy());
}

return bytes;
}

private boolean isValidFloatsAndQuantizationConfig(List<?> floats) {
return floats != null && floats.isEmpty() == false && quantizationConfig != null && quantizationConfig != QuantizationConfig.EMPTY;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import com.google.common.collect.ImmutableMap;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.util.IndexUtil;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
Expand Down Expand Up @@ -124,7 +125,8 @@ public void testTrainingDataEntryContext_load() {
null,
0,
0,
VectorDataType.DEFAULT
VectorDataType.DEFAULT,
QuantizationConfig.EMPTY
);

NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation = new NativeMemoryAllocation.TrainingDataAllocation(
Expand All @@ -149,7 +151,8 @@ public void testTrainingDataEntryContext_getTrainIndexName() {
null,
0,
0,
VectorDataType.DEFAULT
VectorDataType.DEFAULT,
QuantizationConfig.EMPTY
);

assertEquals(trainIndexName, trainingDataEntryContext.getTrainIndexName());
Expand All @@ -165,7 +168,8 @@ public void testTrainingDataEntryContext_getTrainFieldName() {
null,
0,
0,
VectorDataType.DEFAULT
VectorDataType.DEFAULT,
QuantizationConfig.EMPTY
);

assertEquals(trainFieldName, trainingDataEntryContext.getTrainFieldName());
Expand All @@ -181,7 +185,8 @@ public void testTrainingDataEntryContext_getMaxVectorCount() {
null,
maxVectorCount,
0,
VectorDataType.DEFAULT
VectorDataType.DEFAULT,
QuantizationConfig.EMPTY
);

assertEquals(maxVectorCount, trainingDataEntryContext.getMaxVectorCount());
Expand All @@ -197,7 +202,8 @@ public void testTrainingDataEntryContext_getSearchSize() {
null,
0,
searchSize,
VectorDataType.DEFAULT
VectorDataType.DEFAULT,
QuantizationConfig.EMPTY
);

assertEquals(searchSize, trainingDataEntryContext.getSearchSize());
Expand All @@ -213,7 +219,8 @@ public void testTrainingDataEntryContext_getIndicesService() {
clusterService,
0,
0,
VectorDataType.DEFAULT
VectorDataType.DEFAULT,
QuantizationConfig.EMPTY
);

assertEquals(clusterService, trainingDataEntryContext.getClusterService());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.opensearch.knn.TestUtils;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.jni.JNICommons;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.index.query.KNNQueryResult;
Expand Down Expand Up @@ -180,7 +181,8 @@ public void testTrainingLoadStrategy_load() {
null,
0,
0,
VectorDataType.FLOAT
VectorDataType.FLOAT,
QuantizationConfig.EMPTY
);

// Load the allocation. Initially, the memory address should be 0. However, after the readlock is obtained,
Expand Down
Loading

0 comments on commit 5d10d64

Please sign in to comment.