Skip to content

Commit

Permalink
Fix remaining leaked SearchResponse issues in :server:test
Browse files Browse the repository at this point in the history
Same as elastic#102896, handling almost all of the remaining spots (just a handful of tricky ones left that I'll
open a separate PR for).
  • Loading branch information
original-brownbear committed Dec 2, 2023
1 parent 76a6dd6 commit 9cd7c9a
Show file tree
Hide file tree
Showing 14 changed files with 719 additions and 565 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.List;

import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;

Expand Down Expand Up @@ -62,15 +63,17 @@ public void testKnnSearchRemovedVector() throws IOException {

float[] queryVector = randomVector();
KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 20, 50, null).boost(5.0f);
SearchResponse response = client().prepareSearch("index")
.setKnnSearch(List.of(knnSearch))
.setQuery(QueryBuilders.matchQuery("text", "goodnight"))
.setSize(10)
.get();

// Originally indexed 20 documents, but deleted vector field with an update, so only 19 should be hit
assertHitCount(response, 19);
assertEquals(10, response.getHits().getHits().length);
assertResponse(
client().prepareSearch("index")
.setKnnSearch(List.of(knnSearch))
.setQuery(QueryBuilders.matchQuery("text", "goodnight"))
.setSize(10),
response -> {
// Originally indexed 20 documents, but deleted vector field with an update, so only 19 should be hit
assertHitCount(response, 19);
assertEquals(10, response.getHits().getHits().length);
}
);
// Make sure we still have 20 docs
assertHitCount(client().prepareSearch("index").setSize(0).setTrackTotalHits(true), 20);
}
Expand Down Expand Up @@ -104,19 +107,22 @@ public void testKnnWithQuery() throws IOException {

float[] queryVector = randomVector();
KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null).boost(5.0f);
SearchResponse response = client().prepareSearch("index")
.setKnnSearch(List.of(knnSearch))
.setQuery(QueryBuilders.matchQuery("text", "goodnight"))
.addFetchField("*")
.setSize(10)
.get();

// The total hits is k plus the number of text matches
assertHitCount(response, 15);
assertEquals(10, response.getHits().getHits().length);

// Because of the boost, vector results should appear first
assertNotNull(response.getHits().getAt(0).field("vector"));
assertResponse(
client().prepareSearch("index")
.setKnnSearch(List.of(knnSearch))
.setQuery(QueryBuilders.matchQuery("text", "goodnight"))
.addFetchField("*")
.setSize(10),
response -> {

// The total hits is k plus the number of text matches
assertHitCount(response, 15);
assertEquals(10, response.getHits().getHits().length);

// Because of the boost, vector results should appear first
assertNotNull(response.getHits().getAt(0).field("vector"));
}
);
}

public void testKnnFilter() throws IOException {
Expand Down Expand Up @@ -150,13 +156,13 @@ public void testKnnFilter() throws IOException {
KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null).addFilterQuery(
QueryBuilders.termsQuery("field", "second")
);
SearchResponse response = client().prepareSearch("index").setKnnSearch(List.of(knnSearch)).addFetchField("*").setSize(10).get();

assertHitCount(response, 5);
assertEquals(5, response.getHits().getHits().length);
for (SearchHit hit : response.getHits().getHits()) {
assertEquals("second", hit.field("field").getValue());
}
assertResponse(client().prepareSearch("index").setKnnSearch(List.of(knnSearch)).addFetchField("*").setSize(10), response -> {
assertHitCount(response, 5);
assertEquals(5, response.getHits().getHits().length);
for (SearchHit hit : response.getHits().getHits()) {
assertEquals("second", hit.field("field").getValue());
}
});
}

public void testKnnFilterWithRewrite() throws IOException {
Expand Down Expand Up @@ -193,10 +199,10 @@ public void testKnnFilterWithRewrite() throws IOException {
KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null).addFilterQuery(
QueryBuilders.termsLookupQuery("field", new TermsLookup("index", "lookup-doc", "other-field"))
);
SearchResponse response = client().prepareSearch("index").setKnnSearch(List.of(knnSearch)).setSize(10).get();

assertHitCount(response, 5);
assertEquals(5, response.getHits().getHits().length);
assertResponse(client().prepareSearch("index").setKnnSearch(List.of(knnSearch)).setSize(10), response -> {
assertHitCount(response, 5);
assertEquals(5, response.getHits().getHits().length);
});
}

