Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support script score when doc value is disabled #1573

Merged
merged 4 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
* Make the HitQueue size more appropriate for exact search [#1549](https://github.com/opensearch-project/k-NN/pull/1549)
* Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573)
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

package org.opensearch.knn.index;

import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.search.DocIdSetIterator;
import org.opensearch.index.fielddata.LeafFieldData;
import org.opensearch.index.fielddata.ScriptDocValues;
import org.opensearch.index.fielddata.SortedBinaryDocValues;
Expand Down Expand Up @@ -39,10 +40,29 @@
@Override
public ScriptDocValues<float[]> getScriptValues() {
try {
BinaryDocValues values = DocValues.getBinary(reader, fieldName);
return new KNNVectorScriptDocValues(values, fieldName, vectorDataType);
FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(fieldName);
if (fieldInfo == null) {
return KNNVectorScriptDocValues.emptyValues(fieldName, vectorDataType);
}

DocIdSetIterator values;
if (fieldInfo.hasVectorValues()) {
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32:
values = reader.getFloatVectorValues(fieldName);
break;
case BYTE:
values = reader.getByteVectorValues(fieldName);
break;
default:
throw new IllegalStateException("Unsupported Lucene vector encoding: " + fieldInfo.getVectorEncoding());

Check warning on line 58 in src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java#L58

Added line #L58 was not covered by tests
}
} else {
values = DocValues.getBinary(reader, fieldName);
}
return KNNVectorScriptDocValues.create(values, fieldName, vectorDataType);
} catch (IOException e) {
throw new IllegalStateException("Cannot load doc values for knn vector field: " + fieldName, e);
throw new IllegalStateException("Cannot load values for knn vector field: " + fieldName, e);

Check warning on line 65 in src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java#L65

Added line #L65 was not covered by tests
}
}

Expand Down
109 changes: 98 additions & 11 deletions src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,30 @@

package org.opensearch.knn.index;

import java.io.IOException;
import java.util.Objects;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.opensearch.ExceptionsHelper;
import org.opensearch.index.fielddata.ScriptDocValues;

import java.io.IOException;

