diff --git a/gradle.properties b/gradle.properties index af24e4c7..4ebab2ea 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,5 +1,5 @@ ltrVersion = 1.5.8 -elasticsearchVersion = 8.11.4 -luceneVersion = 9.8.0 +elasticsearchVersion = 8.14.0 +luceneVersion = 9.10.0 ow2Version = 8.0.1 antlrVersion = 4.5.1-1 diff --git a/src/javaRestTest/java/com/o19s/es/ltr/action/BaseIntegrationTest.java b/src/javaRestTest/java/com/o19s/es/ltr/action/BaseIntegrationTest.java index 2bf1d326..56a96807 100644 --- a/src/javaRestTest/java/com/o19s/es/ltr/action/BaseIntegrationTest.java +++ b/src/javaRestTest/java/com/o19s/es/ltr/action/BaseIntegrationTest.java @@ -26,7 +26,7 @@ import com.o19s.es.ltr.ranker.parser.LtrRankerParserFactory; import com.o19s.es.ltr.ranker.ranklib.RankLibScriptEngine; import org.elasticsearch.action.DocWriteResponse; -import org.elasticsearch.action.admin.indices.create.CreateIndexAction; +import org.elasticsearch.action.admin.indices.create.TransportCreateIndexAction; import org.elasticsearch.action.admin.indices.create.CreateIndexResponse; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.core.Nullable; @@ -70,7 +70,7 @@ protected Collection> getPlugins() { public void createStore(String name) throws Exception { assert IndexFeatureStore.isIndexStore(name); - CreateIndexResponse resp = client().execute(CreateIndexAction.INSTANCE, IndexFeatureStore.buildIndexRequest(name)).get(); + CreateIndexResponse resp = client().execute(TransportCreateIndexAction.TYPE, IndexFeatureStore.buildIndexRequest(name)).get(); assertTrue(resp.isAcknowledged()); } diff --git a/src/javaRestTest/java/com/o19s/es/ltr/logging/LoggingIT.java b/src/javaRestTest/java/com/o19s/es/ltr/logging/LoggingIT.java index 76b61cb3..86684c0d 100644 --- a/src/javaRestTest/java/com/o19s/es/ltr/logging/LoggingIT.java +++ b/src/javaRestTest/java/com/o19s/es/ltr/logging/LoggingIT.java @@ -50,6 +50,7 @@ import java.util.Map; import java.util.stream.Collectors; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse; import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.instanceOf; @@ -213,9 +214,7 @@ public void testLog() throws Exception { new LoggingSearchExtBuilder() .addQueryLogging("first_log", "test", false) .addRescoreLogging("second_log", 0, true))); - - SearchResponse resp = client().prepareSearch("test_index").setSource(sourceBuilder).get(); - assertSearchHits(docs, resp); + assertResponse(client().prepareSearch("test_index").setSource(sourceBuilder), resp -> assertSearchHits(docs, resp)); sbuilder.featureSetName(null); sbuilder.modelName("my_model"); sbuilder.boost(random().nextInt(3)); @@ -234,8 +233,7 @@ public void testLog() throws Exception { .addQueryLogging("first_log", "test", false) .addRescoreLogging("second_log", 0, true))); - SearchResponse resp2 = client().prepareSearch("test_index").setSource(sourceBuilder).get(); - assertSearchHits(docs, resp2); + assertResponse(client().prepareSearch("test_index").setSource(sourceBuilder), resp -> assertSearchHits(docs, resp)); query = QueryBuilders.boolQuery() .must(new WrapperQueryBuilder(sbuilder.toString())) @@ -254,8 +252,7 @@ public void testLog() throws Exception { new LoggingSearchExtBuilder() .addQueryLogging("first_log", "test", false) .addRescoreLogging("second_log", 0, true))); - SearchResponse resp3 = client().prepareSearch("test_index").setSource(sourceBuilder).get(); - assertSearchHits(docs, resp3); + assertResponse(client().prepareSearch("test_index").setSource(sourceBuilder), resp -> assertSearchHits(docs, resp)); query = QueryBuilders.boolQuery().filter(QueryBuilders.idsQuery().addIds(ids)); sourceBuilder = new SearchSourceBuilder().query(query) @@ -268,8 +265,7 @@ public void testLog() throws Exception { .addRescoreLogging("first_log", 0, false) .addRescoreLogging("second_log", 1, true))); - SearchResponse resp4 = client().prepareSearch("test_index").setSource(sourceBuilder).get(); - assertSearchHits(docs, resp4); + assertResponse(client().prepareSearch("test_index").setSource(sourceBuilder), resp -> assertSearchHits(docs, resp)); } public void testLogExtraLogging() throws Exception { @@ -304,8 +300,7 @@ public void testLogExtraLogging() throws Exception { .addQueryLogging("first_log", "test", false) .addRescoreLogging("second_log", 0, true))); - SearchResponse resp = client().prepareSearch("test_index").setSource(sourceBuilder).get(); - assertSearchHitsExtraLogging(docs, resp); + assertResponse(client().prepareSearch("test_index").setSource(sourceBuilder), resp -> assertSearchHitsExtraLogging(docs, resp)); sbuilder.featureSetName(null); sbuilder.modelName("my_model"); sbuilder.boost(random().nextInt(3)); @@ -324,8 +319,7 @@ public void testLogExtraLogging() throws Exception { .addQueryLogging("first_log", "test", false) .addRescoreLogging("second_log", 0, true))); - SearchResponse resp2 = client().prepareSearch("test_index").setSource(sourceBuilder).get(); - assertSearchHitsExtraLogging(docs, resp2); + assertResponse(client().prepareSearch("test_index").setSource(sourceBuilder), resp -> assertSearchHitsExtraLogging(docs, resp)); query = QueryBuilders.boolQuery() .must(new WrapperQueryBuilder(sbuilder.toString())) @@ -344,8 +338,7 @@ public void testLogExtraLogging() throws Exception { new LoggingSearchExtBuilder() .addQueryLogging("first_log", "test", false) .addRescoreLogging("second_log", 0, true))); - SearchResponse resp3 = client().prepareSearch("test_index").setSource(sourceBuilder).get(); - assertSearchHitsExtraLogging(docs, resp3); + assertResponse(client().prepareSearch("test_index").setSource(sourceBuilder), resp -> assertSearchHitsExtraLogging(docs, resp)); } public void testLogWithFeatureScoreCache() throws Exception { @@ -386,8 +379,7 @@ public void testLogWithFeatureScoreCache() throws Exception { .addQueryLogging("first_log", "test", false) .addRescoreLogging("second_log", 0, true))); - SearchResponse resp = client().prepareSearch("test_index").setSource(sourceBuilder).get(); - assertSearchHits(docs, resp); + assertResponse(client().prepareSearch("test_index").setSource(sourceBuilder), resp -> assertSearchHits(docs, resp)); sbuilder.featureSetName(null); sbuilder.modelName("my_model"); sbuilder_rescore.featureSetName(null); @@ -404,8 +396,7 @@ public void testLogWithFeatureScoreCache() throws Exception { .addQueryLogging("first_log", "test", false) .addRescoreLogging("second_log", 0, true))); - SearchResponse resp2 = client().prepareSearch("test_index").setSource(sourceBuilder).get(); - assertSearchHits(docs, resp2); + assertResponse(client().prepareSearch("test_index").setSource(sourceBuilder), resp -> assertSearchHits(docs, resp)); query = QueryBuilders.boolQuery() .must(new WrapperQueryBuilder(sbuilder.toString())) @@ -424,8 +415,7 @@ public void testLogWithFeatureScoreCache() throws Exception { new LoggingSearchExtBuilder() .addQueryLogging("first_log", "test", false) .addRescoreLogging("second_log", 0, true))); - SearchResponse resp3 = client().prepareSearch("test_index").setSource(sourceBuilder).get(); - assertSearchHits(docs, resp3); + assertResponse(client().prepareSearch("test_index").setSource(sourceBuilder), resp -> assertSearchHits(docs, resp)); query = QueryBuilders.boolQuery().filter(QueryBuilders.idsQuery().addIds(ids)); sourceBuilder = new SearchSourceBuilder().query(query) @@ -438,8 +428,7 @@ public void testLogWithFeatureScoreCache() throws Exception { .addRescoreLogging("first_log", 0, false) .addRescoreLogging("second_log", 1, true))); - SearchResponse resp4 = client().prepareSearch("test_index").setSource(sourceBuilder).get(); - assertSearchHits(docs, resp4); + assertResponse(client().prepareSearch("test_index").setSource(sourceBuilder), resp -> assertSearchHits(docs, resp)); } public void testScriptLogInternalParams() throws Exception { @@ -467,17 +456,17 @@ public void testScriptLogInternalParams() throws Exception { new LoggingSearchExtBuilder() .addQueryLogging("first_log", "test", false))); - SearchResponse resp = client().prepareSearch("test_index").setSource(sourceBuilder).get(); - - SearchHits hits = resp.getHits(); - SearchHit testHit = hits.getAt(0); - Map>> logs = testHit.getFields().get("_ltrlog").getValue(); + assertResponse( client().prepareSearch("test_index").setSource(sourceBuilder), resp -> { + SearchHits hits = resp.getHits(); + SearchHit testHit = hits.getAt(0); + Map>> logs = testHit.getFields().get("_ltrlog").getValue(); - assertTrue(logs.containsKey("first_log")); - List> log = logs.get("first_log"); + assertTrue(logs.containsKey("first_log")); + List> log = logs.get("first_log"); - assertEquals(log.get(0).get("name"), "test_inject"); - assertTrue((Float)log.get(0).get("value") > 0.0F); + assertEquals(log.get(0).get("name"), "test_inject"); + assertTrue((Float)log.get(0).get("value") > 0.0F); + }); } public void testScriptLogExternalParams() throws Exception { @@ -513,17 +502,17 @@ public void testScriptLogExternalParams() throws Exception { new LoggingSearchExtBuilder() .addQueryLogging("first_log", "test", false))); - SearchResponse resp = client().prepareSearch("test_index").setSource(sourceBuilder).get(); + assertResponse(client().prepareSearch("test_index").setSource(sourceBuilder), resp -> { + SearchHits hits = resp.getHits(); + SearchHit testHit = hits.getAt(0); + Map>> logs = testHit.getFields().get("_ltrlog").getValue(); - SearchHits hits = resp.getHits(); - SearchHit testHit = hits.getAt(0); - Map>> logs = testHit.getFields().get("_ltrlog").getValue(); - - assertTrue(logs.containsKey("first_log")); - List> log = logs.get("first_log"); + assertTrue(logs.containsKey("first_log")); + List> log = logs.get("first_log"); - assertEquals(log.get(0).get("name"), "test_inject"); - assertTrue((Float)log.get(0).get("value") > 0.0F); + assertEquals(log.get(0).get("name"), "test_inject"); + assertTrue((Float)log.get(0).get("value") > 0.0F); + }); } public void testScriptLogInvalidExternalParams() throws Exception { diff --git a/src/javaRestTest/java/com/o19s/es/ltr/query/StoredLtrQueryIT.java b/src/javaRestTest/java/com/o19s/es/ltr/query/StoredLtrQueryIT.java index 21077057..aa8f0411 100644 --- a/src/javaRestTest/java/com/o19s/es/ltr/query/StoredLtrQueryIT.java +++ b/src/javaRestTest/java/com/o19s/es/ltr/query/StoredLtrQueryIT.java @@ -31,7 +31,6 @@ import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; import com.o19s.es.ltr.logging.LoggingSearchExtBuilder; import org.elasticsearch.action.search.SearchRequestBuilder; -import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; @@ -51,6 +50,7 @@ import java.util.List; import java.util.Map; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse; import static org.hamcrest.CoreMatchers.containsString; /** @@ -118,22 +118,23 @@ public void testScriptFeatureUseCaseMissingFeatureNaiveAdditiveDecisionTree() th new LoggingSearchExtBuilder() .addQueryLogging("log", "test", false))); - SearchResponse resp = client().prepareSearch("test_index").setSource(sourceBuilder).get(); - SearchHit hit = resp.getHits().getAt(0); - assertTrue(hit.getFields().containsKey("_ltrlog")); - Map>> logs = hit.getFields().get("_ltrlog").getValue(); - assertTrue(logs.containsKey("log")); - List> log = logs.get("log"); + assertResponse( client().prepareSearch("test_index").setSource(sourceBuilder), resp -> { + SearchHit hit = resp.getHits().getAt(0); + assertTrue(hit.getFields().containsKey("_ltrlog")); + Map>> logs = hit.getFields().get("_ltrlog").getValue(); + assertTrue(logs.containsKey("log")); + List> log = logs.get("log"); - // verify that text_feature1 has a missing value, and that the reported score results from the model taking the - // corresponding branch, along with the explanation - String explanation = hit.getExplanation().getDetails()[0].getDescription(); - assertThat(explanation, containsString("default value of NaN used")); + // verify that text_feature1 has a missing value, and that the reported score results from the model taking the + // corresponding branch, along with the explanation + String explanation = hit.getExplanation().getDetails()[0].getDescription(); + assertThat(explanation, containsString("default value of NaN used")); - assertEquals("text_feature1", log.get(0).get("name")); - assertEquals(null, log.get(0).get("value")); + assertEquals("text_feature1", log.get(0).get("name")); + assertEquals(null, log.get(0).get("value")); - assertEquals(0.2F, hit.getScore(), Math.ulp(0.2F)); + assertEquals(0.2F, hit.getScore(), Math.ulp(0.2F)); + }); } public void testScriptFeatureUseCase() throws Exception { @@ -168,10 +169,11 @@ public void testScriptFeatureUseCase() throws Exception { .setQueryWeight(0) .setRescoreQueryWeight(1)); - SearchResponse sr = sb.get(); - assertEquals(1, sr.getHits().getTotalHits().value); - assertThat(sr.getHits().getAt(0).getScore(), Matchers.greaterThanOrEqualTo(29.0f)); - assertThat(sr.getHits().getAt(0).getScore(), Matchers.lessThanOrEqualTo(30.0f)); + assertResponse(sb, sr -> { + assertEquals(1, sr.getHits().getTotalHits().value); + assertThat(sr.getHits().getAt(0).getScore(), Matchers.greaterThanOrEqualTo(29.0f)); + assertThat(sr.getHits().getAt(0).getScore(), Matchers.lessThanOrEqualTo(30.0f)); + }); } public void testFullUsecase() throws Exception { @@ -217,9 +219,9 @@ public void testFullUsecase() throws Exception { buildIndex(); Map params = new HashMap<>(); - boolean negativeScore = false; - params.put("query", negativeScore ? "bonjour" : "hello"); - params.put("multiplier", negativeScore ? Integer.parseInt("-1") : 1.0); + final boolean negativeScore1 = false; + params.put("query", negativeScore1 ? "bonjour" : "hello"); + params.put("multiplier", negativeScore1 ? Integer.parseInt("-1") : 1.0); params.put("dependent_feature", new HashMap<>()); SearchRequestBuilder sb = client().prepareSearch("test_index") .setQuery(QueryBuilders.matchQuery("field1", "world")) @@ -229,18 +231,18 @@ public void testFullUsecase() throws Exception { .setQueryWeight(0) .setRescoreQueryWeight(1)); - SearchResponse sr = sb.get(); - assertEquals(1, sr.getHits().getTotalHits().value); - - if (negativeScore) { - assertThat(sr.getHits().getAt(0).getScore(), Matchers.lessThanOrEqualTo(-10.0f)); - } else { - assertThat(sr.getHits().getAt(0).getScore(), Matchers.greaterThanOrEqualTo(10.0f)); - } - - negativeScore = true; - params.put("query", negativeScore ? "bonjour" : "hello"); - params.put("multiplier", negativeScore ? -1 : 1.0); + assertResponse(sb, sr -> { + assertEquals(1, sr.getHits().getTotalHits().value); + if (negativeScore1) { + assertThat(sr.getHits().getAt(0).getScore(), Matchers.lessThanOrEqualTo(-10.0f)); + } else { + assertThat(sr.getHits().getAt(0).getScore(), Matchers.greaterThanOrEqualTo(10.0f)); + } + }); + + final boolean negativeScore2 = true; + params.put("query", negativeScore2 ? "bonjour" : "hello"); + params.put("multiplier", negativeScore2 ? -1 : 1.0); params.put("dependent_feature", new HashMap<>()); sb = client().prepareSearch("test_index") .setQuery(QueryBuilders.matchQuery("field1", "world")) @@ -250,14 +252,16 @@ public void testFullUsecase() throws Exception { .setQueryWeight(0) .setRescoreQueryWeight(1)); - sr = sb.get(); - assertEquals(1, sr.getHits().getTotalHits().value); + assertResponse(sb, sr -> { + assertEquals(1, sr.getHits().getTotalHits().value); + + if (negativeScore2) { + assertThat(sr.getHits().getAt(0).getScore(), Matchers.lessThanOrEqualTo(-10.0f)); + } else { + assertThat(sr.getHits().getAt(0).getScore(), Matchers.greaterThanOrEqualTo(10.0f)); + } + }); - if (negativeScore) { - assertThat(sr.getHits().getAt(0).getScore(), Matchers.lessThanOrEqualTo(-10.0f)); - } else { - assertThat(sr.getHits().getAt(0).getScore(), Matchers.greaterThanOrEqualTo(10.0f)); - } // Test profiling sb = client().prepareSearch("test_index") @@ -269,8 +273,7 @@ public void testFullUsecase() throws Exception { .setQueryWeight(0) .setRescoreQueryWeight(1)); - sr = sb.get(); - assertThat(sr.getProfileResults().isEmpty(), Matchers.equalTo(false)); + assertResponse(sb, sr -> assertThat(sr.getProfileResults().isEmpty(), Matchers.equalTo(false))); //we use only feature4 score and ignore other scores params.put("query", "hello"); sb = client().prepareSearch("test_index") @@ -281,10 +284,11 @@ public void testFullUsecase() throws Exception { .setQueryWeight(0) .setRescoreQueryWeight(1)); - sr = sb.get(); - assertEquals(1, sr.getHits().getTotalHits().value); - assertThat(sr.getHits().getAt(0).getScore(), Matchers.greaterThan(0.0f)); - assertThat(sr.getHits().getAt(0).getScore(), Matchers.lessThanOrEqualTo(1.0f)); + assertResponse(sb, sr -> { + assertEquals(1, sr.getHits().getTotalHits().value); + assertThat(sr.getHits().getAt(0).getScore(), Matchers.greaterThan(0.0f)); + assertThat(sr.getHits().getAt(0).getScore(), Matchers.lessThanOrEqualTo(1.0f)); + }); //we use feature 5 with query time positive int multiplier passed to feature5 params.put("query", "hello"); @@ -296,10 +300,11 @@ public void testFullUsecase() throws Exception { .setScoreMode(QueryRescoreMode.Total) .setQueryWeight(0) .setRescoreQueryWeight(1)); - sr = sb.get(); - assertEquals(1, sr.getHits().getTotalHits().value); - assertThat(sr.getHits().getAt(0).getScore(), Matchers.greaterThan(28.0f)); - assertThat(sr.getHits().getAt(0).getScore(), Matchers.lessThan(30.0f)); + assertResponse(sb, sr -> { + assertEquals(1, sr.getHits().getTotalHits().value); + assertThat(sr.getHits().getAt(0).getScore(), Matchers.greaterThan(28.0f)); + assertThat(sr.getHits().getAt(0).getScore(), Matchers.lessThan(30.0f)); + }); //we use feature 5 with query time negative double multiplier passed to feature5 params.put("query", "hello"); @@ -311,10 +316,11 @@ public void testFullUsecase() throws Exception { .setScoreMode(QueryRescoreMode.Total) .setQueryWeight(0) .setRescoreQueryWeight(1)); - sr = sb.get(); - assertEquals(1, sr.getHits().getTotalHits().value); - assertThat(sr.getHits().getAt(0).getScore(), Matchers.lessThan(-28.0f)); - assertThat(sr.getHits().getAt(0).getScore(), Matchers.greaterThan(-30.0f)); + assertResponse(sb, sr -> { + assertEquals(1, sr.getHits().getTotalHits().value); + assertThat(sr.getHits().getAt(0).getScore(), Matchers.lessThan(-28.0f)); + assertThat(sr.getHits().getAt(0).getScore(), Matchers.greaterThan(-30.0f)); + }); //we use feature1 and feature6(ScriptFeature) params.put("query", "hello"); @@ -326,9 +332,10 @@ public void testFullUsecase() throws Exception { .setScoreMode(QueryRescoreMode.Total) .setQueryWeight(0) .setRescoreQueryWeight(1)); - sr = sb.get(); - assertEquals(1, sr.getHits().getTotalHits().value); - assertThat(sr.getHits().getAt(0).getScore(), Matchers.greaterThan(0.2876f + 2.876f)); + assertResponse(sb, sr -> { + assertEquals(1, sr.getHits().getTotalHits().value); + assertThat(sr.getHits().getAt(0).getScore(), Matchers.greaterThan(0.2876f + 2.876f)); + }); StoredLtrModel model = getElement(StoredLtrModel.class, StoredLtrModel.TYPE, "my_model"); CachesStatsNodesResponse stats = client().execute(CachesStatsAction.INSTANCE, diff --git a/src/main/java/com/o19s/es/explore/ExplorerQuery.java b/src/main/java/com/o19s/es/explore/ExplorerQuery.java index 76556f0f..efe2fe86 100644 --- a/src/main/java/com/o19s/es/explore/ExplorerQuery.java +++ b/src/main/java/com/o19s/es/explore/ExplorerQuery.java @@ -103,7 +103,7 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo StatisticsHelper ttf_stats = new StatisticsHelper(); for (Term term : terms) { - TermStates ctx = TermStates.build(searcher.getTopReaderContext(), term, scoreMode.needsScores()); + TermStates ctx = TermStates.build(searcher, term, scoreMode.needsScores()); if(ctx != null && ctx.docFreq() > 0){ TermStatistics tStats = searcher.termStatistics(term, ctx.docFreq(), ctx.totalTermFreq()); df_stats.add(tStats.docFreq()); diff --git a/src/main/java/com/o19s/es/explore/PostingsExplorerQuery.java b/src/main/java/com/o19s/es/explore/PostingsExplorerQuery.java index 0ee2e930..9a78cefe 100644 --- a/src/main/java/com/o19s/es/explore/PostingsExplorerQuery.java +++ b/src/main/java/com/o19s/es/explore/PostingsExplorerQuery.java @@ -17,7 +17,6 @@ package com.o19s.es.explore; import com.o19s.es.ltr.utils.CheckedBiFunction; -import org.apache.lucene.index.IndexReaderContext; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.PostingsEnum; import org.apache.lucene.index.ReaderUtil; @@ -78,9 +77,8 @@ public int hashCode() { @Override public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - IndexReaderContext context = searcher.getTopReaderContext(); assert scoreMode.needsScores() : "Should not be used in filtering mode"; - return new PostingsExplorerWeight(this, this.term, TermStates.build(context, this.term, + return new PostingsExplorerWeight(this, this.term, TermStates.build(searcher, this.term, scoreMode.needsScores()), this.type); } diff --git a/src/main/java/com/o19s/es/ltr/LtrQueryParserPlugin.java b/src/main/java/com/o19s/es/ltr/LtrQueryParserPlugin.java index d499f323..dc86730b 100644 --- a/src/main/java/com/o19s/es/ltr/LtrQueryParserPlugin.java +++ b/src/main/java/com/o19s/es/ltr/LtrQueryParserPlugin.java @@ -74,11 +74,9 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.node.DiscoveryNodes; -import org.elasticsearch.cluster.routing.allocation.AllocationService; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.core.CheckedFunction; -import org.elasticsearch.indices.IndicesService; -import org.elasticsearch.telemetry.TelemetryProvider; +import org.elasticsearch.features.NodeFeature; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry.Entry; @@ -89,8 +87,6 @@ import org.elasticsearch.common.settings.SettingsFilter; import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.env.Environment; -import org.elasticsearch.env.NodeEnvironment; import org.elasticsearch.index.Index; import org.elasticsearch.index.analysis.PreConfiguredTokenFilter; import org.elasticsearch.index.analysis.PreConfiguredTokenizer; @@ -99,15 +95,11 @@ import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.ScriptPlugin; import org.elasticsearch.plugins.SearchPlugin; -import org.elasticsearch.repositories.RepositoriesService; import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; import org.elasticsearch.script.ScriptContext; import org.elasticsearch.script.ScriptEngine; -import org.elasticsearch.script.ScriptService; import org.elasticsearch.search.fetch.FetchSubPhase; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.watcher.ResourceWatcherService; import java.io.IOException; import java.util.ArrayList; @@ -117,6 +109,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Predicate; import java.util.function.Supplier; import static java.util.Arrays.asList; @@ -171,10 +164,15 @@ public ScriptEngine getScriptEngine(Settings settings, Collection getRestHandlers(Settings settings, RestController restController, - ClusterSettings clusterSettings, IndexScopedSettings indexScopedSettings, - SettingsFilter settingsFilter, IndexNameExpressionResolver indexNameExpressionResolver, - Supplier nodesInCluster) { + public List getRestHandlers(Settings settings, + NamedWriteableRegistry namedWriteableRegistry, + RestController restController, + ClusterSettings clusterSettings, + IndexScopedSettings indexScopedSettings, + SettingsFilter settingsFilter, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier nodesInCluster, + Predicate clusterSupportsFeature) { List list = new ArrayList<>(); for (String type : ValidatingLtrQueryBuilder.SUPPORTED_TYPES) { @@ -242,21 +240,8 @@ public List> getSettings() { } @Override - public Collection createComponents(Client client, - ClusterService clusterService, - ThreadPool threadPool, - ResourceWatcherService resourceWatcherService, - ScriptService scriptService, - NamedXContentRegistry xContentRegistry, - Environment environment, - NodeEnvironment nodeEnvironment, - NamedWriteableRegistry namedWriteableRegistry, - IndexNameExpressionResolver indexNameExpressionResolver, - Supplier repositoriesServiceSupplier, - TelemetryProvider telemetryProvider, - AllocationService allocationService, - IndicesService indicesService) { - clusterService.addListener(event -> { + public Collection createComponents(PluginServices services) { + services.clusterService().addListener(event -> { for (Index i : event.indicesDeleted()) { if (IndexFeatureStore.isIndexStore(i.getName())) { caches.evict(i.getName()); @@ -264,9 +249,10 @@ public Collection createComponents(Client client, } }); - Scripting.initScriptService(scriptService); + Scripting.initScriptService(services.scriptService()); - return asList(caches, parserFactory, getStats(client, clusterService, indexNameExpressionResolver)); + return asList(caches, parserFactory, + getStats(services.client(), services.clusterService(), services.indexNameExpressionResolver())); } private LTRStats getStats(Client client, ClusterService clusterService, IndexNameExpressionResolver indexNameExpressionResolver) { diff --git a/src/main/java/com/o19s/es/ltr/action/AddFeaturesToSetAction.java b/src/main/java/com/o19s/es/ltr/action/AddFeaturesToSetAction.java index 4588b984..87b5536f 100644 --- a/src/main/java/com/o19s/es/ltr/action/AddFeaturesToSetAction.java +++ b/src/main/java/com/o19s/es/ltr/action/AddFeaturesToSetAction.java @@ -29,7 +29,6 @@ import org.elasticsearch.client.internal.ElasticsearchClient; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable.Reader; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; @@ -44,12 +43,7 @@ public class AddFeaturesToSetAction extends ActionType public static final String NAME = "cluster:admin/ltr/store/add-features-to-set"; protected AddFeaturesToSetAction() { - super(NAME, AddFeaturesToSetResponse::new); - } - - @Override - public Reader getResponseReader() { - return AddFeaturesToSetResponse::new; + super(NAME); } public static class AddFeaturesToSetRequestBuilder extends ActionRequestBuilder { diff --git a/src/main/java/com/o19s/es/ltr/action/CachesStatsAction.java b/src/main/java/com/o19s/es/ltr/action/CachesStatsAction.java index 6b8a7a56..30594d17 100644 --- a/src/main/java/com/o19s/es/ltr/action/CachesStatsAction.java +++ b/src/main/java/com/o19s/es/ltr/action/CachesStatsAction.java @@ -43,7 +43,7 @@ public class CachesStatsAction extends ActionType { public static final CachesStatsAction INSTANCE = new CachesStatsAction(); protected CachesStatsAction() { - super(NAME, CachesStatsNodesResponse::new); + super(NAME); } public static class CachesStatsNodesRequest extends BaseNodesRequest { diff --git a/src/main/java/com/o19s/es/ltr/action/ClearCachesAction.java b/src/main/java/com/o19s/es/ltr/action/ClearCachesAction.java index 71d7ea4f..c9b38b06 100644 --- a/src/main/java/com/o19s/es/ltr/action/ClearCachesAction.java +++ b/src/main/java/com/o19s/es/ltr/action/ClearCachesAction.java @@ -33,7 +33,6 @@ import java.io.IOException; import java.util.List; import java.util.Objects; -import org.elasticsearch.common.io.stream.Writeable.Reader; import static org.elasticsearch.action.ValidateActions.addValidationError; @@ -42,12 +41,7 @@ public class ClearCachesAction extends ActionType { public static final ClearCachesAction INSTANCE = new ClearCachesAction(); private ClearCachesAction() { - super(NAME, ClearCachesNodesResponse::new); - } - - @Override - public Reader getResponseReader() { - return ClearCachesNodesResponse::new; + super(NAME); } public static class RequestBuilder extends ActionRequestBuilder { diff --git a/src/main/java/com/o19s/es/ltr/action/CreateModelFromSetAction.java b/src/main/java/com/o19s/es/ltr/action/CreateModelFromSetAction.java index d28dc7ac..d00a3ed8 100644 --- a/src/main/java/com/o19s/es/ltr/action/CreateModelFromSetAction.java +++ b/src/main/java/com/o19s/es/ltr/action/CreateModelFromSetAction.java @@ -42,7 +42,7 @@ public class CreateModelFromSetAction extends ActionType { public static final FeatureStoreAction INSTANCE = new FeatureStoreAction(); protected FeatureStoreAction() { - super(NAME, FeatureStoreResponse::new); - } - - @Override - public Reader getResponseReader() { - return FeatureStoreResponse::new; + super(NAME); } public static class FeatureStoreRequestBuilder diff --git a/src/main/java/com/o19s/es/ltr/action/LTRStatsAction.java b/src/main/java/com/o19s/es/ltr/action/LTRStatsAction.java index 5af8af63..97c7cab3 100644 --- a/src/main/java/com/o19s/es/ltr/action/LTRStatsAction.java +++ b/src/main/java/com/o19s/es/ltr/action/LTRStatsAction.java @@ -27,7 +27,7 @@ public class LTRStatsAction extends ActionType nodeResponses, diff --git a/src/main/java/com/o19s/es/ltr/action/ListStoresAction.java b/src/main/java/com/o19s/es/ltr/action/ListStoresAction.java index 12cde4c5..afa61f7d 100644 --- a/src/main/java/com/o19s/es/ltr/action/ListStoresAction.java +++ b/src/main/java/com/o19s/es/ltr/action/ListStoresAction.java @@ -27,7 +27,6 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.io.stream.Writeable.Reader; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; @@ -43,12 +42,7 @@ public class ListStoresAction extends ActionType { public static final ListStoresAction INSTANCE = new ListStoresAction(); private ListStoresAction() { - super(NAME, ListStoresActionResponse::new); - } - - @Override - public Reader getResponseReader() { - return ListStoresActionResponse::new; + super(NAME); } public static class ListStoresActionRequest extends MasterNodeReadRequest { diff --git a/src/main/java/com/o19s/es/ltr/action/TransportAddFeatureToSetAction.java b/src/main/java/com/o19s/es/ltr/action/TransportAddFeatureToSetAction.java index 1b3efe07..6aeaed8e 100644 --- a/src/main/java/com/o19s/es/ltr/action/TransportAddFeatureToSetAction.java +++ b/src/main/java/com/o19s/es/ltr/action/TransportAddFeatureToSetAction.java @@ -68,7 +68,8 @@ public TransportAddFeatureToSetAction(Settings settings, ThreadPool threadPool, IndexNameExpressionResolver indexNameExpressionResolver, ClusterService clusterService, TransportSearchAction searchAction, TransportGetAction getAction, TransportFeatureStoreAction featureStoreAction) { - super(AddFeaturesToSetAction.NAME, transportService, actionFilters, AddFeaturesToSetRequest::new); + super(AddFeaturesToSetAction.NAME, transportService, actionFilters, + AddFeaturesToSetRequest::new, threadPool.executor(ThreadPool.Names.MANAGEMENT)); this.clusterService = clusterService; this.searchAction = searchAction; this.getAction = getAction; diff --git a/src/main/java/com/o19s/es/ltr/action/TransportCacheStatsAction.java b/src/main/java/com/o19s/es/ltr/action/TransportCacheStatsAction.java index 723a88ac..44fcf4b7 100644 --- a/src/main/java/com/o19s/es/ltr/action/TransportCacheStatsAction.java +++ b/src/main/java/com/o19s/es/ltr/action/TransportCacheStatsAction.java @@ -46,8 +46,8 @@ public TransportCacheStatsAction(Settings settings, ThreadPool threadPool, ClusterService clusterService, TransportService transportService, ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, Caches caches) { - super(CachesStatsAction.NAME, threadPool, clusterService, transportService, - actionFilters, CachesStatsNodesRequest::new, CachesStatsNodeRequest::new, + super(CachesStatsAction.NAME, clusterService, transportService, + actionFilters, CachesStatsNodeRequest::new, threadPool.executor(ThreadPool.Names.MANAGEMENT)); this.caches = caches; } diff --git a/src/main/java/com/o19s/es/ltr/action/TransportClearCachesAction.java b/src/main/java/com/o19s/es/ltr/action/TransportClearCachesAction.java index ebee0014..e8044671 100644 --- a/src/main/java/com/o19s/es/ltr/action/TransportClearCachesAction.java +++ b/src/main/java/com/o19s/es/ltr/action/TransportClearCachesAction.java @@ -47,8 +47,8 @@ public TransportClearCachesAction(Settings settings, ThreadPool threadPool, ClusterService clusterService, TransportService transportService, ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, Caches caches) { - super(ClearCachesAction.NAME, threadPool, clusterService, transportService, actionFilters, - ClearCachesNodesRequest::new, ClearCachesNodeRequest::new, threadPool.executor(ThreadPool.Names.MANAGEMENT)); + super(ClearCachesAction.NAME, clusterService, transportService, actionFilters, + ClearCachesNodeRequest::new, threadPool.executor(ThreadPool.Names.MANAGEMENT)); this.caches = caches; } diff --git a/src/main/java/com/o19s/es/ltr/action/TransportCreateModelFromSetAction.java b/src/main/java/com/o19s/es/ltr/action/TransportCreateModelFromSetAction.java index d7b350e5..e3099079 100644 --- a/src/main/java/com/o19s/es/ltr/action/TransportCreateModelFromSetAction.java +++ b/src/main/java/com/o19s/es/ltr/action/TransportCreateModelFromSetAction.java @@ -33,6 +33,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; @@ -50,7 +51,8 @@ public TransportCreateModelFromSetAction(Settings settings, ThreadPool threadPoo IndexNameExpressionResolver indexNameExpressionResolver, ClusterService clusterService, TransportGetAction getAction, TransportFeatureStoreAction featureStoreAction) { - super(CreateModelFromSetAction.NAME, transportService, actionFilters, CreateModelFromSetRequest::new); + super(CreateModelFromSetAction.NAME, transportService, actionFilters, + CreateModelFromSetRequest::new, EsExecutors.DIRECT_EXECUTOR_SERVICE); this.clusterService = clusterService; this.getAction = getAction; this.featureStoreAction = featureStoreAction; diff --git a/src/main/java/com/o19s/es/ltr/action/TransportFeatureStoreAction.java b/src/main/java/com/o19s/es/ltr/action/TransportFeatureStoreAction.java index 351c305b..83fac239 100644 --- a/src/main/java/com/o19s/es/ltr/action/TransportFeatureStoreAction.java +++ b/src/main/java/com/o19s/es/ltr/action/TransportFeatureStoreAction.java @@ -30,9 +30,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.index.IndexAction; +import org.elasticsearch.action.index.TransportIndexAction; import org.elasticsearch.action.index.IndexRequest; -import org.elasticsearch.action.search.SearchAction; import org.elasticsearch.action.search.SearchRequestBuilder; import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.action.support.ActionFilters; @@ -41,6 +40,7 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; @@ -62,7 +62,8 @@ public TransportFeatureStoreAction(TransportService transportService, ClusterService clusterService, Client client, LtrRankerParserFactory factory, TransportClearCachesAction clearCachesAction) { - super(FeatureStoreAction.NAME, false, transportService, actionFilters, FeatureStoreRequest::new); + super(FeatureStoreAction.NAME, false, transportService, actionFilters, + FeatureStoreRequest::new, EsExecutors.DIRECT_EXECUTOR_SERVICE); this.factory = factory; this.clusterService = clusterService; this.clearCachesAction = clearCachesAction; @@ -152,7 +153,7 @@ private void validate(FeatureValidation validation, Runnable onSuccess) { ValidatingLtrQueryBuilder ltrBuilder = new ValidatingLtrQueryBuilder(element, validation, factory); - SearchRequestBuilder builder = new SearchRequestBuilder(client, SearchAction.INSTANCE); + SearchRequestBuilder builder = new SearchRequestBuilder(client); builder.setIndices(validation.getIndex()); builder.setQuery(ltrBuilder); builder.setFrom(0); @@ -179,7 +180,7 @@ private void store(FeatureStoreRequest request, Task task, ActionListener clearCachesNodesRequest = buildClearCache(request); IndexRequest indexRequest = buildIndexRequest(task, request); - client.execute(IndexAction.INSTANCE, indexRequest, wrap( + client.execute(TransportIndexAction.TYPE, indexRequest, wrap( (r) -> { // Run and forget, log only if something bad happens // but don't wait for the action to be done nor set the parent task. diff --git a/src/main/java/com/o19s/es/ltr/action/TransportLTRStatsAction.java b/src/main/java/com/o19s/es/ltr/action/TransportLTRStatsAction.java index 297ea671..3e46ed48 100644 --- a/src/main/java/com/o19s/es/ltr/action/TransportLTRStatsAction.java +++ b/src/main/java/com/o19s/es/ltr/action/TransportLTRStatsAction.java @@ -33,8 +33,8 @@ public TransportLTRStatsAction(ThreadPool threadPool, TransportService transportService, ActionFilters actionFilters, LTRStats ltrStats) { - super(LTRStatsAction.NAME, threadPool, clusterService, transportService, - actionFilters, LTRStatsNodesRequest::new, LTRStatsNodeRequest::new, + super(LTRStatsAction.NAME, clusterService, transportService, + actionFilters, LTRStatsNodeRequest::new, threadPool.executor(ThreadPool.Names.MANAGEMENT)); this.ltrStats = ltrStats; } diff --git a/src/main/java/com/o19s/es/ltr/action/TransportListStoresAction.java b/src/main/java/com/o19s/es/ltr/action/TransportListStoresAction.java index 11dac743..66c2605d 100644 --- a/src/main/java/com/o19s/es/ltr/action/TransportListStoresAction.java +++ b/src/main/java/com/o19s/es/ltr/action/TransportListStoresAction.java @@ -32,6 +32,7 @@ import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.core.Tuple; import org.elasticsearch.tasks.Task; import org.elasticsearch.common.inject.Inject; @@ -64,7 +65,7 @@ public TransportListStoresAction(TransportService transportService, IndexNameExpressionResolver indexNameExpressionResolver, Client client) { super(ListStoresAction.NAME, transportService, clusterService, threadPool, actionFilters, ListStoresActionRequest::new, indexNameExpressionResolver, ListStoresActionResponse::new, - threadPool.executor(ThreadPool.Names.SAME)); + EsExecutors.DIRECT_EXECUTOR_SERVICE); this.client = client; } diff --git a/src/main/java/com/o19s/es/ltr/feature/FeatureValidation.java b/src/main/java/com/o19s/es/ltr/feature/FeatureValidation.java index 7bf96ef8..f1e79fef 100644 --- a/src/main/java/com/o19s/es/ltr/feature/FeatureValidation.java +++ b/src/main/java/com/o19s/es/ltr/feature/FeatureValidation.java @@ -60,7 +60,7 @@ public FeatureValidation(String index, Map params) { public FeatureValidation(StreamInput input) throws IOException { this.index = input.readString(); - this.params = input.readMap(); + this.params = input.readGenericMap(); } @Override diff --git a/src/main/java/com/o19s/es/ltr/feature/store/ScriptFeature.java b/src/main/java/com/o19s/es/ltr/feature/store/ScriptFeature.java index 57f639be..81a94799 100644 --- a/src/main/java/com/o19s/es/ltr/feature/store/ScriptFeature.java +++ b/src/main/java/com/o19s/es/ltr/feature/store/ScriptFeature.java @@ -319,7 +319,7 @@ static class LtrScriptWeight extends Weight { if (scoreMode.needsScores()) { for (Term t : terms) { - TermStates ctx = TermStates.build(searcher.getTopReaderContext(), t, true); + TermStates ctx = TermStates.build(searcher, t, true); if (ctx != null && ctx.docFreq() > 0) { searcher.collectionStatistics(t.field()); searcher.termStatistics(t, ctx.docFreq(), ctx.totalTermFreq()); diff --git a/src/main/java/com/o19s/es/ltr/feature/store/index/IndexFeatureStore.java b/src/main/java/com/o19s/es/ltr/feature/store/index/IndexFeatureStore.java index 55d4d480..0f1a6257 100644 --- a/src/main/java/com/o19s/es/ltr/feature/store/index/IndexFeatureStore.java +++ b/src/main/java/com/o19s/es/ltr/feature/store/index/IndexFeatureStore.java @@ -181,7 +181,7 @@ public CompiledLtrModel loadModel(String name) throws IOException { public E getAndParse(String name, Class eltClass, String type) throws IOException { GetResponse response = internalGet(generateId(type, name)).get(); if (response.isExists()) { - return parse(eltClass, type, response.getSourceAsBytes()); + return parse(eltClass, type, response.getSourceAsBytesRef()); } else { return null; } diff --git a/src/main/java/com/o19s/es/ltr/query/StoredLtrQueryBuilder.java b/src/main/java/com/o19s/es/ltr/query/StoredLtrQueryBuilder.java index f72176be..28c2aa89 100644 --- a/src/main/java/com/o19s/es/ltr/query/StoredLtrQueryBuilder.java +++ b/src/main/java/com/o19s/es/ltr/query/StoredLtrQueryBuilder.java @@ -90,7 +90,7 @@ public StoredLtrQueryBuilder(FeatureStoreLoader storeLoader, StreamInput input) modelName = input.readOptionalString(); featureScoreCacheFlag = input.readOptionalBoolean(); featureSetName = input.readOptionalString(); - params = input.readMap(); + params = input.readGenericMap(); if (input.getTransportVersion().onOrAfter(TransportVersions.V_7_0_0)) { String[] activeFeat = input.readOptionalStringArray(); activeFeatures = activeFeat == null ? null : Arrays.asList(activeFeat); diff --git a/src/main/java/com/o19s/es/ltr/rest/RestSearchStoreElements.java b/src/main/java/com/o19s/es/ltr/rest/RestSearchStoreElements.java index 1c56aefe..cae8c84d 100644 --- a/src/main/java/com/o19s/es/ltr/rest/RestSearchStoreElements.java +++ b/src/main/java/com/o19s/es/ltr/rest/RestSearchStoreElements.java @@ -3,7 +3,7 @@ import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.rest.RestRequest; -import org.elasticsearch.rest.action.RestChunkedToXContentListener; +import org.elasticsearch.rest.action.RestRefCountedChunkedToXContentListener; import java.util.List; @@ -51,7 +51,7 @@ RestChannelConsumer search(NodeClient client, String type, String indexName, Res .setQuery(qb) .setSize(size) .setFrom(from) - .execute(new RestChunkedToXContentListener<>(channel)); + .execute(new RestRefCountedChunkedToXContentListener<>(channel)); } } diff --git a/src/main/java/com/o19s/es/ltr/stats/suppliers/StoreStatsSupplier.java b/src/main/java/com/o19s/es/ltr/stats/suppliers/StoreStatsSupplier.java index f36a3289..dc5084ce 100644 --- a/src/main/java/com/o19s/es/ltr/stats/suppliers/StoreStatsSupplier.java +++ b/src/main/java/com/o19s/es/ltr/stats/suppliers/StoreStatsSupplier.java @@ -108,6 +108,7 @@ private Map> createStoreStatsResponse(MultiSearchReq .forEach(bucket -> updateCount(bucket, storeStat)); } } + msr.decRef(); return stats; } catch (InterruptedException | ExecutionException e) { LOG.error("Error retrieving store stats", e); diff --git a/src/main/java/com/o19s/es/termstat/TermStatQuery.java b/src/main/java/com/o19s/es/termstat/TermStatQuery.java index 224561c3..3cbda2dd 100644 --- a/src/main/java/com/o19s/es/termstat/TermStatQuery.java +++ b/src/main/java/com/o19s/es/termstat/TermStatQuery.java @@ -106,7 +106,7 @@ static class TermStatWeight extends Weight { // This is needed for proper DFS_QUERY_THEN_FETCH support if (scoreMode.needsScores()) { for (Term t : terms) { - TermStates ctx = TermStates.build(searcher.getTopReaderContext(), t, true); + TermStates ctx = TermStates.build(searcher, t, true); if (ctx != null && ctx.docFreq() > 0) { searcher.collectionStatistics(t.field()); diff --git a/src/test/java/com/o19s/es/ltr/ShardStatsIT.java b/src/test/java/com/o19s/es/ltr/ShardStatsIT.java index 7e384f1b..ea688c09 100644 --- a/src/test/java/com/o19s/es/ltr/ShardStatsIT.java +++ b/src/test/java/com/o19s/es/ltr/ShardStatsIT.java @@ -3,7 +3,6 @@ import com.o19s.es.TestExpressionsPlugin; import com.o19s.es.explore.ExplorerQueryBuilder; import com.o19s.es.termstat.TermStatQueryBuilder; -import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchType; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.TermQueryBuilder; @@ -14,7 +13,8 @@ import java.util.Arrays; import java.util.Collection; -import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailures; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse; import static org.hamcrest.Matchers.equalTo; /* @@ -60,14 +60,15 @@ public void testDfsExplorer() throws Exception { .query(q) .statsType("min_raw_df"); - final SearchResponse r = client().prepareSearch("idx") - .setSearchType(SearchType.DFS_QUERY_THEN_FETCH) - .setQuery(eq).get(); + assertResponse( + client().prepareSearch("idx").setSearchType(SearchType.DFS_QUERY_THEN_FETCH).setQuery(eq), + r -> { + assertNoFailures(r); - assertSearchResponse(r); - - SearchHits hits = r.getHits(); - assertThat(hits.getAt(0).getScore(), equalTo(4.0f)); + SearchHits hits = r.getHits(); + assertThat(hits.getAt(0).getScore(), equalTo(4.0f)); + } + ); } public void testNonDfsExplorer() throws Exception { @@ -79,14 +80,15 @@ public void testNonDfsExplorer() throws Exception { .query(q) .statsType("min_raw_df"); - final SearchResponse r = client().prepareSearch("idx") - .setSearchType(SearchType.QUERY_THEN_FETCH) - .setQuery(eq).get(); - - assertSearchResponse(r); + assertResponse( + client().prepareSearch("idx").setSearchType(SearchType.QUERY_THEN_FETCH).setQuery(eq), + r -> { + assertNoFailures(r); - SearchHits hits = r.getHits(); - assertThat(hits.getAt(0).getScore(), equalTo(2.0f)); + SearchHits hits = r.getHits(); + assertThat(hits.getAt(0).getScore(), equalTo(2.0f)); + } + ); } public void testDfsTSQ() throws Exception { @@ -99,15 +101,15 @@ public void testDfsTSQ() throws Exception { .terms(new String[]{"zzz"}) .fields(new String[]{"s"}); - final SearchResponse r = client().prepareSearch("idx") - .setSearchType(SearchType.DFS_QUERY_THEN_FETCH) - .setQuery(tsq) - .get(); + assertResponse( + client().prepareSearch("idx").setSearchType(SearchType.DFS_QUERY_THEN_FETCH).setQuery(tsq), + r -> { + assertNoFailures(r); - assertSearchResponse(r); - - SearchHits hits = r.getHits(); - assertThat(hits.getAt(0).getScore(), equalTo(4.0f)); + SearchHits hits = r.getHits(); + assertThat(hits.getAt(0).getScore(), equalTo(4.0f)); + } + ); } public void testNonDfsTSQ() throws Exception { @@ -120,14 +122,14 @@ public void testNonDfsTSQ() throws Exception { .terms(new String[]{"zzz"}) .fields(new String[]{"s"}); - final SearchResponse r = client().prepareSearch("idx") - .setSearchType(SearchType.QUERY_THEN_FETCH) - .setQuery(tsq) - .get(); - - assertSearchResponse(r); + assertResponse( + client().prepareSearch("idx").setSearchType(SearchType.QUERY_THEN_FETCH).setQuery(tsq), + r -> { + assertNoFailures(r); - SearchHits hits = r.getHits(); - assertThat(hits.getAt(0).getScore(), equalTo(2.0f)); + SearchHits hits = r.getHits(); + assertThat(hits.getAt(0).getScore(), equalTo(2.0f)); + } + ); } } diff --git a/src/test/java/com/o19s/es/ltr/logging/LoggingFetchSubPhaseTests.java b/src/test/java/com/o19s/es/ltr/logging/LoggingFetchSubPhaseTests.java index c499cb1a..ea1d9a40 100644 --- a/src/test/java/com/o19s/es/ltr/logging/LoggingFetchSubPhaseTests.java +++ b/src/test/java/com/o19s/es/ltr/logging/LoggingFetchSubPhaseTests.java @@ -121,7 +121,7 @@ public void testLogging() throws IOException { LoggingFetchSubPhaseProcessor processor = new LoggingFetchSubPhaseProcessor(() -> new Tuple<>(weight, loggers)); SearchHit[] hits = preprocessRandomHits(processor); - for (SearchHit hit : hits) { + for (SearchHit hit : hits) try { assertTrue(docs.containsKey(hit.getId())); Document d = docs.get(hit.getId()); assertTrue(hit.getFields().containsKey("_ltrlog")); @@ -149,6 +149,8 @@ public void testLogging() throws IOException { expectedScore = Math.log1p(expectedScore+1); assertEquals((float) expectedScore, (Float)log1.get(1).get("value"), Math.ulp((float)expectedScore)); assertEquals((float) expectedScore, (Float)log1.get(1).get("value"), Math.ulp((float)expectedScore)); + } finally { + hit.decRef(); } } @@ -219,7 +221,8 @@ public Query buildFunctionScore() { "score", FLOAT, CoreValuesSourceType.NUMERIC, - (dv, n) -> { throw new UnsupportedOperationException(); })); + (dv, n) -> { throw new UnsupportedOperationException(); }, + false)); return new FunctionScoreQuery(new MatchAllDocsQuery(), fieldValueFactorFunction, CombineFunction.MULTIPLY, 0F, Float.MAX_VALUE); }