public void testMultiKnnClauses() throws IOException {
Expand Down Expand Up @@ -239,26 +245,29 @@ public void testMultiKnnClauses() throws IOException {
float[] queryVector = randomVector(20f, 21f);
KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null).boost(5.0f);
KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50, null).boost(10.0f);
SearchResponse response = client().prepareSearch("index")
.setKnnSearch(List.of(knnSearch, knnSearch2))
.setQuery(QueryBuilders.matchQuery("text", "goodnight"))
.addFetchField("*")
.setSize(10)
.addAggregation(AggregationBuilders.stats("stats").field("number"))
.get();

// The total hits is k plus the number of text matches
assertHitCount(response, 20);
assertEquals(10, response.getHits().getHits().length);
InternalStats agg = response.getAggregations().get("stats");
assertThat(agg.getCount(), equalTo(20L));
assertThat(agg.getMax(), equalTo(3.0));
assertThat(agg.getMin(), equalTo(1.0));
assertThat(agg.getAvg(), equalTo(2.25));
assertThat(agg.getSum(), equalTo(45.0));

// Because of the boost & vector distributions, vector_2 results should appear first
assertNotNull(response.getHits().getAt(0).field("vector_2"));
assertResponse(
client().prepareSearch("index")
.setKnnSearch(List.of(knnSearch, knnSearch2))
.setQuery(QueryBuilders.matchQuery("text", "goodnight"))
.addFetchField("*")
.setSize(10)
.addAggregation(AggregationBuilders.stats("stats").field("number")),
response -> {

// The total hits is k plus the number of text matches
assertHitCount(response, 20);
assertEquals(10, response.getHits().getHits().length);
InternalStats agg = response.getAggregations().get("stats");
assertThat(agg.getCount(), equalTo(20L));
assertThat(agg.getMax(), equalTo(3.0));
assertThat(agg.getMin(), equalTo(1.0));
assertThat(agg.getAvg(), equalTo(2.25));
assertThat(agg.getSum(), equalTo(45.0));

// Because of the boost & vector distributions, vector_2 results should appear first
assertNotNull(response.getHits().getAt(0).field("vector_2"));
}
);
}

public void testMultiKnnClausesSameDoc() throws IOException {
Expand Down Expand Up @@ -298,38 +307,42 @@ public void testMultiKnnClausesSameDoc() throws IOException {
// Having the same query vector and same docs should mean our KNN scores are linearly combined if the same doc is matched
KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null);
KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50, null);
SearchResponse responseOneKnn = client().prepareSearch("index")
.setKnnSearch(List.of(knnSearch))
.addFetchField("*")
.setSize(10)
.addAggregation(AggregationBuilders.stats("stats").field("number"))
.get();
SearchResponse responseBothKnn = client().prepareSearch("index")
.setKnnSearch(List.of(knnSearch, knnSearch2))
.addFetchField("*")
.setSize(10)
.addAggregation(AggregationBuilders.stats("stats").field("number"))
.get();

// The total hits is k matched docs
assertHitCount(responseOneKnn, 5);
assertHitCount(responseBothKnn, 5);
assertEquals(5, responseOneKnn.getHits().getHits().length);
assertEquals(5, responseBothKnn.getHits().getHits().length);