@RequiredArgsConstructor
public final class KNNVectorScriptDocValues extends ScriptDocValues<float[]> {
@RequiredArgsConstructor(access = AccessLevel.PRIVATE)
public abstract class KNNVectorScriptDocValues extends ScriptDocValues<float[]> {

private final BinaryDocValues binaryDocValues;
private final DocIdSetIterator vectorValues;
private final String fieldName;
@Getter
private final VectorDataType vectorDataType;
private boolean docExists = false;

@Override
public void setNextDocId(int docId) throws IOException {
if (binaryDocValues.advanceExact(docId)) {
docExists = true;
return;
}
docExists = false;
docExists = vectorValues.docID() == docId || vectorValues.advance(docId) == docId;
}

public float[] getValue() {
Expand All @@ -43,12 +43,14 @@
throw new IllegalStateException(errorMessage);
}
try {
return vectorDataType.getVectorFromDocValues(binaryDocValues.binaryValue());
return doGetValue();
} catch (IOException e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
}

protected abstract float[] doGetValue() throws IOException;

@Override
public int size() {
return docExists ? 1 : 0;
Expand All @@ -58,4 +60,89 @@
public float[] get(int i) {
throw new UnsupportedOperationException("knn vector does not support this operation");
}

/**
* Creates a KNNVectorScriptDocValues object based on the provided parameters.
*
* @param values The DocIdSetIterator representing the vector values.
* @param fieldName The name of the field.
* @param vectorDataType The data type of the vector.
* @return A KNNVectorScriptDocValues object based on the type of the values.
* @throws IllegalArgumentException If the type of values is unsupported.
*/
public static KNNVectorScriptDocValues create(DocIdSetIterator values, String fieldName, VectorDataType vectorDataType) {
Objects.requireNonNull(values, "values must not be null");
if (values instanceof ByteVectorValues) {
return new KNNByteVectorScriptDocValues((ByteVectorValues) values, fieldName, vectorDataType);
} else if (values instanceof FloatVectorValues) {
return new KNNFloatVectorScriptDocValues((FloatVectorValues) values, fieldName, vectorDataType);
} else if (values instanceof BinaryDocValues) {
return new KNNNativeVectorScriptDocValues((BinaryDocValues) values, fieldName, vectorDataType);
} else {
throw new IllegalArgumentException("Unsupported values type: " + values.getClass());
}
}

private static final class KNNByteVectorScriptDocValues extends KNNVectorScriptDocValues {
private final ByteVectorValues values;

KNNByteVectorScriptDocValues(ByteVectorValues values, String field, VectorDataType type) {
super(values, field, type);
this.values = values;
}

@Override
protected float[] doGetValue() throws IOException {
byte[] bytes = values.vectorValue();
float[] value = new float[bytes.length];
for (int i = 0; i < bytes.length; i++) {
value[i] = (float) bytes[i];
}
return value;
}
}

private static final class KNNFloatVectorScriptDocValues extends KNNVectorScriptDocValues {
private final FloatVectorValues values;

KNNFloatVectorScriptDocValues(FloatVectorValues values, String field, VectorDataType type) {
super(values, field, type);
this.values = values;
}

@Override
protected float[] doGetValue() throws IOException {
return values.vectorValue();
}
}

private static final class KNNNativeVectorScriptDocValues extends KNNVectorScriptDocValues {
private final BinaryDocValues values;

KNNNativeVectorScriptDocValues(BinaryDocValues values, String field, VectorDataType type) {
super(values, field, type);
this.values = values;
}

@Override
protected float[] doGetValue() throws IOException {
return getVectorDataType().getVectorFromDocValues(values.binaryValue());
}
}

/**
* Creates an empty KNNVectorScriptDocValues object based on the provided field name and vector data type.
*
* @param fieldName The name of the field.
* @param type The data type of the vector.
* @return An empty KNNVectorScriptDocValues object.
*/
public static KNNVectorScriptDocValues emptyValues(String fieldName, VectorDataType type) {
return new KNNVectorScriptDocValues(DocIdSetIterator.empty(), fieldName, type) {
@Override
protected float[] doGetValue() throws IOException {
throw new UnsupportedOperationException("empty values");

Check warning on line 144 in src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java#L144

Added line #L144 was not covered by tests
}
};
}
}
6 changes: 3 additions & 3 deletions src/test/java/org/opensearch/knn/index/FaissIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ public void testEndToEnd_whenMethodIsHNSWFlat_thenSucceed() {

List<Float> actualScores = parseSearchResponseScore(responseBody, fieldName);
for (int j = 0; j < k; j++) {
float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList()));
float[] primitiveArray = knnResults.get(j).getVector();
assertEquals(
KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType),
actualScores.get(j),
Expand Down Expand Up @@ -258,7 +258,7 @@ public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() {

List<Float> actualScores = parseSearchResponseScore(responseBody, fieldName);
for (int j = 0; j < k; j++) {
float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList()));
float[] primitiveArray = knnResults.get(j).getVector();
assertEquals(
KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType),
actualScores.get(j),
Expand Down Expand Up @@ -828,7 +828,7 @@ public void testEndToEnd_whenMethodIsHNSWPQAndHyperParametersNotSet_thenSucceed(

List<Float> actualScores = parseSearchResponseScore(responseBody, fieldName);
for (int j = 0; j < k; j++) {
float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList()));
float[] primitiveArray = knnResults.get(j).getVector();
assertEquals(
KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType),
actualScores.get(j),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@

package org.opensearch.knn.index;

import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.search.DocIdSetIterator;
import org.opensearch.knn.KNNTestCase;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.document.BinaryDocValuesField;
Expand All @@ -13,7 +22,6 @@
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.store.Directory;
import org.junit.Assert;
import org.junit.Before;
Expand All @@ -24,6 +32,7 @@ public class KNNVectorScriptDocValuesTests extends KNNTestCase {

private static final String MOCK_INDEX_FIELD_NAME = "test-index-field-name";
private static final float[] SAMPLE_VECTOR_DATA = new float[] { 1.0f, 2.0f };
private static final byte[] SAMPLE_BYTE_VECTOR_DATA = new byte[] { 1, 2 };
private KNNVectorScriptDocValues scriptDocValues;
private Directory directory;
private DirectoryReader reader;
Expand All @@ -32,26 +41,39 @@ public class KNNVectorScriptDocValuesTests extends KNNTestCase {
public void setUp() throws Exception {
super.setUp();
directory = newDirectory();
createKNNVectorDocument(directory);
Class<? extends DocIdSetIterator> valuesClass = randomFrom(BinaryDocValues.class, ByteVectorValues.class, FloatVectorValues.class);
createKNNVectorDocument(directory, valuesClass);
reader = DirectoryReader.open(directory);
LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0);
scriptDocValues = new KNNVectorScriptDocValues(
leafReaderContext.reader().getBinaryDocValues(MOCK_INDEX_FIELD_NAME),
MOCK_INDEX_FIELD_NAME,
VectorDataType.FLOAT
);
LeafReader leafReader = reader.getContext().leaves().get(0).reader();
DocIdSetIterator vectorValues;
if (BinaryDocValues.class.equals(valuesClass)) {
vectorValues = DocValues.getBinary(leafReader, MOCK_INDEX_FIELD_NAME);
} else if (ByteVectorValues.class.equals(valuesClass)) {
vectorValues = leafReader.getByteVectorValues(MOCK_INDEX_FIELD_NAME);
} else {
vectorValues = leafReader.getFloatVectorValues(MOCK_INDEX_FIELD_NAME);
}

scriptDocValues = KNNVectorScriptDocValues.create(vectorValues, MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT);
}

private void createKNNVectorDocument(Directory directory) throws IOException {
private void createKNNVectorDocument(Directory directory, Class<? extends DocIdSetIterator> valuesClass) throws IOException {
IndexWriterConfig conf = newIndexWriterConfig(new MockAnalyzer(random()));
IndexWriter writer = new IndexWriter(directory, conf);
Document knnDocument = new Document();
knnDocument.add(
new BinaryDocValuesField(
Field field;
if (BinaryDocValues.class.equals(valuesClass)) {
field = new BinaryDocValuesField(
MOCK_INDEX_FIELD_NAME,
new VectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA, new FieldType()).binaryValue()
)
);
);
} else if (ByteVectorValues.class.equals(valuesClass)) {
field = new KnnByteVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_BYTE_VECTOR_DATA);
} else {
field = new KnnFloatVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA);
}

knnDocument.add(field);
writer.addDocument(knnDocument);
writer.commit();
writer.close();
Expand Down Expand Up @@ -83,4 +105,18 @@ public void testSize() throws IOException {
public void testGet() throws IOException {
expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0));
}

public void testUnsupportedValues() throws IOException {
expectThrows(
IllegalArgumentException.class,
() -> KNNVectorScriptDocValues.create(DocValues.emptyNumeric(), MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT)
);
}

public void testEmptyValues() throws IOException {
KNNVectorScriptDocValues values = KNNVectorScriptDocValues.emptyValues(MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT);
assertEquals(0, values.size());
scriptDocValues.setNextDocId(0);
assertEquals(0, values.size());
}
}
9 changes: 4 additions & 5 deletions src/test/java/org/opensearch/knn/index/LuceneEngineIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Floats;
import lombok.SneakyThrows;
import org.apache.commons.lang.math.RandomUtils;
import org.apache.hc.core5.http.io.entity.EntityUtils;
Expand Down Expand Up @@ -307,14 +306,14 @@ public void testIndexReopening() throws Exception {
final float[] searchVector = TEST_QUERY_VECTORS[0];
final int k = 1 + RandomUtils.nextInt(TEST_INDEX_VECTORS.length);

final List<Float[]> knnResultsBeforeIndexClosure = queryResults(searchVector, k);
final List<float[]> knnResultsBeforeIndexClosure = queryResults(searchVector, k);

closeIndex(INDEX_NAME);
openIndex(INDEX_NAME);

ensureGreen(INDEX_NAME);

final List<Float[]> knnResultsAfterIndexClosure = queryResults(searchVector, k);
final List<float[]> knnResultsAfterIndexClosure = queryResults(searchVector, k);

assertArrayEquals(knnResultsBeforeIndexClosure.toArray(), knnResultsAfterIndexClosure.toArray());
}
Expand Down Expand Up @@ -365,15 +364,15 @@ private void validateQueries(SpaceType spaceType, String fieldName) throws Excep

