diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index 0d016a60b..3f32003ac 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -25,11 +25,10 @@ import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.RamUsageEstimator; import org.opensearch.common.StopWatch; -import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; +import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; -import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; import org.opensearch.knn.plugin.stats.KNNGraphValue; import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; @@ -39,6 +38,7 @@ import java.util.List; import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType; +import static org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory.getVectorValues; /** * A KNNVectorsWriter class for writing the vector data strcutures and flat vectors for Native Engines. @@ -47,15 +47,11 @@ public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter { private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(NativeEngines990KnnVectorsWriter.class); - private static final String FLUSH_OPERATION = "flush"; - private static final String MERGE_OPERATION = "merge"; - private final SegmentWriteState segmentWriteState; private final FlatVectorsWriter flatVectorsWriter; private KNN990QuantizationStateWriter quantizationStateWriter; private final List> fields = new ArrayList<>(); private boolean finished; - private final QuantizationService quantizationService = QuantizationService.getInstance(); public NativeEngines990KnnVectorsWriter(SegmentWriteState segmentWriteState, FlatVectorsWriter flatVectorsWriter) { this.segmentWriteState = segmentWriteState; @@ -84,14 +80,27 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { flatVectorsWriter.flush(maxDoc, sortMap); for (final NativeEngineFieldVectorsWriter field : fields) { - trainAndIndex( - field.getFieldInfo(), - (vectorDataType, fieldInfo, fieldVectorsWriter) -> getKNNVectorValues(vectorDataType, fieldVectorsWriter), - NativeIndexWriter::flushIndex, - field, - KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS, - FLUSH_OPERATION - ); + final FieldInfo fieldInfo = field.getFieldInfo(); + final VectorDataType vectorDataType = extractVectorDataType(fieldInfo); + int totalLiveDocs = getLiveDocs(getVectorValues(vectorDataType, field.getDocsWithField(), field.getVectors())); + if (totalLiveDocs > 0) { + KNNVectorValues knnVectorValues = getVectorValues(vectorDataType, field.getDocsWithField(), field.getVectors()); + + final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValues, totalLiveDocs); + final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState); + + knnVectorValues = getVectorValues(vectorDataType, field.getDocsWithField(), field.getVectors()); + + StopWatch stopWatch = new StopWatch().start(); + + writer.flushIndex(knnVectorValues, totalLiveDocs); + + long time_in_millis = stopWatch.stop().totalTime().millis(); + KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis); + log.debug("Flush took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName()); + } else { + log.debug("[Flush] No live docs for field {}", fieldInfo.getName()); + } } } @@ -100,15 +109,26 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState // This will ensure that we are merging the FlatIndex during force merge. flatVectorsWriter.mergeOneField(fieldInfo, mergeState); - // For merge, pick values from flat vector and reindex again. This will use the flush operation to create graphs - trainAndIndex( - fieldInfo, - this::getKNNVectorValuesForMerge, - NativeIndexWriter::mergeIndex, - mergeState, - KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS, - MERGE_OPERATION - ); + final VectorDataType vectorDataType = extractVectorDataType(fieldInfo); + int totalLiveDocs = getLiveDocs(getKNNVectorValuesForMerge(vectorDataType, fieldInfo, mergeState)); + if (totalLiveDocs == 0) { + log.debug("[Merge] No live docs for field {}", fieldInfo.getName()); + return; + } + + KNNVectorValues knnVectorValues = getKNNVectorValuesForMerge(vectorDataType, fieldInfo, mergeState); + final QuantizationState quantizationState = train(fieldInfo, knnVectorValues, totalLiveDocs); + final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState); + + knnVectorValues = getKNNVectorValuesForMerge(vectorDataType, fieldInfo, mergeState); + + StopWatch stopWatch = new StopWatch().start(); + + writer.mergeIndex(knnVectorValues, totalLiveDocs); + + long time_in_millis = stopWatch.stop().totalTime().millis(); + KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis); + log.debug("Merge took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName()); } /** @@ -157,18 +177,6 @@ public long ramBytesUsed() { .sum(); } - /** - * Retrieves the {@link KNNVectorValues} for a specific field based on the vector data type and field writer. - * - * @param vectorDataType The {@link VectorDataType} representing the type of vectors stored. - * @param field The {@link NativeEngineFieldVectorsWriter} representing the field from which to retrieve vectors. - * @param The type of vectors being processed. - * @return The {@link KNNVectorValues} associated with the field. - */ - private KNNVectorValues getKNNVectorValues(final VectorDataType vectorDataType, final NativeEngineFieldVectorsWriter field) { - return (KNNVectorValues) KNNVectorValuesFactory.getVectorValues(vectorDataType, field.getDocsWithField(), field.getVectors()); - } - /** * Retrieves the {@link KNNVectorValues} for a specific field during a merge operation, based on the vector data type. * @@ -187,84 +195,28 @@ private KNNVectorValues getKNNVectorValuesForMerge( switch (fieldInfo.getVectorEncoding()) { case FLOAT32: FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); - return (KNNVectorValues) KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedFloats); + return getVectorValues(vectorDataType, mergedFloats); case BYTE: ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); - return (KNNVectorValues) KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedBytes); + return getVectorValues(vectorDataType, mergedBytes); default: throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]"); } } - /** - * Functional interface representing an operation that indexes the provided {@link KNNVectorValues}. - * - * @param The type of vectors being processed. - */ - @FunctionalInterface - private interface IndexOperation { - void buildAndWrite(NativeIndexWriter writer, KNNVectorValues knnVectorValues, int totalLiveDocs) throws IOException; - } - - /** - * Functional interface representing a method that retrieves {@link KNNVectorValues} based on - * the vector data type, field information, and the merge state. - * - * @param The type of the data representing the vector (e.g., {@link VectorDataType}). - * @param The metadata about the field. - * @param The state of the merge operation. - * @param The result of the retrieval, typically {@link KNNVectorValues}. - */ - @FunctionalInterface - private interface VectorValuesRetriever { - Result apply(DataType vectorDataType, FieldInfo fieldInfo, MergeState mergeState) throws IOException; - } + private QuantizationState train(final FieldInfo fieldInfo, final KNNVectorValues knnVectorValues, final int totalLiveDocs) + throws IOException { - /** - * Unified method for processing a field during either the indexing or merge operation. This method retrieves vector values - * based on the provided vector data type and applies the specified index operation, potentially including quantization if needed. - * - * @param fieldInfo The {@link FieldInfo} object containing metadata about the field. - * @param vectorValuesRetriever A functional interface that retrieves {@link KNNVectorValues} based on the vector data type, - * field information, and additional context (e.g., merge state or field writer). - * @param indexOperation A functional interface that performs the indexing operation using the retrieved - * {@link KNNVectorValues}. - * @param VectorProcessingContext The additional context required for retrieving the vector values (e.g., {@link MergeState} or {@link NativeEngineFieldVectorsWriter}). - * From Flush we need NativeFieldWriter which contains total number of vectors while from Merge we need merge state which contains vector information - * @param The type of vectors being processed. - * @param The type of the context needed for retrieving the vector values. - * @throws IOException If an I/O error occurs during the processing. - */ - private void trainAndIndex( - final FieldInfo fieldInfo, - final VectorValuesRetriever> vectorValuesRetriever, - final IndexOperation indexOperation, - final C VectorProcessingContext, - final KNNGraphValue graphBuildTime, - final String operationName - ) throws IOException { - final VectorDataType vectorDataType = extractVectorDataType(fieldInfo); - // Count the docIds - int totalLiveDocs = getLiveDocs(vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext)); - if (totalLiveDocs == 0) { - log.debug("No live docs for field " + fieldInfo.name); - return; - } + final QuantizationService quantizationService = QuantizationService.getInstance(); + final QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo); QuantizationState quantizationState = null; - QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo); - if (quantizationParams != null) { + if (quantizationParams != null && totalLiveDocs > 0) { initQuantizationStateWriterIfNecessary(); - KNNVectorValues knnVectorValues = vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext); quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs); quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState); } - NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState); - KNNVectorValues knnVectors = vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext); - StopWatch stopWatch = new StopWatch().start(); - indexOperation.buildAndWrite(writer, knnVectors, totalLiveDocs); - long time_in_millis = stopWatch.stop().totalTime().millis(); - graphBuildTime.incrementBy(time_in_millis); - log.warn("Graph build took " + time_in_millis + " ms for " + operationName); + + return quantizationState; } /** diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java new file mode 100644 index 000000000..ad72f5b24 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java @@ -0,0 +1,288 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN990Codec; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import lombok.RequiredArgsConstructor; +import lombok.SneakyThrows; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.VectorEncoding; +import org.mockito.Mock; +import org.mockito.MockedConstruction; +import org.mockito.MockedStatic; +import org.mockito.MockitoAnnotations; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; +import org.opensearch.knn.index.quantizationservice.QuantizationService; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; +import org.opensearch.knn.index.vectorvalues.TestVectorValues; +import org.opensearch.knn.plugin.stats.KNNGraphValue; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.stream.IntStream; + +import static com.carrotsearch.randomizedtesting.RandomizedTest.$; +import static com.carrotsearch.randomizedtesting.RandomizedTest.$$; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockConstruction; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@RequiredArgsConstructor +public class NativeEngines990KnnVectorsWriterFlushTests extends OpenSearchTestCase { + + @Mock + private FlatVectorsWriter flatVectorsWriter; + @Mock + private SegmentWriteState segmentWriteState; + @Mock + private QuantizationParams quantizationParams; + @Mock + private QuantizationState quantizationState; + @Mock + private QuantizationService quantizationService; + @Mock + private NativeIndexWriter nativeIndexWriter; + + private NativeEngines990KnnVectorsWriter objectUnderTest; + + private final String description; + private final List> vectorsPerField; + + @Override + public void setUp() throws Exception { + super.setUp(); + MockitoAnnotations.openMocks(this); + objectUnderTest = new NativeEngines990KnnVectorsWriter(segmentWriteState, flatVectorsWriter); + } + + @ParametersFactory + public static Collection data() { + return Arrays.asList( + $$( + $("Single field", List.of(Map.of(0, new float[] { 1, 2, 3 }, 1, new float[] { 2, 3, 4 }, 2, new float[] { 3, 4, 5 }))), + $("Single field, no total live docs", List.of()), + $( + "Multi Field", + List.of( + Map.of(0, new float[] { 1, 2, 3 }, 1, new float[] { 2, 3, 4 }, 2, new float[] { 3, 4, 5 }), + Map.of( + 0, + new float[] { 1, 2, 3, 4 }, + 1, + new float[] { 2, 3, 4, 5 }, + 2, + new float[] { 3, 4, 5, 6 }, + 3, + new float[] { 4, 5, 6, 7 } + ) + ) + ) + ) + ); + } + + @SneakyThrows + public void testFlush() { + // Given + List> expectedVectorValues = new ArrayList<>(); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectorsPerField.get(i).values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + expectedVectorValues.add(knnVectorValues); + + }); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) + .thenReturn(field); + + try { + objectUnderTest.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + }); + + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + objectUnderTest.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + assertNotEquals(0L, (long) KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue()); + } + + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + try { + verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + } + + @SneakyThrows + public void testFlush_WithQuantization() { + // Given + List> expectedVectorValues = new ArrayList<>(); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectorsPerField.get(i).values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + expectedVectorValues.add(knnVectorValues); + + }); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) + .thenReturn(field); + + try { + objectUnderTest.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); + try { + when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size())) + .thenReturn(quantizationState); + } catch (Exception e) { + throw new RuntimeException(e); + } + + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)) + .thenReturn(nativeIndexWriter); + }); + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + objectUnderTest.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeHeader(segmentWriteState); + assertTrue(KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue() > 0L); + } else { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + } + + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + try { + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(i, quantizationState); + verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + } + + private FieldInfo fieldInfo(int fieldNumber, VectorEncoding vectorEncoding, Map attributes) { + FieldInfo fieldInfo = mock(FieldInfo.class); + when(fieldInfo.getFieldNumber()).thenReturn(fieldNumber); + when(fieldInfo.getVectorEncoding()).thenReturn(vectorEncoding); + when(fieldInfo.attributes()).thenReturn(attributes); + attributes.forEach((key, value) -> when(fieldInfo.getAttribute(key)).thenReturn(value)); + return fieldInfo; + } + + private NativeEngineFieldVectorsWriter nativeEngineFieldVectorsWriter(FieldInfo fieldInfo, Map vectors) { + NativeEngineFieldVectorsWriter fieldVectorsWriter = mock(NativeEngineFieldVectorsWriter.class); + DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet(); + vectors.keySet().stream().sorted().forEach(docsWithFieldSet::add); + when(fieldVectorsWriter.getFieldInfo()).thenReturn(fieldInfo); + when(fieldVectorsWriter.getVectors()).thenReturn(vectors); + when(fieldVectorsWriter.getDocsWithField()).thenReturn(docsWithFieldSet); + return fieldVectorsWriter; + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java new file mode 100644 index 000000000..440e8bbc5 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java @@ -0,0 +1,240 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN990Codec; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import lombok.RequiredArgsConstructor; +import lombok.SneakyThrows; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.VectorEncoding; +import org.mockito.Mock; +import org.mockito.MockedConstruction; +import org.mockito.MockedStatic; +import org.mockito.MockitoAnnotations; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; +import org.opensearch.knn.index.quantizationservice.QuantizationService; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; +import org.opensearch.knn.index.vectorvalues.TestVectorValues; +import org.opensearch.knn.plugin.stats.KNNGraphValue; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Map; + +import static com.carrotsearch.randomizedtesting.RandomizedTest.$; +import static com.carrotsearch.randomizedtesting.RandomizedTest.$$; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockConstruction; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +@RequiredArgsConstructor +public class NativeEngines990KnnVectorsWriterMergeTests extends OpenSearchTestCase { + + @Mock + private FlatVectorsWriter flatVectorsWriter; + @Mock + private SegmentWriteState segmentWriteState; + @Mock + private QuantizationParams quantizationParams; + @Mock + private QuantizationState quantizationState; + @Mock + private QuantizationService quantizationService; + @Mock + private NativeIndexWriter nativeIndexWriter; + @Mock + private FloatVectorValues floatVectorValues; + @Mock + private MergeState mergeState; + + private NativeEngines990KnnVectorsWriter objectUnderTest; + + private final String description; + private final Map mergedVectors; + + @Override + public void setUp() throws Exception { + super.setUp(); + MockitoAnnotations.openMocks(this); + objectUnderTest = new NativeEngines990KnnVectorsWriter(segmentWriteState, flatVectorsWriter); + } + + @ParametersFactory + public static Collection data() { + return Arrays.asList( + $$( + $("Merge one field", Map.of(0, new float[] { 1, 2, 3 }, 1, new float[] { 2, 3, 4 }, 2, new float[] { 3, 4, 5 })), + $("Merge, no live docs", Map.of()) + ) + ); + } + + @SneakyThrows + public void testMerge() { + // Given + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(mergedVectors.values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedStatic mergedVectorValuesMockedStatic = mockStatic( + KnnVectorsWriter.MergedVectorValues.class + ); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + final FieldInfo fieldInfo = fieldInfo( + 0, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, mergedVectors); + fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) + .thenReturn(field); + + mergedVectorValuesMockedStatic.when(() -> KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)) + .thenReturn(floatVectorValues); + knnVectorValuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues)) + .thenReturn(knnVectorValues); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).mergeIndex(any(), anyInt()); + + // When + objectUnderTest.mergeOneField(fieldInfo, mergeState); + + // Then + verify(flatVectorsWriter).mergeOneField(fieldInfo, mergeState); + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + if (!mergedVectors.isEmpty()) { + verify(nativeIndexWriter).mergeIndex(knnVectorValues, mergedVectors.size()); + assertTrue(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue() > 0L); + } else { + verifyNoInteractions(nativeIndexWriter); + } + } + } + + @SneakyThrows + public void testMerge_WithQuantization() { + // Given + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(mergedVectors.values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + MockedStatic mergedVectorValuesMockedStatic = mockStatic( + KnnVectorsWriter.MergedVectorValues.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + + final FieldInfo fieldInfo = fieldInfo( + 0, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, mergedVectors); + fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) + .thenReturn(field); + + mergedVectorValuesMockedStatic.when(() -> KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)) + .thenReturn(floatVectorValues); + knnVectorValuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues)) + .thenReturn(knnVectorValues); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); + try { + when(quantizationService.train(quantizationParams, knnVectorValues, mergedVectors.size())).thenReturn(quantizationState); + } catch (Exception e) { + throw new RuntimeException(e); + } + + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)) + .thenReturn(nativeIndexWriter); + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).mergeIndex(any(), anyInt()); + + // When + objectUnderTest.mergeOneField(fieldInfo, mergeState); + + // Then + verify(flatVectorsWriter).mergeOneField(fieldInfo, mergeState); + if (!mergedVectors.isEmpty()) { + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeHeader(segmentWriteState); + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(0, quantizationState); + verify(nativeIndexWriter).mergeIndex(knnVectorValues, mergedVectors.size()); + assertTrue(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue() > 0L); + } else { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + verifyNoInteractions(nativeIndexWriter); + } + + } + } + + private FieldInfo fieldInfo(int fieldNumber, VectorEncoding vectorEncoding, Map attributes) { + FieldInfo fieldInfo = mock(FieldInfo.class); + when(fieldInfo.getFieldNumber()).thenReturn(fieldNumber); + when(fieldInfo.getVectorEncoding()).thenReturn(vectorEncoding); + when(fieldInfo.attributes()).thenReturn(attributes); + attributes.forEach((key, value) -> when(fieldInfo.getAttribute(key)).thenReturn(value)); + return fieldInfo; + } + + private NativeEngineFieldVectorsWriter nativeEngineFieldVectorsWriter(FieldInfo fieldInfo, Map vectors) { + NativeEngineFieldVectorsWriter fieldVectorsWriter = mock(NativeEngineFieldVectorsWriter.class); + DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet(); + vectors.keySet().stream().sorted().forEach(docsWithFieldSet::add); + when(fieldVectorsWriter.getFieldInfo()).thenReturn(fieldInfo); + when(fieldVectorsWriter.getVectors()).thenReturn(vectors); + when(fieldVectorsWriter.getDocsWithField()).thenReturn(docsWithFieldSet); + return fieldVectorsWriter; + } +} diff --git a/src/test/java/org/opensearch/knn/index/vectorvalues/TestVectorValues.java b/src/test/java/org/opensearch/knn/index/vectorvalues/TestVectorValues.java index 3bf79b004..0f15d5240 100644 --- a/src/test/java/org/opensearch/knn/index/vectorvalues/TestVectorValues.java +++ b/src/test/java/org/opensearch/knn/index/vectorvalues/TestVectorValues.java @@ -184,7 +184,11 @@ public static class PreDefinedFloatVectorValues extends FloatVectorValues { public PreDefinedFloatVectorValues(final List vectors) { super(); this.count = vectors.size(); - this.dimension = vectors.get(0).length; + if (!vectors.isEmpty()) { + this.dimension = vectors.get(0).length; + } else { + this.dimension = 0; + } this.vectors = vectors; this.current = -1; vector = new float[dimension];