for (int i = 0; i < responseOneKnn.getHits().getHits().length; i++) {
SearchHit oneHit = responseOneKnn.getHits().getHits()[i];
SearchHit bothHit = responseBothKnn.getHits().getHits()[i];
assertThat(bothHit.getId(), equalTo(oneHit.getId()));
assertThat(bothHit.getScore(), greaterThan(oneHit.getScore()));
}
InternalStats oneAgg = responseOneKnn.getAggregations().get("stats");
InternalStats bothAgg = responseBothKnn.getAggregations().get("stats");
assertThat(bothAgg.getCount(), equalTo(oneAgg.getCount()));
assertThat(bothAgg.getAvg(), equalTo(oneAgg.getAvg()));
assertThat(bothAgg.getMax(), equalTo(oneAgg.getMax()));
assertThat(bothAgg.getSum(), equalTo(oneAgg.getSum()));
assertThat(bothAgg.getMin(), equalTo(oneAgg.getMin()));
assertResponse(
client().prepareSearch("index")
.setKnnSearch(List.of(knnSearch))
.addFetchField("*")
.setSize(10)
.addAggregation(AggregationBuilders.stats("stats").field("number")),
responseOneKnn -> assertResponse(
client().prepareSearch("index")
.setKnnSearch(List.of(knnSearch, knnSearch2))
.addFetchField("*")
.setSize(10)
.addAggregation(AggregationBuilders.stats("stats").field("number")),
responseBothKnn -> {

// The total hits is k matched docs
assertHitCount(responseOneKnn, 5);
assertHitCount(responseBothKnn, 5);
assertEquals(5, responseOneKnn.getHits().getHits().length);
assertEquals(5, responseBothKnn.getHits().getHits().length);

for (int i = 0; i < responseOneKnn.getHits().getHits().length; i++) {
SearchHit oneHit = responseOneKnn.getHits().getHits()[i];
SearchHit bothHit = responseBothKnn.getHits().getHits()[i];
assertThat(bothHit.getId(), equalTo(oneHit.getId()));
assertThat(bothHit.getScore(), greaterThan(oneHit.getScore()));
}
InternalStats oneAgg = responseOneKnn.getAggregations().get("stats");
InternalStats bothAgg = responseBothKnn.getAggregations().get("stats");
assertThat(bothAgg.getCount(), equalTo(oneAgg.getCount()));
assertThat(bothAgg.getAvg(), equalTo(oneAgg.getAvg()));
assertThat(bothAgg.getMax(), equalTo(oneAgg.getMax()));
assertThat(bothAgg.getSum(), equalTo(oneAgg.getSum()));
assertThat(bothAgg.getMin(), equalTo(oneAgg.getMin()));
}
)
);
}

public void testKnnFilteredAlias() throws IOException {
Expand Down Expand Up @@ -366,10 +379,11 @@ public void testKnnFilteredAlias() throws IOException {

float[] queryVector = randomVector();
KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 10, 50, null);
SearchResponse response = client().prepareSearch("test-alias").setKnnSearch(List.of(knnSearch)).setSize(10).get();

assertHitCount(response, expectedHits);
assertEquals(expectedHits, response.getHits().getHits().length);
final int expectedHitCount = expectedHits;
assertResponse(client().prepareSearch("test-alias").setKnnSearch(List.of(knnSearch)).setSize(10), response -> {
assertHitCount(response, expectedHitCount);
assertEquals(expectedHitCount, response.getHits().getHits().length);
});
}

public void testKnnSearchAction() throws IOException {
Expand Down Expand Up @@ -399,14 +413,14 @@ public void testKnnSearchAction() throws IOException {
// Since there's no kNN search action at the transport layer, we just emulate
// how the action works (it builds a kNN query under the hood)
float[] queryVector = randomVector();
SearchResponse response = client().prepareSearch("index1", "index2")
.setQuery(new KnnVectorQueryBuilder("vector", queryVector, 5, null))
.setSize(2)
.get();

// The total hits is num_cands * num_shards, since the query gathers num_cands hits from each shard
assertHitCount(response, 5 * 2);
assertEquals(2, response.getHits().getHits().length);
assertResponse(
client().prepareSearch("index1", "index2").setQuery(new KnnVectorQueryBuilder("vector", queryVector, 5, null)).setSize(2),
response -> {
// The total hits is num_cands * num_shards, since the query gathers num_cands hits from each shard
assertHitCount(response, 5 * 2);
assertEquals(2, response.getHits().getHits().length);
}
);
}

public void testKnnVectorsWith4096Dims() throws IOException {
Expand Down Expand Up @@ -434,11 +448,11 @@ public void testKnnVectorsWith4096Dims() throws IOException {

float[] queryVector = randomVector(4096);
KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 3, 50, null).boost(5.0f);
SearchResponse response = client().prepareSearch("index").setKnnSearch(List.of(knnSearch)).addFetchField("*").setSize(10).get();

assertHitCount(response, 3);
assertEquals(3, response.getHits().getHits().length);
assertEquals(4096, response.getHits().getAt(0).field("vector").getValues().size());
assertResponse(client().prepareSearch("index").setKnnSearch(List.of(knnSearch)).addFetchField("*").setSize(10), response -> {
assertHitCount(response, 3);
assertEquals(3, response.getHits().getHits().length);
assertEquals(4096, response.getHits().getAt(0).field("vector").getValues().size());
});
}

private float[] randomVector() {
Expand Down
Loading

0 comments on commit 9cd7c9a

Please sign in to comment.