Skip to content

Commit

Permalink
Adds rescore parameter in neural search
Browse files Browse the repository at this point in the history
This is required to enable rescoring for on disk mode indices

Signed-off-by: Tejas Shah <shatejas@amazon.com>
  • Loading branch information
shatejas committed Sep 4, 2024
1 parent b8e2b35 commit afdc8c3
Show file tree
Hide file tree
Showing 14 changed files with 136 additions and 38 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.16...2.x)
### Features
### Enhancements
- Adds rescore parameter support ([#885](https://github.com/opensearch-project/neural-search/pull/885))
### Bug Fixes
- Fixed merge logic in hybrid query for multiple shards case ([#877](https://github.com/opensearch-project/neural-search/pull/877))
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import static org.opensearch.neuralsearch.util.TestUtils.TEXT_EMBEDDING_PROCESSOR;
import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_NORMALIZATION_METHOD;
import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_COMBINATION_METHOD;

import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;

Expand Down Expand Up @@ -69,8 +71,10 @@ private void validateNormalizationProcessor(final String fileName, final String
modelId = getModelId(getIngestionPipeline(pipelineName), TEXT_EMBEDDING_PROCESSOR);
loadModel(modelId);
addDocuments(getIndexNameForTest(), false);
validateTestIndex(modelId, getIndexNameForTest(), searchPipelineName);
validateTestIndex(modelId, getIndexNameForTest(), searchPipelineName, Map.of("ef_search", 100));
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
validateTestIndex(getIndexNameForTest(), searchPipelineName, hybridQueryBuilder);
hybridQueryBuilder = getQueryBuilder(modelId, Map.of("ef_search", 100), RescoreContext.getDefault());
validateTestIndex(getIndexNameForTest(), searchPipelineName, hybridQueryBuilder);
} finally {
wipeOfTestResources(getIndexNameForTest(), pipelineName, modelId, searchPipelineName);
}
Expand Down Expand Up @@ -98,15 +102,10 @@ private void createSearchPipeline(final String pipelineName) {
);
}

private void validateTestIndex(final String modelId, final String index, final String searchPipeline) {
validateTestIndex(modelId, index, searchPipeline, null);
}

private void validateTestIndex(final String modelId, final String index, final String searchPipeline, Map<String, ?> methodParameters) {
private void validateTestIndex(final String index, final String searchPipeline, HybridQueryBuilder queryBuilder) {
int docCount = getDocCount(index);
assertEquals(6, docCount);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, methodParameters);
Map<String, Object> searchResponseAsMap = search(index, hybridQueryBuilder, null, 1, Map.of("search_pipeline", searchPipeline));
Map<String, Object> searchResponseAsMap = search(index, queryBuilder, null, 1, Map.of("search_pipeline", searchPipeline));
assertNotNull(searchResponseAsMap);
int hits = getHitCount(searchResponseAsMap);
assertEquals(1, hits);
Expand All @@ -116,7 +115,7 @@ private void validateTestIndex(final String modelId, final String index, final S
}
}

private HybridQueryBuilder getQueryBuilder(final String modelId, Map<String, ?> methodParameters) {
private HybridQueryBuilder getQueryBuilder(final String modelId, Map<String, ?> methodParameters, RescoreContext rescoreContext) {
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder();
neuralQueryBuilder.fieldName("passage_embedding");
neuralQueryBuilder.modelId(modelId);
Expand All @@ -125,6 +124,9 @@ private HybridQueryBuilder getQueryBuilder(final String modelId, Map<String, ?>
if (methodParameters != null) {
neuralQueryBuilder.methodParameters(methodParameters);
}
if (rescoreContext != null) {
neuralQueryBuilder.rescoreContext(rescoreContext);
}

MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("text", QUERY);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ private void validateIndexQuery(final String modelId) {
0.01f,
null,
null,
null,
null
);
Map<String, Object> responseWithMinScoreQuery = search(getIndexNameForTest(), neuralQueryBuilderWithMinScoreQuery, 1);
Expand All @@ -76,6 +77,7 @@ private void validateIndexQuery(final String modelId) {
null,
null,
null,
null,
null
);
Map<String, Object> responseWithMaxDistanceQuery = search(getIndexNameForTest(), neuralQueryBuilderWithMaxDistanceQuery, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ private void validateTestIndex(final String modelId) throws Exception {
null,
null,
null,
null,
null
);
Map<String, Object> response = search(getIndexNameForTest(), neuralQueryBuilder, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_NORMALIZATION_METHOD;
import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_COMBINATION_METHOD;
import static org.opensearch.neuralsearch.util.TestUtils.getModelId;

import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;

Expand Down Expand Up @@ -59,11 +61,13 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
int totalDocsCountMixed;
if (isFirstMixedRound()) {
totalDocsCountMixed = NUM_DOCS_PER_ROUND;
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder);
addDocument(getIndexNameForTest(), "1", TEST_FIELD, TEXT_MIXED, null, null);
} else {
totalDocsCountMixed = 2 * NUM_DOCS_PER_ROUND;
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder);
}
break;
case UPGRADED:
Expand All @@ -72,8 +76,10 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
int totalDocsCountUpgraded = 3 * NUM_DOCS_PER_ROUND;
loadModel(modelId);
addDocument(getIndexNameForTest(), "2", TEST_FIELD, TEXT_UPGRADED, null, null);
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId);
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, Map.of("ef_search", 100));
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder);
hybridQueryBuilder = getQueryBuilder(modelId, Map.of("ef_search", 100), RescoreContext.getDefault());
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder);
} finally {
wipeOfTestResources(getIndexNameForTest(), PIPELINE_NAME, modelId, SEARCH_PIPELINE_NAME);
}
Expand All @@ -83,16 +89,11 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
}
}