List<Float> actualScores = parseSearchResponseScore(responseBody, fieldName);
for (int j = 0; j < k; j++) {
float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList()));
float[] primitiveArray = knnResults.get(j).getVector();
float distance = TestUtils.computeDistFromSpaceType(spaceType, primitiveArray, queryVector);
float rawScore = VECTOR_SIMILARITY_TO_SCORE.get(spaceType.getVectorSimilarityFunction()).apply(distance);
assertEquals(KNNEngine.LUCENE.score(rawScore, spaceType), actualScores.get(j), 0.0001);
}
}
}

private List<Float[]> queryResults(final float[] searchVector, final int k) throws Exception {
private List<float[]> queryResults(final float[] searchVector, final int k) throws Exception {
final String responseBody = EntityUtils.toString(
searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, searchVector, k), k).getEntity()
);
Expand Down
4 changes: 1 addition & 3 deletions src/test/java/org/opensearch/knn/index/NmslibIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,9 @@

import java.io.IOException;
import java.net.URL;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.stream.Collectors;

import static org.hamcrest.Matchers.containsString;

Expand Down Expand Up @@ -115,7 +113,7 @@ public void testEndToEnd() throws Exception {

List<Float> actualScores = parseSearchResponseScore(responseBody, fieldName);
for (int j = 0; j < k; j++) {
float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList()));
float[] primitiveArray = knnResults.get(j).getVector();
assertEquals(
KNNEngine.NMSLIB.score(KNNScoringUtil.l1Norm(testData.queries[i], primitiveArray), spaceType),
actualScores.get(j),
Expand Down
Loading
Loading