Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SearchResponse reference count leaks in ML module #103009

Merged
merged 4 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.search;

import org.elasticsearch.action.search.SearchRequestBuilder;

public enum SearchResponseUtils {
;

public static long getTotalHitsValue(SearchRequestBuilder request) {
var resp = request.get();
try {
return resp.getHits().getTotalHits().value;
} finally {
resp.decRef();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ private void assertExecutionWithOrigin(Map<String, String> storedHeaders, Client
assertThat(headers, not(hasEntry(AuthenticationServiceField.RUN_AS_USER_HEADER, "anything")));

return client.search(new SearchRequest()).actionGet();
});
}).decRef();
}

/**
Expand All @@ -356,7 +356,7 @@ public void assertRunAsExecution(Map<String, String> storedHeaders, Consumer<Map

consumer.accept(client.threadPool().getThreadContext().getHeaders());
return client.search(new SearchRequest()).actionGet();
});
}).decRef();
}

public void testFilterSecurityHeaders() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1567,36 +1567,38 @@ public void testFeatureImportanceValues() throws Exception {

client().admin().indices().refresh(new RefreshRequest(destIndex));
SearchResponse sourceData = prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();

// obtain addition information for investigation of #90599
String modelId = getModelId(jobId);
TrainedModelMetadata modelMetadata = getModelMetadata(modelId);
assertThat(modelMetadata.getHyperparameters().size(), greaterThan(0));
StringBuilder hyperparameters = new StringBuilder(); // used to investigate #90019
for (Hyperparameters hyperparameter : modelMetadata.getHyperparameters()) {
hyperparameters.append(hyperparameter.hyperparameterName).append(": ").append(hyperparameter.value).append("\n");
}
TrainedModelDefinition modelDefinition = getModelDefinition(modelId);
Ensemble ensemble = (Ensemble) modelDefinition.getTrainedModel();
int numberTrees = ensemble.getModels().size();
String str = "Failure: failed for modelId %s numberTrees %d\n";
for (SearchHit hit : sourceData.getHits()) {
Map<String, Object> destDoc = getDestDoc(config, hit);
assertNotNull(destDoc);
Map<String, Object> resultsObject = getFieldValue(destDoc, "ml");
assertThat(resultsObject.containsKey(predictionField), is(true));
String predictionValue = (String) resultsObject.get(predictionField);
assertNotNull(predictionValue);
assertThat(resultsObject.containsKey("feature_importance"), is(true));
@SuppressWarnings("unchecked")
List<Map<String, Object>> importanceArray = (List<Map<String, Object>>) resultsObject.get("feature_importance");
assertThat(
Strings.format(str, modelId, numberTrees) + predictionValue + hyperparameters + modelDefinition,
importanceArray,
hasSize(greaterThan(0))
);
try {
// obtain addition information for investigation of #90599
String modelId = getModelId(jobId);
TrainedModelMetadata modelMetadata = getModelMetadata(modelId);
assertThat(modelMetadata.getHyperparameters().size(), greaterThan(0));
StringBuilder hyperparameters = new StringBuilder(); // used to investigate #90019
for (Hyperparameters hyperparameter : modelMetadata.getHyperparameters()) {
hyperparameters.append(hyperparameter.hyperparameterName).append(": ").append(hyperparameter.value).append("\n");
}
TrainedModelDefinition modelDefinition = getModelDefinition(modelId);
Ensemble ensemble = (Ensemble) modelDefinition.getTrainedModel();
int numberTrees = ensemble.getModels().size();
String str = "Failure: failed for modelId %s numberTrees %d\n";
for (SearchHit hit : sourceData.getHits()) {
Map<String, Object> destDoc = getDestDoc(config, hit);
assertNotNull(destDoc);
Map<String, Object> resultsObject = getFieldValue(destDoc, "ml");
assertThat(resultsObject.containsKey(predictionField), is(true));
String predictionValue = (String) resultsObject.get(predictionField);
assertNotNull(predictionValue);
assertThat(resultsObject.containsKey("feature_importance"), is(true));
@SuppressWarnings("unchecked")
List<Map<String, Object>> importanceArray = (List<Map<String, Object>>) resultsObject.get("feature_importance");
assertThat(
Strings.format(str, modelId, numberTrees) + predictionValue + hyperparameters + modelDefinition,
importanceArray,
hasSize(greaterThan(0))
);
}
} finally {
sourceData.decRef();
}

}

static void indexData(String sourceIndex) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import java.util.List;
import java.util.concurrent.TimeUnit;

import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
Expand Down Expand Up @@ -163,16 +164,16 @@ private void testDfWithAggs(AggregatorFactories.Builder aggs, Detector.Builder d
bucket.getEventCount()
);
// Confirm that it's possible to search for the same buckets by @timestamp - proves that @timestamp works as a field alias
assertThat(
assertHitCount(
prepareSearch(AnomalyDetectorsIndex.jobResultsAliasedName(jobId)).setQuery(
QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery("job_id", jobId))
.filter(QueryBuilders.termQuery("result_type", "bucket"))
.filter(
QueryBuilders.rangeQuery("@timestamp").gte(bucket.getTimestamp().getTime()).lte(bucket.getTimestamp().getTime())
)
).setTrackTotalHits(true).get().getHits().getTotalHits().value,
equalTo(1L)
).setTrackTotalHits(true),
1
);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchResponseUtils;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
Expand Down Expand Up @@ -268,14 +269,13 @@ private void testExpiredDeletion(Float customThrottle, int numUnusedState) throw

retainAllSnapshots("snapshots-retention-with-retain");

long totalModelSizeStatsBeforeDelete = prepareSearch("*").setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN_CLOSED_HIDDEN)
.setQuery(QueryBuilders.termQuery("result_type", "model_size_stats"))
.get()
.getHits()
.getTotalHits().value;
long totalNotificationsCountBeforeDelete = prepareSearch(NotificationsIndex.NOTIFICATIONS_INDEX).get()
.getHits()
.getTotalHits().value;
long totalModelSizeStatsBeforeDelete = SearchResponseUtils.getTotalHitsValue(
prepareSearch("*").setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN_CLOSED_HIDDEN)
.setQuery(QueryBuilders.termQuery("result_type", "model_size_stats"))
);
long totalNotificationsCountBeforeDelete = SearchResponseUtils.getTotalHitsValue(
prepareSearch(NotificationsIndex.NOTIFICATIONS_INDEX)
);
assertThat(totalModelSizeStatsBeforeDelete, greaterThan(0L));
assertThat(totalNotificationsCountBeforeDelete, greaterThan(0L));

Expand Down Expand Up @@ -319,14 +319,13 @@ private void testExpiredDeletion(Float customThrottle, int numUnusedState) throw
assertThat(getRecords("results-and-snapshots-retention").size(), equalTo(0));
assertThat(getModelSnapshots("results-and-snapshots-retention").size(), equalTo(1));

long totalModelSizeStatsAfterDelete = prepareSearch("*").setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN_CLOSED_HIDDEN)
.setQuery(QueryBuilders.termQuery("result_type", "model_size_stats"))
.get()
.getHits()
.getTotalHits().value;
long totalNotificationsCountAfterDelete = prepareSearch(NotificationsIndex.NOTIFICATIONS_INDEX).get()
.getHits()
.getTotalHits().value;
long totalModelSizeStatsAfterDelete = SearchResponseUtils.getTotalHitsValue(
prepareSearch("*").setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN_CLOSED_HIDDEN)
.setQuery(QueryBuilders.termQuery("result_type", "model_size_stats"))
);
long totalNotificationsCountAfterDelete = SearchResponseUtils.getTotalHitsValue(
prepareSearch(NotificationsIndex.NOTIFICATIONS_INDEX)
);
assertThat(totalModelSizeStatsAfterDelete, equalTo(totalModelSizeStatsBeforeDelete));
assertThat(totalNotificationsCountAfterDelete, greaterThanOrEqualTo(totalNotificationsCountBeforeDelete));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchResponseUtils;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
import org.elasticsearch.xpack.core.ml.action.NodeAcknowledgedResponse;
Expand Down Expand Up @@ -396,11 +397,12 @@ public void testStopOutlierDetectionWithEnoughDocumentsToScroll() throws Excepti

assertResponse(prepareSearch(config.getDest().getIndex()).setTrackTotalHits(true), searchResponse -> {
if (searchResponse.getHits().getTotalHits().value == docCount) {
searchResponse = prepareSearch(config.getDest().getIndex()).setTrackTotalHits(true)
.setQuery(QueryBuilders.existsQuery("custom_ml.outlier_score"))
.get();
logger.debug("We stopped during analysis: [{}] < [{}]", searchResponse.getHits().getTotalHits().value, docCount);
assertThat(searchResponse.getHits().getTotalHits().value, lessThan((long) docCount));
long seenCount = SearchResponseUtils.getTotalHitsValue(
prepareSearch(config.getDest().getIndex()).setTrackTotalHits(true)
.setQuery(QueryBuilders.existsQuery("custom_ml.outlier_score"))
);
logger.debug("We stopped during analysis: [{}] < [{}]", seenCount, docCount);
assertThat(seenCount, lessThan((long) docCount));
} else {
logger.debug("We stopped during reindexing: [{}] < [{}]", searchResponse.getHits().getTotalHits().value, docCount);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ public void testInferenceAggRestricted() {

SearchRequest search = new SearchRequest(index);
search.source().aggregation(termsAgg);
client().search(search).actionGet();
client().search(search).actionGet().decRef();

// Pick a license that does not allow machine learning
License.OperationMode mode = randomInvalidLicenseType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.search.aggregations.AggregationBuilders;
Expand All @@ -31,6 +30,7 @@
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Stream;

import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse;
import static org.hamcrest.Matchers.closeTo;

public class BucketCorrelationAggregationIT extends MlSingleNodeTestCase {
Expand Down Expand Up @@ -71,34 +71,42 @@ public void testCountCorrelation() {

AtomicLong counter = new AtomicLong();
double[] steps = Stream.generate(() -> counter.getAndAdd(2L)).limit(50).mapToDouble(l -> (double) l).toArray();
SearchResponse percentilesSearch = client().prepareSearch("data")
.addAggregation(AggregationBuilders.percentiles("percentiles").field("metric").percentiles(steps))
.setSize(0)
.setTrackTotalHits(true)
.get();
long totalHits = percentilesSearch.getHits().getTotalHits().value;
Percentiles percentiles = percentilesSearch.getAggregations().get("percentiles");
Tuple<RangeAggregationBuilder, BucketCorrelationAggregationBuilder> aggs = buildRangeAggAndSetExpectations(
percentiles,
steps,
totalHits,
"metric"
assertResponse(
client().prepareSearch("data")
.addAggregation(AggregationBuilders.percentiles("percentiles").field("metric").percentiles(steps))
.setSize(0)
.setTrackTotalHits(true),
percentilesSearch -> {
long totalHits = percentilesSearch.getHits().getTotalHits().value;
Percentiles percentiles = percentilesSearch.getAggregations().get("percentiles");
Tuple<RangeAggregationBuilder, BucketCorrelationAggregationBuilder> aggs = buildRangeAggAndSetExpectations(
percentiles,
steps,
totalHits,
"metric"
);

assertResponse(
client().prepareSearch("data")
.setSize(0)
.setTrackTotalHits(false)
.addAggregation(
AggregationBuilders.terms("buckets").field("term").subAggregation(aggs.v1()).subAggregation(aggs.v2())
),
countCorrelations -> {

Terms terms = countCorrelations.getAggregations().get("buckets");
Terms.Bucket catBucket = terms.getBucketByKey("cat");
Terms.Bucket dogBucket = terms.getBucketByKey("dog");
NumericMetricsAggregation.SingleValue approxCatCorrelation = catBucket.getAggregations().get("correlates");
NumericMetricsAggregation.SingleValue approxDogCorrelation = dogBucket.getAggregations().get("correlates");

assertThat(approxCatCorrelation.value(), closeTo(catCorrelation, 0.1));
assertThat(approxDogCorrelation.value(), closeTo(dogCorrelation, 0.1));
}
);
}
);

SearchResponse countCorrelations = client().prepareSearch("data")
.setSize(0)
.setTrackTotalHits(false)
.addAggregation(AggregationBuilders.terms("buckets").field("term").subAggregation(aggs.v1()).subAggregation(aggs.v2()))
.get();

Terms terms = countCorrelations.getAggregations().get("buckets");
Terms.Bucket catBucket = terms.getBucketByKey("cat");
Terms.Bucket dogBucket = terms.getBucketByKey("dog");
NumericMetricsAggregation.SingleValue approxCatCorrelation = catBucket.getAggregations().get("correlates");
NumericMetricsAggregation.SingleValue approxDogCorrelation = dogBucket.getAggregations().get("correlates");

assertThat(approxCatCorrelation.value(), closeTo(catCorrelation, 0.1));
assertThat(approxDogCorrelation.value(), closeTo(dogCorrelation, 0.1));
}

private static Tuple<RangeAggregationBuilder, BucketCorrelationAggregationBuilder> buildRangeAggAndSetExpectations(
Expand Down
Loading