private void validateTestIndexOnUpgrade(final int numberOfDocs, final String modelId) throws Exception {
validateTestIndexOnUpgrade(numberOfDocs, modelId, null);
}

private void validateTestIndexOnUpgrade(final int numberOfDocs, final String modelId, Map<String, ?> methodParameters)
private void validateTestIndexOnUpgrade(final int numberOfDocs, final String modelId, HybridQueryBuilder hybridQueryBuilder)
throws Exception {
int docCount = getDocCount(getIndexNameForTest());
assertEquals(numberOfDocs, docCount);
loadModel(modelId);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, methodParameters);
Map<String, Object> searchResponseAsMap = search(
getIndexNameForTest(),
hybridQueryBuilder,
Expand All @@ -109,7 +110,11 @@ private void validateTestIndexOnUpgrade(final int numberOfDocs, final String mod
}
}

private HybridQueryBuilder getQueryBuilder(final String modelId, final Map<String, ?> methodParameters) {
private HybridQueryBuilder getQueryBuilder(
final String modelId,
final Map<String, ?> methodParameters,
final RescoreContext rescoreContext
) {
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder();
neuralQueryBuilder.fieldName("passage_embedding");
neuralQueryBuilder.modelId(modelId);
Expand All @@ -118,6 +123,9 @@ private HybridQueryBuilder getQueryBuilder(final String modelId, final Map<Strin
if (methodParameters != null) {
neuralQueryBuilder.methodParameters(methodParameters);
}
if (rescoreContext != null) {
neuralQueryBuilder.rescoreContext(rescoreContext);
}

MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("text", QUERY);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ private void validateIndexQueryOnUpgrade(final int numberOfDocs, final String mo
0.01f,
null,
null,
null,
null
);
Map<String, Object> responseWithMinScore = search(getIndexNameForTest(), neuralQueryBuilderWithMinScoreQuery, 1);
Expand All @@ -102,6 +103,7 @@ private void validateIndexQueryOnUpgrade(final int numberOfDocs, final String mo
null,
null,
null,
null,
null
);
Map<String, Object> responseWithMaxScore = search(getIndexNameForTest(), neuralQueryBuilderWithMaxDistanceQuery, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ private void validateTestIndexOnUpgrade(final int numberOfDocs, final String mod
null,
null,
null,
null,
null
);
Map<String, Object> responseWithKQuery = search(getIndexNameForTest(), neuralQueryBuilderWithKQuery, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static org.opensearch.knn.index.query.KNNQueryBuilder.MAX_DISTANCE_FIELD;
import static org.opensearch.knn.index.query.KNNQueryBuilder.METHOD_PARAMS_FIELD;
import static org.opensearch.knn.index.query.KNNQueryBuilder.MIN_SCORE_FIELD;
import static org.opensearch.knn.index.query.KNNQueryBuilder.RESCORE_FIELD;
import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersion;
import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport;
import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForRadialSearch;
Expand Down Expand Up @@ -40,6 +41,8 @@
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.query.parser.MethodParametersParser;
import org.opensearch.knn.index.query.parser.RescoreParser;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.neuralsearch.common.MinClusterVersionUtil;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;

Expand Down Expand Up @@ -101,6 +104,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) {
private Supplier<float[]> vectorSupplier;
private QueryBuilder filter;
private Map<String, ?> methodParameters;
private RescoreContext rescoreContext;

/**
* Constructor from stream input
Expand Down Expand Up @@ -131,6 +135,7 @@ public NeuralQueryBuilder(StreamInput in) throws IOException {
if (isClusterOnOrAfterMinReqVersion(METHOD_PARAMS_FIELD.getPreferredName())) {
this.methodParameters = MethodParametersParser.streamInput(in, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion);
}
this.rescoreContext = RescoreParser.streamInput(in);
}

@Override
Expand All @@ -156,6 +161,7 @@ protected void doWriteTo(StreamOutput out) throws IOException {
if (isClusterOnOrAfterMinReqVersion(METHOD_PARAMS_FIELD.getPreferredName())) {
MethodParametersParser.streamOutput(out, methodParameters, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion);
}
RescoreParser.streamOutput(out, rescoreContext);
}

@Override
Expand All @@ -181,6 +187,9 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
if (Objects.nonNull(methodParameters)) {
MethodParametersParser.doXContent(xContentBuilder, methodParameters);
}
if (Objects.nonNull(rescoreContext)) {
RescoreParser.doXContent(xContentBuilder, rescoreContext);
}
printBoostAndQueryName(xContentBuilder);
xContentBuilder.endObject();
xContentBuilder.endObject();
Expand Down Expand Up @@ -276,6 +285,8 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n
neuralQueryBuilder.filter(parseInnerQueryBuilder(parser));
} else if (METHOD_PARAMS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
neuralQueryBuilder.methodParameters(MethodParametersParser.fromXContent(parser));
} else if (RESCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
neuralQueryBuilder.rescoreContext(RescoreParser.fromXContent(parser));
}
} else {
throw new ParsingException(
Expand Down Expand Up @@ -308,6 +319,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
.maxDistance(maxDistance)
.minScore(minScore)
.k(k)
.methodParameters(methodParameters)
.rescoreContext(rescoreContext)
.build();
}

Expand Down Expand Up @@ -335,7 +348,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
minScore(),
vectorSetOnce::get,
filter(),
methodParameters()
methodParameters(),
rescoreContext()
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ public void testResultProcessor_whenOneShardAndQueryMatches_thenSuccessful() {
null,
null,
null,
null,
null
);
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
Expand Down Expand Up @@ -148,6 +149,7 @@ public void testResultProcessor_whenDefaultProcessorConfigAndQueryMatches_thenSu
null,
null,
null,
null,
null
);
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
Expand Down Expand Up @@ -188,6 +190,7 @@ public void testQueryMatches_whenMultipleShards_thenSuccessful() {
null,
null,
null,
null,
null
);
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf

HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder();
hybridQueryBuilderDefaultNorm.add(
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null)
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null, null)
);
hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));

Expand All @@ -249,7 +249,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf

HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder();
hybridQueryBuilderL2Norm.add(
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null)
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null, null)
);
hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));

Expand Down Expand Up @@ -299,7 +299,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess

HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder();
hybridQueryBuilderDefaultNorm.add(
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null)
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null, null)
);
hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));

Expand All @@ -324,7 +324,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess

HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder();
hybridQueryBuilderL2Norm.add(
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null)
new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null, null, null, null, null)
);
hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));

Expand Down
Loading

0 comments on commit afdc8c3

Please sign in to comment.