diff --git a/.buildkite/scripts/lucene-snapshot/update-branch.sh b/.buildkite/scripts/lucene-snapshot/update-branch.sh index 6a2d1e3df05f7..a2a80824c984f 100755 --- a/.buildkite/scripts/lucene-snapshot/update-branch.sh +++ b/.buildkite/scripts/lucene-snapshot/update-branch.sh @@ -7,12 +7,21 @@ if [[ "$BUILDKITE_BRANCH" != "lucene_snapshot"* ]]; then exit 1 fi -echo --- Updating "$BUILDKITE_BRANCH" branch with main +if [[ "$BUILDKITE_BRANCH" == "lucene_snapshot_10" ]]; then + UPSTREAM="main" +elif [[ "$BUILDKITE_BRANCH" == "lucene_snapshot" ]]; then + UPSTREAM="8.x" +else + echo "Error: unknown branch: $BUILDKITE_BRANCH" + exit 1 +fi + +echo --- Updating "$BUILDKITE_BRANCH" branch with "$UPSTREAM" git config --global user.name elasticsearchmachine git config --global user.email 'infra-root+elasticsearchmachine@elastic.co' git checkout "$BUILDKITE_BRANCH" -git fetch origin main -git merge --no-edit origin/main +git fetch origin "$UPSTREAM" +git merge --no-edit "origin/$UPSTREAM" git push origin "$BUILDKITE_BRANCH" diff --git a/distribution/tools/java-version-checker/build.gradle b/distribution/tools/java-version-checker/build.gradle index 0a47d0652e465..3d4ec5aced29c 100644 --- a/distribution/tools/java-version-checker/build.gradle +++ b/distribution/tools/java-version-checker/build.gradle @@ -1,30 +1,10 @@ apply plugin: 'elasticsearch.build' -sourceSets { - unsupportedJdkVersionEntrypoint -} - -tasks.named(sourceSets.unsupportedJdkVersionEntrypoint.compileJavaTaskName).configure { - targetCompatibility = JavaVersion.VERSION_1_8 -} - - -tasks.named("jar") { - manifest { - attributes("Multi-Release": "true") - } - - FileCollection mainOutput = sourceSets.main.output; - from(sourceSets.unsupportedJdkVersionEntrypoint.output) - eachFile { details -> - if (details.path.equals("org/elasticsearch/tools/java_version_checker/JavaVersionChecker.class") && - mainOutput.asFileTree.contains(details.file)) { - details.relativePath = details.relativePath.prepend("META-INF/versions/17") - } - } +compileJava { + options.release = 8 } // TODO revisit forbiddenApis issues -["javadoc", "forbiddenApisMain", "forbiddenApisUnsupportedJdkVersionEntrypoint"].each { +["javadoc", "forbiddenApisMain"].each { tasks.named(it).configure { enabled = false } } diff --git a/distribution/tools/java-version-checker/src/main/java/org/elasticsearch/tools/java_version_checker/JavaVersionChecker.java b/distribution/tools/java-version-checker/src/main/java/org/elasticsearch/tools/java_version_checker/JavaVersionChecker.java index 36b0ad49d124c..672fa8fd164aa 100644 --- a/distribution/tools/java-version-checker/src/main/java/org/elasticsearch/tools/java_version_checker/JavaVersionChecker.java +++ b/distribution/tools/java-version-checker/src/main/java/org/elasticsearch/tools/java_version_checker/JavaVersionChecker.java @@ -10,9 +10,10 @@ package org.elasticsearch.tools.java_version_checker; import java.util.Arrays; +import java.util.Locale; /** - * Java 17 compatible main which just exits without error. + * Java 8 compatible main to check the runtime version */ final class JavaVersionChecker { @@ -23,5 +24,27 @@ public static void main(final String[] args) { if (args.length != 0) { throw new IllegalArgumentException("expected zero arguments but was " + Arrays.toString(args)); } + + final int MIN_VERSION = 21; + final int version; + String versionString = System.getProperty("java.specification.version"); + if (versionString.equals("1.8")) { + version = 8; + } else { + version = Integer.parseInt(versionString); + } + if (version >= MIN_VERSION) { + return; + } + + final String message = String.format( + Locale.ROOT, + "The minimum required Java version is %d; your Java version %d from [%s] does not meet that requirement.", + MIN_VERSION, + version, + System.getProperty("java.home") + ); + System.err.println(message); + System.exit(1); } } diff --git a/distribution/tools/java-version-checker/src/unsupportedJdkVersionEntrypoint/java/org/elasticsearch/tools/java_version_checker/JavaVersionChecker.java b/distribution/tools/java-version-checker/src/unsupportedJdkVersionEntrypoint/java/org/elasticsearch/tools/java_version_checker/JavaVersionChecker.java deleted file mode 100644 index 49151e29201e8..0000000000000 --- a/distribution/tools/java-version-checker/src/unsupportedJdkVersionEntrypoint/java/org/elasticsearch/tools/java_version_checker/JavaVersionChecker.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.tools.java_version_checker; - -import java.util.Arrays; -import java.util.Locale; - -/** - * Java 7 compatible main which exits with an error. - */ -final class JavaVersionChecker { - - private JavaVersionChecker() {} - - public static void main(final String[] args) { - // no leniency! - if (args.length != 0) { - throw new IllegalArgumentException("expected zero arguments but was " + Arrays.toString(args)); - } - final String message = String.format( - Locale.ROOT, - "The minimum required Java version is 17; your Java version %s from [%s] does not meet that requirement.", - System.getProperty("java.specification.version"), - System.getProperty("java.home") - ); - System.err.println(message); - System.exit(1); - } -} diff --git a/docs/changelog/111684.yaml b/docs/changelog/111684.yaml deleted file mode 100644 index 32edb5723cb0a..0000000000000 --- a/docs/changelog/111684.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 111684 -summary: Write downloaded model parts async -area: Machine Learning -type: enhancement -issues: [] diff --git a/docs/changelog/112678.yaml b/docs/changelog/112678.yaml new file mode 100644 index 0000000000000..7a1a9d622a65f --- /dev/null +++ b/docs/changelog/112678.yaml @@ -0,0 +1,6 @@ +pr: 112678 +summary: Make "too many clauses" throw IllegalArgumentException to avoid 500s +area: Search +type: bug +issues: + - 112177 \ No newline at end of file diff --git a/docs/changelog/112888.yaml b/docs/changelog/112888.yaml new file mode 100644 index 0000000000000..48806a491e531 --- /dev/null +++ b/docs/changelog/112888.yaml @@ -0,0 +1,5 @@ +pr: 112888 +summary: Fix `getDatabaseType` for unusual MMDBs +area: Ingest Node +type: bug +issues: [] diff --git a/docs/reference/release-notes/8.15.0.asciidoc b/docs/reference/release-notes/8.15.0.asciidoc index bed1912fc1b84..c19f4f7cf989b 100644 --- a/docs/reference/release-notes/8.15.0.asciidoc +++ b/docs/reference/release-notes/8.15.0.asciidoc @@ -30,6 +30,19 @@ signed integer) may encounter errors (issue: {es-issue}111854[#111854]) `xpack.security.authc.realms.*.files.role_mapping` configuration option. As a workaround, custom role mappings can be configured using the https://www.elastic.co/guide/en/elasticsearch/reference/current/security-api-put-role-mapping.html[REST API] (issue: {es-issue}112503[#112503]) +* ES|QL queries can lead to node crashes due to Out Of Memory errors when: +** Multiple indices match the query pattern +** These indices have many conflicting field mappings +** Many of those fields are included in the request +These issues deplete heap memory, increasing the likelihood of OOM errors. (issue: {es-issue}111964[#111964], {es-issue}111358[#111358]). ++ +To work around this issue, you have a number of options: +** Downgrade to an earlier version +** Upgrade to 8.15.2 upon release +** Follow the instructions to +<> +** Change the default data view in Discover to a smaller set of indices and/or one with fewer mapping conflicts. + [[breaking-8.15.0]] [float] === Breaking changes diff --git a/docs/reference/release-notes/8.15.1.asciidoc b/docs/reference/release-notes/8.15.1.asciidoc index f480172611164..2c126cccbda9e 100644 --- a/docs/reference/release-notes/8.15.1.asciidoc +++ b/docs/reference/release-notes/8.15.1.asciidoc @@ -10,6 +10,19 @@ Also see <>. `xpack.security.authc.realms.*.files.role_mapping` configuration option. As a workaround, custom role mappings can be configured using the https://www.elastic.co/guide/en/elasticsearch/reference/current/security-api-put-role-mapping.html[REST API] (issue: {es-issue}112503[#112503]) +* ES|QL queries can lead to node crashes due to Out Of Memory errors when: +** Multiple indices match the query pattern +** These indices have many conflicting field mappings +** Many of those fields are included in the request +These issues deplete heap memory, increasing the likelihood of OOM errors. (issue: {es-issue}111964[#111964], {es-issue}111358[#111358]). ++ +To work around this issue, you have a number of options: +** Downgrade to an earlier version +** Upgrade to 8.15.2 upon release +** Follow the instructions to +<> +** Change the default data view in Discover to a smaller set of indices and/or one with fewer mapping conflicts. + [[bug-8.15.1]] [float] === Bug fixes diff --git a/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/AbstractChallengeRestTest.java b/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/AbstractChallengeRestTest.java index 811130fa4207d..88a33d502633b 100644 --- a/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/AbstractChallengeRestTest.java +++ b/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/AbstractChallengeRestTest.java @@ -44,7 +44,7 @@ public abstract class AbstractChallengeRestTest extends ESRestTestCase { private XContentBuilder contenderMappings; private Settings.Builder baselineSettings; private Settings.Builder contenderSettings; - private RestClient client; + protected RestClient client; @ClassRule() public static ElasticsearchCluster cluster = ElasticsearchCluster.local() diff --git a/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/DataGenerationHelper.java b/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/DataGenerationHelper.java new file mode 100644 index 0000000000000..7fd1ccde10053 --- /dev/null +++ b/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/DataGenerationHelper.java @@ -0,0 +1,141 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.datastreams.logsdb.qa; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.CheckedConsumer; +import org.elasticsearch.index.mapper.Mapper; +import org.elasticsearch.index.mapper.ObjectMapper; +import org.elasticsearch.logsdb.datageneration.DataGenerator; +import org.elasticsearch.logsdb.datageneration.DataGeneratorSpecification; +import org.elasticsearch.logsdb.datageneration.FieldDataGenerator; +import org.elasticsearch.logsdb.datageneration.datasource.DataSourceHandler; +import org.elasticsearch.logsdb.datageneration.datasource.DataSourceRequest; +import org.elasticsearch.logsdb.datageneration.datasource.DataSourceResponse; +import org.elasticsearch.logsdb.datageneration.fields.PredefinedField; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +class DataGenerationHelper { + private final ObjectMapper.Subobjects subobjects; + private final boolean keepArraySource; + + private final DataGenerator dataGenerator; + + DataGenerationHelper() { + // TODO enable subobjects: auto + // It is disabled because it currently does not have auto flattening and that results in asserts being triggered when using copy_to. + this.subobjects = ESTestCase.randomValueOtherThan( + ObjectMapper.Subobjects.AUTO, + () -> ESTestCase.randomFrom(ObjectMapper.Subobjects.values()) + ); + this.keepArraySource = ESTestCase.randomBoolean(); + + var specificationBuilder = DataGeneratorSpecification.builder().withFullyDynamicMapping(ESTestCase.randomBoolean()); + if (subobjects != ObjectMapper.Subobjects.ENABLED) { + specificationBuilder = specificationBuilder.withNestedFieldsLimit(0); + } + this.dataGenerator = new DataGenerator(specificationBuilder.withDataSourceHandlers(List.of(new DataSourceHandler() { + @Override + public DataSourceResponse.ObjectMappingParametersGenerator handle(DataSourceRequest.ObjectMappingParametersGenerator request) { + if (subobjects == ObjectMapper.Subobjects.ENABLED) { + // Use default behavior + return null; + } + + assert request.isNested() == false; + + // "enabled: false" is not compatible with subobjects: false + // "dynamic: false/strict/runtime" is not compatible with subobjects: false + return new DataSourceResponse.ObjectMappingParametersGenerator(() -> { + var parameters = new HashMap(); + parameters.put("subobjects", subobjects.toString()); + if (ESTestCase.randomBoolean()) { + parameters.put("dynamic", "true"); + } + if (ESTestCase.randomBoolean()) { + parameters.put("enabled", "true"); + } + return parameters; + }); + } + })) + .withPredefinedFields( + List.of( + // Customized because it always needs doc_values for aggregations. + new PredefinedField.WithGenerator("host.name", new FieldDataGenerator() { + @Override + public CheckedConsumer mappingWriter() { + return b -> b.startObject().field("type", "keyword").endObject(); + } + + @Override + public CheckedConsumer fieldValueGenerator() { + return b -> b.value(ESTestCase.randomAlphaOfLength(5)); + } + }), + // Needed for terms query + new PredefinedField.WithGenerator("method", new FieldDataGenerator() { + @Override + public CheckedConsumer mappingWriter() { + return b -> b.startObject().field("type", "keyword").endObject(); + } + + @Override + public CheckedConsumer fieldValueGenerator() { + return b -> b.value(ESTestCase.randomFrom("put", "post", "get")); + } + }), + + // Needed for histogram aggregation + new PredefinedField.WithGenerator("memory_usage_bytes", new FieldDataGenerator() { + @Override + public CheckedConsumer mappingWriter() { + return b -> b.startObject().field("type", "long").endObject(); + } + + @Override + public CheckedConsumer fieldValueGenerator() { + // We can generate this using standard long field but we would get "too many buckets" + return b -> b.value(ESTestCase.randomLongBetween(1000, 2000)); + } + }) + ) + ) + .build()); + } + + DataGenerator getDataGenerator() { + return dataGenerator; + } + + void logsDbMapping(XContentBuilder builder) throws IOException { + dataGenerator.writeMapping(builder); + } + + void standardMapping(XContentBuilder builder) throws IOException { + if (subobjects != ObjectMapper.Subobjects.ENABLED) { + dataGenerator.writeMapping(builder, Map.of("subobjects", subobjects.toString())); + } else { + dataGenerator.writeMapping(builder); + } + } + + void logsDbSettings(Settings.Builder builder) { + if (keepArraySource) { + builder.put(Mapper.SYNTHETIC_SOURCE_KEEP_INDEX_SETTING.getKey(), "arrays"); + } + } +} diff --git a/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/LogsDbVersusLogsDbReindexedIntoStandardModeChallengeRestIT.java b/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/LogsDbVersusLogsDbReindexedIntoStandardModeChallengeRestIT.java new file mode 100644 index 0000000000000..e1cafc40f706f --- /dev/null +++ b/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/LogsDbVersusLogsDbReindexedIntoStandardModeChallengeRestIT.java @@ -0,0 +1,80 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.datastreams.logsdb.qa; + +import org.elasticsearch.client.Request; +import org.elasticsearch.client.Response; +import org.elasticsearch.common.CheckedSupplier; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import static org.hamcrest.Matchers.equalTo; + +/** + * This test compares behavior of a logsdb data stream and a standard index mode data stream + * containing data reindexed from initial data stream. + * There should be no differences between such two data streams. + */ +public class LogsDbVersusLogsDbReindexedIntoStandardModeChallengeRestIT extends ReindexChallengeRestIT { + public String getBaselineDataStreamName() { + return "logs-apache-baseline"; + } + + public String getContenderDataStreamName() { + return "standard-apache-reindexed-contender"; + } + + @Override + public void baselineSettings(Settings.Builder builder) { + dataGenerationHelper.logsDbSettings(builder); + } + + @Override + public void contenderSettings(Settings.Builder builder) { + + } + + @Override + public void baselineMappings(XContentBuilder builder) throws IOException { + dataGenerationHelper.logsDbMapping(builder); + } + + @Override + public void contenderMappings(XContentBuilder builder) throws IOException { + dataGenerationHelper.standardMapping(builder); + } + + @Override + public Response indexContenderDocuments(CheckedSupplier, IOException> documentsSupplier) throws IOException { + var reindexRequest = new Request("POST", "/_reindex?refresh=true"); + reindexRequest.setJsonEntity(String.format(Locale.ROOT, """ + { + "source": { + "index": "%s" + }, + "dest": { + "index": "%s", + "op_type": "create" + } + } + """, getBaselineDataStreamName(), getContenderDataStreamName())); + var response = client.performRequest(reindexRequest); + assertOK(response); + + var body = entityAsMap(response); + assertThat("encountered failures when performing reindex:\n " + body, body.get("failures"), equalTo(List.of())); + + return response; + } +} diff --git a/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/LogsDbVersusReindexedLogsDbChallengeRestIT.java b/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/LogsDbVersusReindexedLogsDbChallengeRestIT.java new file mode 100644 index 0000000000000..dd80917b5f080 --- /dev/null +++ b/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/LogsDbVersusReindexedLogsDbChallengeRestIT.java @@ -0,0 +1,80 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.datastreams.logsdb.qa; + +import org.elasticsearch.client.Request; +import org.elasticsearch.client.Response; +import org.elasticsearch.common.CheckedSupplier; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import static org.hamcrest.Matchers.equalTo; + +/** + * This test compares behavior of a logsdb data stream and a data stream containing + * data reindexed from initial data stream. + * There should be no differences between such two data streams. + */ +public class LogsDbVersusReindexedLogsDbChallengeRestIT extends ReindexChallengeRestIT { + public String getBaselineDataStreamName() { + return "logs-apache-baseline"; + } + + public String getContenderDataStreamName() { + return "logs-apache-reindexed-contender"; + } + + @Override + public void baselineSettings(Settings.Builder builder) { + dataGenerationHelper.logsDbSettings(builder); + } + + @Override + public void contenderSettings(Settings.Builder builder) { + dataGenerationHelper.logsDbSettings(builder); + } + + @Override + public void baselineMappings(XContentBuilder builder) throws IOException { + dataGenerationHelper.logsDbMapping(builder); + } + + @Override + public void contenderMappings(XContentBuilder builder) throws IOException { + dataGenerationHelper.logsDbMapping(builder); + } + + @Override + public Response indexContenderDocuments(CheckedSupplier, IOException> documentsSupplier) throws IOException { + var reindexRequest = new Request("POST", "/_reindex?refresh=true"); + reindexRequest.setJsonEntity(String.format(Locale.ROOT, """ + { + "source": { + "index": "%s" + }, + "dest": { + "index": "%s", + "op_type": "create" + } + } + """, getBaselineDataStreamName(), getContenderDataStreamName())); + var response = client.performRequest(reindexRequest); + assertOK(response); + + var body = entityAsMap(response); + assertThat("encountered failures when performing reindex:\n " + body, body.get("failures"), equalTo(List.of())); + + return response; + } +} diff --git a/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/ReindexChallengeRestIT.java b/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/ReindexChallengeRestIT.java new file mode 100644 index 0000000000000..b48dce9ca4c57 --- /dev/null +++ b/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/ReindexChallengeRestIT.java @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.datastreams.logsdb.qa; + +import org.elasticsearch.client.Request; +import org.elasticsearch.client.Response; +import org.elasticsearch.common.CheckedSupplier; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import static org.hamcrest.Matchers.equalTo; + +public abstract class ReindexChallengeRestIT extends StandardVersusLogsIndexModeRandomDataChallengeRestIT { + @Override + public Response indexContenderDocuments(CheckedSupplier, IOException> documentsSupplier) throws IOException { + var reindexRequest = new Request("POST", "/_reindex?refresh=true"); + reindexRequest.setJsonEntity(String.format(Locale.ROOT, """ + { + "source": { + "index": "%s" + }, + "dest": { + "index": "%s", + "op_type": "create" + } + } + """, getBaselineDataStreamName(), getContenderDataStreamName())); + var response = client.performRequest(reindexRequest); + assertOK(response); + + var body = entityAsMap(response); + assertThat("encountered failures when performing reindex:\n " + body, body.get("failures"), equalTo(List.of())); + + return response; + } +} diff --git a/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/StandardVersusLogsIndexModeChallengeRestIT.java b/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/StandardVersusLogsIndexModeChallengeRestIT.java index c35d5bd626b5c..0b6cc38aff37a 100644 --- a/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/StandardVersusLogsIndexModeChallengeRestIT.java +++ b/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/StandardVersusLogsIndexModeChallengeRestIT.java @@ -13,11 +13,11 @@ import org.elasticsearch.client.Response; import org.elasticsearch.client.ResponseException; import org.elasticsearch.client.RestClient; +import org.elasticsearch.common.CheckedSupplier; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.time.DateFormatter; import org.elasticsearch.common.time.FormatNames; import org.elasticsearch.common.xcontent.XContentHelper; -import org.elasticsearch.core.Tuple; import org.elasticsearch.datastreams.logsdb.qa.matchers.MatchResult; import org.elasticsearch.datastreams.logsdb.qa.matchers.Matcher; import org.elasticsearch.index.query.QueryBuilders; @@ -185,7 +185,7 @@ public void testMatchAllQuery() throws IOException { int numberOfDocuments = ESTestCase.randomIntBetween(100, 200); final List documents = generateDocuments(numberOfDocuments); - assertDocumentIndexing(documents); + indexDocuments(documents); final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(QueryBuilders.matchAllQuery()) .size(numberOfDocuments); @@ -203,7 +203,7 @@ public void testTermsQuery() throws IOException { int numberOfDocuments = ESTestCase.randomIntBetween(100, 200); final List documents = generateDocuments(numberOfDocuments); - assertDocumentIndexing(documents); + indexDocuments(documents); final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(QueryBuilders.termQuery("method", "put")) .size(numberOfDocuments); @@ -221,7 +221,7 @@ public void testHistogramAggregation() throws IOException { int numberOfDocuments = ESTestCase.randomIntBetween(100, 200); final List documents = generateDocuments(numberOfDocuments); - assertDocumentIndexing(documents); + indexDocuments(documents); final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(QueryBuilders.matchAllQuery()) .size(numberOfDocuments) @@ -239,7 +239,7 @@ public void testTermsAggregation() throws IOException { int numberOfDocuments = ESTestCase.randomIntBetween(100, 200); final List documents = generateDocuments(numberOfDocuments); - assertDocumentIndexing(documents); + indexDocuments(documents); final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(QueryBuilders.matchAllQuery()) .size(0) @@ -257,7 +257,7 @@ public void testDateHistogramAggregation() throws IOException { int numberOfDocuments = ESTestCase.randomIntBetween(100, 200); final List documents = generateDocuments(numberOfDocuments); - assertDocumentIndexing(documents); + indexDocuments(documents); final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(QueryBuilders.matchAllQuery()) .aggregation(AggregationBuilders.dateHistogram("agg").field("@timestamp").calendarInterval(DateHistogramInterval.SECOND)) @@ -271,6 +271,28 @@ public void testDateHistogramAggregation() throws IOException { assertTrue(matchResult.getMessage(), matchResult.isMatch()); } + @Override + public Response indexBaselineDocuments(CheckedSupplier, IOException> documentsSupplier) throws IOException { + var response = super.indexBaselineDocuments(documentsSupplier); + + assertThat(response.getStatusLine().getStatusCode(), Matchers.equalTo(RestStatus.OK.getStatus())); + var baselineResponseBody = entityAsMap(response); + assertThat("errors in baseline bulk response:\n " + baselineResponseBody, baselineResponseBody.get("errors"), equalTo(false)); + + return response; + } + + @Override + public Response indexContenderDocuments(CheckedSupplier, IOException> documentsSupplier) throws IOException { + var response = super.indexContenderDocuments(documentsSupplier); + + assertThat(response.getStatusLine().getStatusCode(), Matchers.equalTo(RestStatus.OK.getStatus())); + var contenderResponseBody = entityAsMap(response); + assertThat("errors in contender bulk response:\n " + contenderResponseBody, contenderResponseBody.get("errors"), equalTo(false)); + + return response; + } + private List generateDocuments(int numberOfDocuments) throws IOException { final List documents = new ArrayList<>(); // This is static in order to be able to identify documents between test runs. @@ -319,15 +341,7 @@ private static List> getAggregationBuckets(final Response re return buckets; } - private void assertDocumentIndexing(List documents) throws IOException { - final Tuple tuple = indexDocuments(() -> documents, () -> documents); - - assertThat(tuple.v1().getStatusLine().getStatusCode(), Matchers.equalTo(RestStatus.OK.getStatus())); - var baselineResponseBody = entityAsMap(tuple.v1()); - assertThat("errors in baseline bulk response:\n " + baselineResponseBody, baselineResponseBody.get("errors"), equalTo(false)); - - assertThat(tuple.v2().getStatusLine().getStatusCode(), Matchers.equalTo(RestStatus.OK.getStatus())); - var contenderResponseBody = entityAsMap(tuple.v2()); - assertThat("errors in contender bulk response:\n " + contenderResponseBody, contenderResponseBody.get("errors"), equalTo(false)); + private void indexDocuments(List documents) throws IOException { + indexDocuments(() -> documents, () -> documents); } } diff --git a/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/StandardVersusLogsIndexModeRandomDataChallengeRestIT.java b/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/StandardVersusLogsIndexModeRandomDataChallengeRestIT.java index 1775f613f9d84..611f7fc5a9dcd 100644 --- a/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/StandardVersusLogsIndexModeRandomDataChallengeRestIT.java +++ b/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/StandardVersusLogsIndexModeRandomDataChallengeRestIT.java @@ -12,150 +12,44 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.time.DateFormatter; import org.elasticsearch.common.time.FormatNames; -import org.elasticsearch.core.CheckedConsumer; -import org.elasticsearch.index.mapper.Mapper; -import org.elasticsearch.index.mapper.ObjectMapper; -import org.elasticsearch.logsdb.datageneration.DataGenerator; -import org.elasticsearch.logsdb.datageneration.DataGeneratorSpecification; -import org.elasticsearch.logsdb.datageneration.FieldDataGenerator; -import org.elasticsearch.logsdb.datageneration.datasource.DataSourceHandler; -import org.elasticsearch.logsdb.datageneration.datasource.DataSourceRequest; -import org.elasticsearch.logsdb.datageneration.datasource.DataSourceResponse; -import org.elasticsearch.logsdb.datageneration.fields.PredefinedField; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import java.io.IOException; import java.time.Instant; -import java.util.HashMap; -import java.util.List; -import java.util.Map; /** * Challenge test (see {@link StandardVersusLogsIndexModeChallengeRestIT}) that uses randomly generated * mapping and documents in order to cover more code paths and permutations. */ public class StandardVersusLogsIndexModeRandomDataChallengeRestIT extends StandardVersusLogsIndexModeChallengeRestIT { - private final ObjectMapper.Subobjects subobjects; - private final boolean keepArraySource; - - private final DataGenerator dataGenerator; + protected final DataGenerationHelper dataGenerationHelper; public StandardVersusLogsIndexModeRandomDataChallengeRestIT() { super(); - // TODO enable subobjects: auto - // It is disabled because it currently does not have auto flattening and that results in asserts being triggered when using copy_to. - this.subobjects = randomValueOtherThan(ObjectMapper.Subobjects.AUTO, () -> randomFrom(ObjectMapper.Subobjects.values())); - this.keepArraySource = randomBoolean(); - - var specificationBuilder = DataGeneratorSpecification.builder().withFullyDynamicMapping(randomBoolean()); - if (subobjects != ObjectMapper.Subobjects.ENABLED) { - specificationBuilder = specificationBuilder.withNestedFieldsLimit(0); - } - this.dataGenerator = new DataGenerator(specificationBuilder.withDataSourceHandlers(List.of(new DataSourceHandler() { - @Override - public DataSourceResponse.ObjectMappingParametersGenerator handle(DataSourceRequest.ObjectMappingParametersGenerator request) { - if (subobjects == ObjectMapper.Subobjects.ENABLED) { - // Use default behavior - return null; - } - - assert request.isNested() == false; - - // "enabled: false" is not compatible with subobjects: false - // "dynamic: false/strict/runtime" is not compatible with subobjects: false - return new DataSourceResponse.ObjectMappingParametersGenerator(() -> { - var parameters = new HashMap(); - parameters.put("subobjects", subobjects.toString()); - if (ESTestCase.randomBoolean()) { - parameters.put("dynamic", "true"); - } - if (ESTestCase.randomBoolean()) { - parameters.put("enabled", "true"); - } - return parameters; - }); - } - })) - .withPredefinedFields( - List.of( - // Customized because it always needs doc_values for aggregations. - new PredefinedField.WithGenerator("host.name", new FieldDataGenerator() { - @Override - public CheckedConsumer mappingWriter() { - return b -> b.startObject().field("type", "keyword").endObject(); - } - - @Override - public CheckedConsumer fieldValueGenerator() { - return b -> b.value(randomAlphaOfLength(5)); - } - }), - // Needed for terms query - new PredefinedField.WithGenerator("method", new FieldDataGenerator() { - @Override - public CheckedConsumer mappingWriter() { - return b -> b.startObject().field("type", "keyword").endObject(); - } - - @Override - public CheckedConsumer fieldValueGenerator() { - return b -> b.value(randomFrom("put", "post", "get")); - } - }), - - // Needed for histogram aggregation - new PredefinedField.WithGenerator("memory_usage_bytes", new FieldDataGenerator() { - @Override - public CheckedConsumer mappingWriter() { - return b -> b.startObject().field("type", "long").endObject(); - } - - @Override - public CheckedConsumer fieldValueGenerator() { - // We can generate this using standard long field but we would get "too many buckets" - return b -> b.value(randomLongBetween(1000, 2000)); - } - }) - ) - ) - .build()); - } - - @Override - protected final Settings restClientSettings() { - return Settings.builder() - .put(super.restClientSettings()) - .put(org.elasticsearch.test.rest.ESRestTestCase.CLIENT_SOCKET_TIMEOUT, "9000s") - .build(); + dataGenerationHelper = new DataGenerationHelper(); } @Override public void baselineMappings(XContentBuilder builder) throws IOException { - dataGenerator.writeMapping(builder); + dataGenerationHelper.standardMapping(builder); } @Override public void contenderMappings(XContentBuilder builder) throws IOException { - if (subobjects != ObjectMapper.Subobjects.ENABLED) { - dataGenerator.writeMapping(builder, Map.of("subobjects", subobjects.toString())); - } else { - dataGenerator.writeMapping(builder); - } + dataGenerationHelper.logsDbMapping(builder); } @Override public void contenderSettings(Settings.Builder builder) { - if (keepArraySource) { - builder.put(Mapper.SYNTHETIC_SOURCE_KEEP_INDEX_SETTING.getKey(), "arrays"); - } + super.contenderSettings(builder); + dataGenerationHelper.logsDbSettings(builder); } @Override protected XContentBuilder generateDocument(final Instant timestamp) throws IOException { var document = XContentFactory.jsonBuilder(); - dataGenerator.generateDocument(document, doc -> { + dataGenerationHelper.getDataGenerator().generateDocument(document, doc -> { doc.field("@timestamp", DateFormatter.forPattern(FormatNames.STRICT_DATE_OPTIONAL_TIME.getName()).format(timestamp)); }); diff --git a/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/StandardVersusStandardReindexedIntoLogsDbChallengeRestIT.java b/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/StandardVersusStandardReindexedIntoLogsDbChallengeRestIT.java new file mode 100644 index 0000000000000..d6cfebed1445a --- /dev/null +++ b/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/logsdb/qa/StandardVersusStandardReindexedIntoLogsDbChallengeRestIT.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.datastreams.logsdb.qa; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; + +/** + * This test compares behavior of a standard index mode data stream and a + * logsdb data stream containing data reindexed from initial data stream. + * There should be no differences between such two data streams. + */ +public class StandardVersusStandardReindexedIntoLogsDbChallengeRestIT extends ReindexChallengeRestIT { + public String getBaselineDataStreamName() { + return "standard-apache-baseline"; + } + + public String getContenderDataStreamName() { + return "logs-apache-reindexed-contender"; + } + + @Override + public void baselineSettings(Settings.Builder builder) { + + } + + @Override + public void contenderSettings(Settings.Builder builder) { + dataGenerationHelper.logsDbSettings(builder); + } + + @Override + public void baselineMappings(XContentBuilder builder) throws IOException { + dataGenerationHelper.standardMapping(builder); + } + + @Override + public void contenderMappings(XContentBuilder builder) throws IOException { + dataGenerationHelper.logsDbMapping(builder); + } +} diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/DatabaseReaderLazyLoader.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/DatabaseReaderLazyLoader.java index fab9c76dcb8d7..f1594ddaf5144 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/DatabaseReaderLazyLoader.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/DatabaseReaderLazyLoader.java @@ -34,9 +34,7 @@ import java.io.Closeable; import java.io.IOException; -import java.io.InputStream; import java.net.InetAddress; -import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.util.Objects; @@ -66,23 +64,16 @@ class DatabaseReaderLazyLoader implements GeoIpDatabase, Closeable { private final AtomicInteger currentUsages = new AtomicInteger(0); DatabaseReaderLazyLoader(GeoIpCache cache, Path databasePath, String md5) { - this(cache, databasePath, md5, createDatabaseLoader(databasePath)); - } - - DatabaseReaderLazyLoader(GeoIpCache cache, Path databasePath, String md5, CheckedSupplier loader) { this.cache = cache; this.databasePath = Objects.requireNonNull(databasePath); this.md5 = md5; - this.loader = Objects.requireNonNull(loader); + this.loader = createDatabaseLoader(databasePath); this.databaseReader = new SetOnce<>(); this.databaseType = new SetOnce<>(); } /** - * Read the database type from the database. We do this manually instead of relying on the built-in mechanism to avoid reading the - * entire database into memory merely to read the type. This is especially important to maintain on master nodes where pipelines are - * validated. If we read the entire database into memory, we could potentially run into low-memory constraints on such nodes where - * loading this data would otherwise be wasteful if they are not also ingest nodes. + * Read the database type from the database and cache it for future calls. * * @return the database type * @throws IOException if an I/O exception occurs reading the database type @@ -92,71 +83,13 @@ public final String getDatabaseType() throws IOException { if (databaseType.get() == null) { synchronized (databaseType) { if (databaseType.get() == null) { - final long fileSize = databaseFileSize(); - if (fileSize <= 512) { - throw new IOException("unexpected file length [" + fileSize + "] for [" + databasePath + "]"); - } - final int[] databaseTypeMarker = { 'd', 'a', 't', 'a', 'b', 'a', 's', 'e', '_', 't', 'y', 'p', 'e' }; - try (InputStream in = databaseInputStream()) { - // read the last 512 bytes - final long skipped = in.skip(fileSize - 512); - if (skipped != fileSize - 512) { - throw new IOException("failed to skip [" + (fileSize - 512) + "] bytes while reading [" + databasePath + "]"); - } - final byte[] tail = new byte[512]; - int read = 0; - do { - final int actualBytesRead = in.read(tail, read, 512 - read); - if (actualBytesRead == -1) { - throw new IOException("unexpected end of stream [" + databasePath + "] after reading [" + read + "] bytes"); - } - read += actualBytesRead; - } while (read != 512); - - // find the database_type header - int metadataOffset = -1; - int markerOffset = 0; - for (int i = 0; i < tail.length; i++) { - byte b = tail[i]; - - if (b == databaseTypeMarker[markerOffset]) { - markerOffset++; - } else { - markerOffset = 0; - } - if (markerOffset == databaseTypeMarker.length) { - metadataOffset = i + 1; - break; - } - } - - if (metadataOffset == -1) { - throw new IOException("database type marker not found"); - } - - // read the database type - final int offsetByte = tail[metadataOffset] & 0xFF; - final int type = offsetByte >>> 5; - if (type != 2) { - throw new IOException("type must be UTF-8 string"); - } - int size = offsetByte & 0x1f; - databaseType.set(new String(tail, metadataOffset + 1, size, StandardCharsets.UTF_8)); - } + databaseType.set(MMDBUtil.getDatabaseType(databasePath)); } } } return databaseType.get(); } - long databaseFileSize() throws IOException { - return Files.size(databasePath); - } - - InputStream databaseInputStream() throws IOException { - return Files.newInputStream(databasePath); - } - @Nullable @Override public CityResponse getCity(InetAddress ipAddress) { diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/MMDBUtil.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/MMDBUtil.java new file mode 100644 index 0000000000000..b0d4d98701704 --- /dev/null +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/MMDBUtil.java @@ -0,0 +1,101 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.ingest.geoip; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; + +public final class MMDBUtil { + + private MMDBUtil() { + // utility class + } + + private static final byte[] DATABASE_TYPE_MARKER = "database_type".getBytes(StandardCharsets.UTF_8); + + // note: technically the metadata can be up to 128k long, but we only handle it correctly as long as it's less than + // or equal to this buffer size. in practice, that seems to be plenty for ordinary files. + private static final int BUFFER_SIZE = 2048; + + /** + * Read the database type from the database. We do this manually instead of relying on the built-in mechanism to avoid reading the + * entire database into memory merely to read the type. This is especially important to maintain on master nodes where pipelines are + * validated. If we read the entire database into memory, we could potentially run into low-memory constraints on such nodes where + * loading this data would otherwise be wasteful if they are not also ingest nodes. + * + * @return the database type + * @throws IOException if an I/O exception occurs reading the database type + */ + public static String getDatabaseType(final Path database) throws IOException { + final long fileSize = Files.size(database); + try (InputStream in = Files.newInputStream(database)) { + // read the last BUFFER_SIZE bytes (or the fileSize, whichever is smaller) + final long skip = fileSize > BUFFER_SIZE ? fileSize - BUFFER_SIZE : 0; + final long skipped = in.skip(skip); + if (skipped != skip) { + throw new IOException("failed to skip [" + skip + "] bytes while reading [" + database + "]"); + } + final byte[] tail = new byte[BUFFER_SIZE]; + int read = 0; + int actualBytesRead; + do { + actualBytesRead = in.read(tail, read, BUFFER_SIZE - read); + read += actualBytesRead; + } while (actualBytesRead > 0); + + // find the database_type header + int metadataOffset = -1; + int markerOffset = 0; + for (int i = 0; i < tail.length; i++) { + byte b = tail[i]; + + if (b == DATABASE_TYPE_MARKER[markerOffset]) { + markerOffset++; + } else { + markerOffset = 0; + } + if (markerOffset == DATABASE_TYPE_MARKER.length) { + metadataOffset = i + 1; + break; + } + } + + if (metadataOffset == -1) { + throw new IOException("database type marker not found"); + } + + // read the database type + final int offsetByte = fromBytes(tail[metadataOffset]); + final int type = offsetByte >>> 5; + if (type != 2) { // 2 is the type indicator in the mmdb format for a UTF-8 string + throw new IOException("type must be UTF-8 string"); + } + int size = offsetByte & 0x1f; + if (size == 29) { + // then we need to read in yet another byte and add it onto this size + // this can actually occur in practice, a 29+ character type description isn't that hard to imagine + size = 29 + fromBytes(tail[metadataOffset + 1]); + metadataOffset += 1; + } else if (size >= 30) { + // we'd need to read two or three more bytes to get the size, but this means the type length is >=285 + throw new IOException("database_type too long [size indicator == " + size + "]"); + } + + return new String(tail, metadataOffset + 1, size, StandardCharsets.UTF_8); + } + } + + private static int fromBytes(byte b1) { + return b1 & 0xFF; + } +} diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java index 32c1979939e0e..4a5d445e3ff5b 100644 --- a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java +++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java @@ -9,38 +9,52 @@ package org.elasticsearch.ingest.geoip; -import com.maxmind.geoip2.DatabaseReader; - import org.elasticsearch.common.CheckedSupplier; import org.elasticsearch.common.util.set.Sets; -import org.elasticsearch.core.PathUtils; +import org.elasticsearch.core.IOUtils; import org.elasticsearch.ingest.IngestDocument; import org.elasticsearch.ingest.RandomDocumentPicks; import org.elasticsearch.ingest.geoip.Database.Property; import org.elasticsearch.test.ESTestCase; +import org.junit.After; +import org.junit.Before; import java.io.IOException; -import java.io.InputStream; +import java.nio.file.Path; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Supplier; import static org.elasticsearch.ingest.IngestDocumentMatcher.assertIngestDocument; +import static org.elasticsearch.ingest.geoip.GeoIpTestUtils.copyDatabase; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasEntry; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; public class GeoIpProcessorTests extends ESTestCase { private static final Set ALL_PROPERTIES = Set.of(Property.values()); + // a temporary directory that mmdb files can be copied to and read from + Path tmpDir; + + @Before + public void setup() { + tmpDir = createTempDir(); + } + + @After + public void cleanup() throws IOException { + IOUtils.rm(tmpDir); + } + public void testDatabasePropertyInvariants() { // the city database is like a specialization of the country database assertThat(Sets.difference(Database.Country.properties(), Database.City.properties()), is(empty())); @@ -60,11 +74,12 @@ public void testDatabasePropertyInvariants() { } public void testCity() throws Exception { + String ip = "8.8.8.8"; GeoIpProcessor processor = new GeoIpProcessor( randomAlphaOfLength(10), null, "source_field", - loader("/GeoLite2-City.mmdb"), + loader("GeoLite2-City.mmdb"), () -> true, "target_field", ALL_PROPERTIES, @@ -74,24 +89,22 @@ public void testCity() throws Exception { ); Map document = new HashMap<>(); - document.put("source_field", "8.8.8.8"); + document.put("source_field", ip); IngestDocument ingestDocument = RandomDocumentPicks.randomIngestDocument(random(), document); processor.execute(ingestDocument); - assertThat(ingestDocument.getSourceAndMetadata().get("source_field"), equalTo("8.8.8.8")); + assertThat(ingestDocument.getSourceAndMetadata().get("source_field"), equalTo(ip)); @SuppressWarnings("unchecked") Map geoData = (Map) ingestDocument.getSourceAndMetadata().get("target_field"); + assertThat(geoData, notNullValue()); assertThat(geoData.size(), equalTo(7)); - assertThat(geoData.get("ip"), equalTo("8.8.8.8")); + assertThat(geoData.get("ip"), equalTo(ip)); assertThat(geoData.get("country_iso_code"), equalTo("US")); assertThat(geoData.get("country_name"), equalTo("United States")); assertThat(geoData.get("continent_code"), equalTo("NA")); assertThat(geoData.get("continent_name"), equalTo("North America")); assertThat(geoData.get("timezone"), equalTo("America/Chicago")); - Map location = new HashMap<>(); - location.put("lat", 37.751d); - location.put("lon", -97.822d); - assertThat(geoData.get("location"), equalTo(location)); + assertThat(geoData.get("location"), equalTo(Map.of("lat", 37.751d, "lon", -97.822d))); } public void testNullValueWithIgnoreMissing() throws Exception { @@ -99,7 +112,7 @@ public void testNullValueWithIgnoreMissing() throws Exception { randomAlphaOfLength(10), null, "source_field", - loader("/GeoLite2-City.mmdb"), + loader("GeoLite2-City.mmdb"), () -> true, "target_field", ALL_PROPERTIES, @@ -121,7 +134,7 @@ public void testNonExistentWithIgnoreMissing() throws Exception { randomAlphaOfLength(10), null, "source_field", - loader("/GeoLite2-City.mmdb"), + loader("GeoLite2-City.mmdb"), () -> true, "target_field", ALL_PROPERTIES, @@ -140,7 +153,7 @@ public void testNullWithoutIgnoreMissing() { randomAlphaOfLength(10), null, "source_field", - loader("/GeoLite2-City.mmdb"), + loader("GeoLite2-City.mmdb"), () -> true, "target_field", ALL_PROPERTIES, @@ -162,7 +175,7 @@ public void testNonExistentWithoutIgnoreMissing() { randomAlphaOfLength(10), null, "source_field", - loader("/GeoLite2-City.mmdb"), + loader("GeoLite2-City.mmdb"), () -> true, "target_field", ALL_PROPERTIES, @@ -177,11 +190,12 @@ public void testNonExistentWithoutIgnoreMissing() { } public void testCity_withIpV6() throws Exception { + String ip = "2602:306:33d3:8000::3257:9652"; GeoIpProcessor processor = new GeoIpProcessor( randomAlphaOfLength(10), null, "source_field", - loader("/GeoLite2-City.mmdb"), + loader("GeoLite2-City.mmdb"), () -> true, "target_field", ALL_PROPERTIES, @@ -190,17 +204,17 @@ public void testCity_withIpV6() throws Exception { "filename" ); - String address = "2602:306:33d3:8000::3257:9652"; Map document = new HashMap<>(); - document.put("source_field", address); + document.put("source_field", ip); IngestDocument ingestDocument = RandomDocumentPicks.randomIngestDocument(random(), document); processor.execute(ingestDocument); - assertThat(ingestDocument.getSourceAndMetadata().get("source_field"), equalTo(address)); + assertThat(ingestDocument.getSourceAndMetadata().get("source_field"), equalTo(ip)); @SuppressWarnings("unchecked") Map geoData = (Map) ingestDocument.getSourceAndMetadata().get("target_field"); + assertThat(geoData, notNullValue()); assertThat(geoData.size(), equalTo(10)); - assertThat(geoData.get("ip"), equalTo(address)); + assertThat(geoData.get("ip"), equalTo(ip)); assertThat(geoData.get("country_iso_code"), equalTo("US")); assertThat(geoData.get("country_name"), equalTo("United States")); assertThat(geoData.get("continent_code"), equalTo("NA")); @@ -209,18 +223,16 @@ public void testCity_withIpV6() throws Exception { assertThat(geoData.get("region_name"), equalTo("Florida")); assertThat(geoData.get("city_name"), equalTo("Homestead")); assertThat(geoData.get("timezone"), equalTo("America/New_York")); - Map location = new HashMap<>(); - location.put("lat", 25.4573d); - location.put("lon", -80.4572d); - assertThat(geoData.get("location"), equalTo(location)); + assertThat(geoData.get("location"), equalTo(Map.of("lat", 25.4573d, "lon", -80.4572d))); } public void testCityWithMissingLocation() throws Exception { + String ip = "80.231.5.0"; GeoIpProcessor processor = new GeoIpProcessor( randomAlphaOfLength(10), null, "source_field", - loader("/GeoLite2-City.mmdb"), + loader("GeoLite2-City.mmdb"), () -> true, "target_field", ALL_PROPERTIES, @@ -230,23 +242,25 @@ public void testCityWithMissingLocation() throws Exception { ); Map document = new HashMap<>(); - document.put("source_field", "80.231.5.0"); + document.put("source_field", ip); IngestDocument ingestDocument = RandomDocumentPicks.randomIngestDocument(random(), document); processor.execute(ingestDocument); - assertThat(ingestDocument.getSourceAndMetadata().get("source_field"), equalTo("80.231.5.0")); + assertThat(ingestDocument.getSourceAndMetadata().get("source_field"), equalTo(ip)); @SuppressWarnings("unchecked") Map geoData = (Map) ingestDocument.getSourceAndMetadata().get("target_field"); + assertThat(geoData, notNullValue()); assertThat(geoData.size(), equalTo(1)); - assertThat(geoData.get("ip"), equalTo("80.231.5.0")); + assertThat(geoData.get("ip"), equalTo(ip)); } public void testCountry() throws Exception { + String ip = "82.170.213.79"; GeoIpProcessor processor = new GeoIpProcessor( randomAlphaOfLength(10), null, "source_field", - loader("/GeoLite2-Country.mmdb"), + loader("GeoLite2-Country.mmdb"), () -> true, "target_field", ALL_PROPERTIES, @@ -256,15 +270,16 @@ public void testCountry() throws Exception { ); Map document = new HashMap<>(); - document.put("source_field", "82.170.213.79"); + document.put("source_field", ip); IngestDocument ingestDocument = RandomDocumentPicks.randomIngestDocument(random(), document); processor.execute(ingestDocument); - assertThat(ingestDocument.getSourceAndMetadata().get("source_field"), equalTo("82.170.213.79")); + assertThat(ingestDocument.getSourceAndMetadata().get("source_field"), equalTo(ip)); @SuppressWarnings("unchecked") Map geoData = (Map) ingestDocument.getSourceAndMetadata().get("target_field"); + assertThat(geoData, notNullValue()); assertThat(geoData.size(), equalTo(5)); - assertThat(geoData.get("ip"), equalTo("82.170.213.79")); + assertThat(geoData.get("ip"), equalTo(ip)); assertThat(geoData.get("country_iso_code"), equalTo("NL")); assertThat(geoData.get("country_name"), equalTo("Netherlands")); assertThat(geoData.get("continent_code"), equalTo("EU")); @@ -272,11 +287,12 @@ public void testCountry() throws Exception { } public void testCountryWithMissingLocation() throws Exception { + String ip = "80.231.5.0"; GeoIpProcessor processor = new GeoIpProcessor( randomAlphaOfLength(10), null, "source_field", - loader("/GeoLite2-Country.mmdb"), + loader("GeoLite2-Country.mmdb"), () -> true, "target_field", ALL_PROPERTIES, @@ -286,15 +302,16 @@ public void testCountryWithMissingLocation() throws Exception { ); Map document = new HashMap<>(); - document.put("source_field", "80.231.5.0"); + document.put("source_field", ip); IngestDocument ingestDocument = RandomDocumentPicks.randomIngestDocument(random(), document); processor.execute(ingestDocument); - assertThat(ingestDocument.getSourceAndMetadata().get("source_field"), equalTo("80.231.5.0")); + assertThat(ingestDocument.getSourceAndMetadata().get("source_field"), equalTo(ip)); @SuppressWarnings("unchecked") Map geoData = (Map) ingestDocument.getSourceAndMetadata().get("target_field"); + assertThat(geoData, notNullValue()); assertThat(geoData.size(), equalTo(1)); - assertThat(geoData.get("ip"), equalTo("80.231.5.0")); + assertThat(geoData.get("ip"), equalTo(ip)); } public void testAsn() throws Exception { @@ -303,7 +320,7 @@ public void testAsn() throws Exception { randomAlphaOfLength(10), null, "source_field", - loader("/GeoLite2-ASN.mmdb"), + loader("GeoLite2-ASN.mmdb"), () -> true, "target_field", ALL_PROPERTIES, @@ -320,6 +337,7 @@ public void testAsn() throws Exception { assertThat(ingestDocument.getSourceAndMetadata().get("source_field"), equalTo(ip)); @SuppressWarnings("unchecked") Map geoData = (Map) ingestDocument.getSourceAndMetadata().get("target_field"); + assertThat(geoData, notNullValue()); assertThat(geoData.size(), equalTo(4)); assertThat(geoData.get("ip"), equalTo(ip)); assertThat(geoData.get("asn"), equalTo(1136L)); @@ -333,7 +351,7 @@ public void testAnonymmousIp() throws Exception { randomAlphaOfLength(10), null, "source_field", - loader("/GeoIP2-Anonymous-IP-Test.mmdb"), + loader("GeoIP2-Anonymous-IP-Test.mmdb"), () -> true, "target_field", ALL_PROPERTIES, @@ -350,6 +368,7 @@ public void testAnonymmousIp() throws Exception { assertThat(ingestDocument.getSourceAndMetadata().get("source_field"), equalTo(ip)); @SuppressWarnings("unchecked") Map geoData = (Map) ingestDocument.getSourceAndMetadata().get("target_field"); + assertThat(geoData, notNullValue()); assertThat(geoData.size(), equalTo(7)); assertThat(geoData.get("ip"), equalTo(ip)); assertThat(geoData.get("hosting_provider"), equalTo(true)); @@ -366,7 +385,7 @@ public void testConnectionType() throws Exception { randomAlphaOfLength(10), null, "source_field", - loader("/GeoIP2-Connection-Type-Test.mmdb"), + loader("GeoIP2-Connection-Type-Test.mmdb"), () -> true, "target_field", ALL_PROPERTIES, @@ -383,6 +402,7 @@ public void testConnectionType() throws Exception { assertThat(ingestDocument.getSourceAndMetadata().get("source_field"), equalTo(ip)); @SuppressWarnings("unchecked") Map geoData = (Map) ingestDocument.getSourceAndMetadata().get("target_field"); + assertThat(geoData, notNullValue()); assertThat(geoData.size(), equalTo(2)); assertThat(geoData.get("ip"), equalTo(ip)); assertThat(geoData.get("connection_type"), equalTo("Satellite")); @@ -394,7 +414,7 @@ public void testDomain() throws Exception { randomAlphaOfLength(10), null, "source_field", - loader("/GeoIP2-Domain-Test.mmdb"), + loader("GeoIP2-Domain-Test.mmdb"), () -> true, "target_field", ALL_PROPERTIES, @@ -411,6 +431,7 @@ public void testDomain() throws Exception { assertThat(ingestDocument.getSourceAndMetadata().get("source_field"), equalTo(ip)); @SuppressWarnings("unchecked") Map geoData = (Map) ingestDocument.getSourceAndMetadata().get("target_field"); + assertThat(geoData, notNullValue()); assertThat(geoData.size(), equalTo(2)); assertThat(geoData.get("ip"), equalTo(ip)); assertThat(geoData.get("domain"), equalTo("ameritech.net")); @@ -422,7 +443,7 @@ public void testEnterprise() throws Exception { randomAlphaOfLength(10), null, "source_field", - loader("/GeoIP2-Enterprise-Test.mmdb"), + loader("GeoIP2-Enterprise-Test.mmdb"), () -> true, "target_field", ALL_PROPERTIES, @@ -439,6 +460,7 @@ public void testEnterprise() throws Exception { assertThat(ingestDocument.getSourceAndMetadata().get("source_field"), equalTo(ip)); @SuppressWarnings("unchecked") Map geoData = (Map) ingestDocument.getSourceAndMetadata().get("target_field"); + assertThat(geoData, notNullValue()); assertThat(geoData.size(), equalTo(24)); assertThat(geoData.get("ip"), equalTo(ip)); assertThat(geoData.get("country_iso_code"), equalTo("US")); @@ -449,10 +471,7 @@ public void testEnterprise() throws Exception { assertThat(geoData.get("region_name"), equalTo("New York")); assertThat(geoData.get("city_name"), equalTo("Chatham")); assertThat(geoData.get("timezone"), equalTo("America/New_York")); - Map location = new HashMap<>(); - location.put("lat", 42.3478); - location.put("lon", -73.5549); - assertThat(geoData.get("location"), equalTo(location)); + assertThat(geoData.get("location"), equalTo(Map.of("lat", 42.3478, "lon", -73.5549))); assertThat(geoData.get("asn"), equalTo(14671L)); assertThat(geoData.get("organization_name"), equalTo("FairPoint Communications")); assertThat(geoData.get("network"), equalTo("74.209.16.0/20")); @@ -475,7 +494,7 @@ public void testIsp() throws Exception { randomAlphaOfLength(10), null, "source_field", - loader("/GeoIP2-ISP-Test.mmdb"), + loader("GeoIP2-ISP-Test.mmdb"), () -> true, "target_field", ALL_PROPERTIES, @@ -492,6 +511,7 @@ public void testIsp() throws Exception { assertThat(ingestDocument.getSourceAndMetadata().get("source_field"), equalTo(ip)); @SuppressWarnings("unchecked") Map geoData = (Map) ingestDocument.getSourceAndMetadata().get("target_field"); + assertThat(geoData, notNullValue()); assertThat(geoData.size(), equalTo(8)); assertThat(geoData.get("ip"), equalTo(ip)); assertThat(geoData.get("asn"), equalTo(6167L)); @@ -508,7 +528,7 @@ public void testAddressIsNotInTheDatabase() throws Exception { randomAlphaOfLength(10), null, "source_field", - loader("/GeoLite2-City.mmdb"), + loader("GeoLite2-City.mmdb"), () -> true, "target_field", ALL_PROPERTIES, @@ -532,7 +552,7 @@ public void testInvalid() { randomAlphaOfLength(10), null, "source_field", - loader("/GeoLite2-City.mmdb"), + loader("GeoLite2-City.mmdb"), () -> true, "target_field", ALL_PROPERTIES, @@ -553,7 +573,7 @@ public void testListAllValid() throws Exception { randomAlphaOfLength(10), null, "source_field", - loader("/GeoLite2-City.mmdb"), + loader("GeoLite2-City.mmdb"), () -> true, "target_field", ALL_PROPERTIES, @@ -569,12 +589,9 @@ public void testListAllValid() throws Exception { @SuppressWarnings("unchecked") List> geoData = (List>) ingestDocument.getSourceAndMetadata().get("target_field"); - - Map location = new HashMap<>(); - location.put("lat", 37.751d); - location.put("lon", -97.822d); - assertThat(geoData.get(0).get("location"), equalTo(location)); - + assertThat(geoData, notNullValue()); + assertThat(geoData.size(), equalTo(2)); + assertThat(geoData.get(0).get("location"), equalTo(Map.of("lat", 37.751d, "lon", -97.822d))); assertThat(geoData.get(1).get("city_name"), equalTo("Hoensbroek")); } @@ -583,7 +600,7 @@ public void testListPartiallyValid() throws Exception { randomAlphaOfLength(10), null, "source_field", - loader("/GeoLite2-City.mmdb"), + loader("GeoLite2-City.mmdb"), () -> true, "target_field", ALL_PROPERTIES, @@ -599,12 +616,9 @@ public void testListPartiallyValid() throws Exception { @SuppressWarnings("unchecked") List> geoData = (List>) ingestDocument.getSourceAndMetadata().get("target_field"); - - Map location = new HashMap<>(); - location.put("lat", 37.751d); - location.put("lon", -97.822d); - assertThat(geoData.get(0).get("location"), equalTo(location)); - + assertThat(geoData, notNullValue()); + assertThat(geoData.size(), equalTo(2)); + assertThat(geoData.get(0).get("location"), equalTo(Map.of("lat", 37.751d, "lon", -97.822d))); assertThat(geoData.get(1), nullValue()); } @@ -613,7 +627,7 @@ public void testListNoMatches() throws Exception { randomAlphaOfLength(10), null, "source_field", - loader("/GeoLite2-City.mmdb"), + loader("GeoLite2-City.mmdb"), () -> true, "target_field", ALL_PROPERTIES, @@ -632,7 +646,7 @@ public void testListNoMatches() throws Exception { public void testListDatabaseReferenceCounting() throws Exception { AtomicBoolean closeCheck = new AtomicBoolean(false); - var loader = loader("/GeoLite2-City.mmdb", closeCheck); + var loader = loader("GeoLite2-City.mmdb", closeCheck); GeoIpProcessor processor = new GeoIpProcessor(randomAlphaOfLength(10), null, "source_field", () -> { loader.preLookup(); return loader; @@ -645,12 +659,9 @@ public void testListDatabaseReferenceCounting() throws Exception { @SuppressWarnings("unchecked") List> geoData = (List>) ingestDocument.getSourceAndMetadata().get("target_field"); - - Map location = new HashMap<>(); - location.put("lat", 37.751d); - location.put("lon", -97.822d); - assertThat(geoData.get(0).get("location"), equalTo(location)); - + assertThat(geoData, notNullValue()); + assertThat(geoData.size(), equalTo(2)); + assertThat(geoData.get(0).get("location"), equalTo(Map.of("lat", 37.751d, "lon", -97.822d))); assertThat(geoData.get(1).get("city_name"), equalTo("Hoensbroek")); // Check the loader's reference count and attempt to close @@ -664,7 +675,7 @@ public void testListFirstOnly() throws Exception { randomAlphaOfLength(10), null, "source_field", - loader("/GeoLite2-City.mmdb"), + loader("GeoLite2-City.mmdb"), () -> true, "target_field", ALL_PROPERTIES, @@ -680,11 +691,8 @@ public void testListFirstOnly() throws Exception { @SuppressWarnings("unchecked") Map geoData = (Map) ingestDocument.getSourceAndMetadata().get("target_field"); - - Map location = new HashMap<>(); - location.put("lat", 37.751d); - location.put("lon", -97.822d); - assertThat(geoData.get("location"), equalTo(location)); + assertThat(geoData, notNullValue()); + assertThat(geoData.get("location"), equalTo(Map.of("lat", 37.751d, "lon", -97.822d))); } public void testListFirstOnlyNoMatches() throws Exception { @@ -692,7 +700,7 @@ public void testListFirstOnlyNoMatches() throws Exception { randomAlphaOfLength(10), null, "source_field", - loader("/GeoLite2-City.mmdb"), + loader("GeoLite2-City.mmdb"), () -> true, "target_field", ALL_PROPERTIES, @@ -714,7 +722,7 @@ public void testInvalidDatabase() throws Exception { randomAlphaOfLength(10), null, "source_field", - loader("/GeoLite2-City.mmdb"), + loader("GeoLite2-City.mmdb"), () -> false, "target_field", ALL_PROPERTIES, @@ -782,32 +790,12 @@ private CheckedSupplier loader(final String path) { return () -> loader; } - private DatabaseReaderLazyLoader loader(final String path, final AtomicBoolean closed) { - final Supplier databaseInputStreamSupplier = () -> GeoIpProcessor.class.getResourceAsStream(path); - final CheckedSupplier loader = () -> new DatabaseReader.Builder(databaseInputStreamSupplier.get()) - .build(); - final GeoIpCache cache = new GeoIpCache(1000); - return new DatabaseReaderLazyLoader(cache, PathUtils.get(path), null, loader) { - - @Override - long databaseFileSize() throws IOException { - try (InputStream is = databaseInputStreamSupplier.get()) { - long bytesRead = 0; - do { - final byte[] bytes = new byte[1 << 10]; - final int read = is.read(bytes); - if (read == -1) break; - bytesRead += read; - } while (true); - return bytesRead; - } - } - - @Override - InputStream databaseInputStream() { - return databaseInputStreamSupplier.get(); - } + private DatabaseReaderLazyLoader loader(final String databaseName, final AtomicBoolean closed) { + Path path = tmpDir.resolve(databaseName); + copyDatabase(databaseName, path); + final GeoIpCache cache = new GeoIpCache(1000); + return new DatabaseReaderLazyLoader(cache, path, null) { @Override protected void doClose() throws IOException { if (closed != null) { diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/MMDBUtilTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/MMDBUtilTests.java new file mode 100644 index 0000000000000..3da15dbdad305 --- /dev/null +++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/MMDBUtilTests.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.ingest.geoip; + +import org.elasticsearch.core.IOUtils; +import org.elasticsearch.test.ESTestCase; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; + +import static org.elasticsearch.ingest.geoip.GeoIpTestUtils.copyDatabase; +import static org.hamcrest.Matchers.endsWith; +import static org.hamcrest.Matchers.hasLength; +import static org.hamcrest.Matchers.is; + +public class MMDBUtilTests extends ESTestCase { + + // a temporary directory that mmdb files can be copied to and read from + Path tmpDir; + + @Before + public void setup() { + tmpDir = createTempDir(); + } + + @After + public void cleanup() throws IOException { + IOUtils.rm(tmpDir); + } + + public void testGetDatabaseTypeGeoIP2City() throws IOException { + Path database = tmpDir.resolve("GeoIP2-City.mmdb"); + copyDatabase("GeoIP2-City-Test.mmdb", database); + + String type = MMDBUtil.getDatabaseType(database); + assertThat(type, is("GeoIP2-City")); + } + + public void testGetDatabaseTypeGeoLite2City() throws IOException { + Path database = tmpDir.resolve("GeoLite2-City.mmdb"); + copyDatabase("GeoLite2-City-Test.mmdb", database); + + String type = MMDBUtil.getDatabaseType(database); + assertThat(type, is("GeoLite2-City")); + } + + public void testSmallFileWithALongDescription() throws IOException { + Path database = tmpDir.resolve("test-description.mmdb"); + copyDatabase("test-description.mmdb", database); + + // it was once the case that we couldn't read a database_type that was 29 characters or longer + String type = MMDBUtil.getDatabaseType(database); + assertThat(type, endsWith("long database_type")); + assertThat(type, hasLength(60)); // 60 is >= 29, ;) + + // it was once the case that we couldn't process an mmdb that was smaller than 512 bytes + assertThat(Files.size(database), is(444L)); // 444 is <512 + } +} diff --git a/modules/ingest-geoip/src/test/resources/test-description.mmdb b/modules/ingest-geoip/src/test/resources/test-description.mmdb new file mode 100644 index 0000000000000..f57af8726d8d5 Binary files /dev/null and b/modules/ingest-geoip/src/test/resources/test-description.mmdb differ diff --git a/muted-tests.yml b/muted-tests.yml index d5c0a305c2728..e02f8208b7fb0 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -118,8 +118,6 @@ tests: - class: org.elasticsearch.search.retriever.RankDocRetrieverBuilderIT method: testRankDocsRetrieverWithCollapse issue: https://github.com/elastic/elasticsearch/issues/112254 -- class: org.elasticsearch.search.ccs.CCSUsageTelemetryIT - issue: https://github.com/elastic/elasticsearch/issues/112324 - class: org.elasticsearch.datastreams.logsdb.qa.StandardVersusLogsIndexModeRandomDataChallengeRestIT method: testMatchAllQuery issue: https://github.com/elastic/elasticsearch/issues/112374 @@ -129,14 +127,9 @@ tests: - class: org.elasticsearch.xpack.ml.integration.MlJobIT method: testMultiIndexDelete issue: https://github.com/elastic/elasticsearch/issues/112381 -- class: org.elasticsearch.action.admin.cluster.stats.CCSTelemetrySnapshotTests - method: testToXContent - issue: https://github.com/elastic/elasticsearch/issues/112325 - class: org.elasticsearch.search.retriever.RankDocRetrieverBuilderIT method: testRankDocsRetrieverWithNestedQuery issue: https://github.com/elastic/elasticsearch/issues/112421 -- class: org.elasticsearch.indices.mapping.UpdateMappingIntegrationIT - issue: https://github.com/elastic/elasticsearch/issues/112423 - class: org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroidTests method: "testAggregateIntermediate {TestCase= #2}" issue: https://github.com/elastic/elasticsearch/issues/112461 @@ -217,13 +210,14 @@ tests: - class: org.elasticsearch.xpack.ml.integration.MlJobIT method: testDeleteJobAfterMissingAliases issue: https://github.com/elastic/elasticsearch/issues/112823 -- class: org.elasticsearch.repositories.blobstore.testkit.analyze.HdfsRepositoryAnalysisRestIT - issue: https://github.com/elastic/elasticsearch/issues/112889 - class: org.elasticsearch.xpack.test.rest.XPackRestIT issue: https://github.com/elastic/elasticsearch/issues/111944 - class: org.elasticsearch.datastreams.logsdb.qa.StandardVersusLogsIndexModeRandomDataChallengeRestIT method: testDateHistogramAggregation issue: https://github.com/elastic/elasticsearch/issues/112919 +- class: org.elasticsearch.datastreams.logsdb.qa.StandardVersusLogsIndexModeRandomDataChallengeRestIT + method: testTermsQuery + issue: https://github.com/elastic/elasticsearch/issues/112462 # Examples: # diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportShardBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportShardBulkAction.java index 9e6211ba0f654..74143cc5c059b 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportShardBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportShardBulkAction.java @@ -50,6 +50,7 @@ import org.elasticsearch.index.mapper.DocumentMapper; import org.elasticsearch.index.mapper.MapperException; import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.mapper.MappingLookup; import org.elasticsearch.index.mapper.SourceToParse; import org.elasticsearch.index.seqno.SequenceNumbers; import org.elasticsearch.index.shard.IndexShard; @@ -480,7 +481,17 @@ && isConflictException(executionResult.getFailure().getCause()) } final BulkItemResponse response; if (isUpdate) { - response = processUpdateResponse((UpdateRequest) docWriteRequest, context.getConcreteIndex(), executionResult, updateResult); + assert context.getPrimary().mapperService() != null; + final MappingLookup mappingLookup = context.getPrimary().mapperService().mappingLookup(); + assert mappingLookup != null; + + response = processUpdateResponse( + (UpdateRequest) docWriteRequest, + context.getConcreteIndex(), + mappingLookup, + executionResult, + updateResult + ); } else { if (isFailed) { final Exception failure = executionResult.getFailure().getCause(); @@ -518,6 +529,7 @@ private static boolean isConflictException(final Exception e) { private static BulkItemResponse processUpdateResponse( final UpdateRequest updateRequest, final String concreteIndex, + final MappingLookup mappingLookup, BulkItemResponse operationResponse, final UpdateHelper.Result translate ) { @@ -555,6 +567,7 @@ private static BulkItemResponse processUpdateResponse( UpdateHelper.extractGetResult( updateRequest, concreteIndex, + mappingLookup, indexResponse.getSeqNo(), indexResponse.getPrimaryTerm(), indexResponse.getVersion(), @@ -579,6 +592,7 @@ private static BulkItemResponse processUpdateResponse( final GetResult getResult = UpdateHelper.extractGetResult( updateRequest, concreteIndex, + mappingLookup, deleteResponse.getSeqNo(), deleteResponse.getPrimaryTerm(), deleteResponse.getVersion(), diff --git a/server/src/main/java/org/elasticsearch/action/ingest/DeletePipelineRequest.java b/server/src/main/java/org/elasticsearch/action/ingest/DeletePipelineRequest.java index a3be50b282b0b..ec8e9bdd8dde9 100644 --- a/server/src/main/java/org/elasticsearch/action/ingest/DeletePipelineRequest.java +++ b/server/src/main/java/org/elasticsearch/action/ingest/DeletePipelineRequest.java @@ -12,6 +12,7 @@ import org.elasticsearch.action.support.master.AcknowledgedRequest; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.TimeValue; import java.io.IOException; import java.util.Objects; @@ -20,8 +21,8 @@ public class DeletePipelineRequest extends AcknowledgedRequest /** * Create a new pipeline request with the id and source along with the content type of the source */ - public PutPipelineRequest(String id, BytesReference source, XContentType xContentType, Integer version) { - super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT); + public PutPipelineRequest( + TimeValue masterNodeTimeout, + TimeValue ackTimeout, + String id, + BytesReference source, + XContentType xContentType, + Integer version + ) { + super(masterNodeTimeout, ackTimeout); this.id = Objects.requireNonNull(id); this.source = Objects.requireNonNull(source); this.xContentType = Objects.requireNonNull(xContentType); this.version = version; } - public PutPipelineRequest(String id, BytesReference source, XContentType xContentType) { - this(id, source, xContentType, null); + public PutPipelineRequest( + TimeValue masterNodeTimeout, + TimeValue ackTimeout, + String id, + BytesReference source, + XContentType xContentType + ) { + this(masterNodeTimeout, ackTimeout, id, source, xContentType, null); } public PutPipelineRequest(StreamInput in) throws IOException { diff --git a/server/src/main/java/org/elasticsearch/action/ingest/ReservedPipelineAction.java b/server/src/main/java/org/elasticsearch/action/ingest/ReservedPipelineAction.java index e55447a36d17a..b5908a3c39fbc 100644 --- a/server/src/main/java/org/elasticsearch/action/ingest/ReservedPipelineAction.java +++ b/server/src/main/java/org/elasticsearch/action/ingest/ReservedPipelineAction.java @@ -101,7 +101,14 @@ public TransformState transform(Object source, TransformState prevState) throws toDelete.removeAll(entities); for (var pipelineToDelete : toDelete) { - var task = new IngestService.DeletePipelineClusterStateUpdateTask(null, new DeletePipelineRequest(pipelineToDelete)); + var task = new IngestService.DeletePipelineClusterStateUpdateTask( + null, + new DeletePipelineRequest( + RESERVED_CLUSTER_STATE_HANDLER_IGNORED_TIMEOUT, + RESERVED_CLUSTER_STATE_HANDLER_IGNORED_TIMEOUT, + pipelineToDelete + ) + ); state = wrapIngestTaskExecute(task, state); } @@ -119,7 +126,15 @@ public List fromXContent(XContentParser parser) throws IOExc Map content = (Map) source.get(id); try (XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON)) { builder.map(content); - result.add(new PutPipelineRequest(id, BytesReference.bytes(builder), XContentType.JSON)); + result.add( + new PutPipelineRequest( + RESERVED_CLUSTER_STATE_HANDLER_IGNORED_TIMEOUT, + RESERVED_CLUSTER_STATE_HANDLER_IGNORED_TIMEOUT, + id, + BytesReference.bytes(builder), + XContentType.JSON + ) + ); } catch (Exception e) { throw new ElasticsearchGenerationException("Failed to generate [" + source + "]", e); } diff --git a/server/src/main/java/org/elasticsearch/action/update/TransportUpdateAction.java b/server/src/main/java/org/elasticsearch/action/update/TransportUpdateAction.java index 877f4bb922d7e..0749512635f83 100644 --- a/server/src/main/java/org/elasticsearch/action/update/TransportUpdateAction.java +++ b/server/src/main/java/org/elasticsearch/action/update/TransportUpdateAction.java @@ -187,11 +187,12 @@ protected void shardOperation(final UpdateRequest request, final ActionListener< final ShardId shardId = request.getShardId(); final IndexService indexService = indicesService.indexServiceSafe(shardId.getIndex()); final IndexShard indexShard = indexService.getShard(shardId.getId()); + final MappingLookup mappingLookup = indexShard.mapperService().mappingLookup(); final UpdateHelper.Result result = deleteInferenceResults( request, updateHelper.prepare(request, indexShard, threadPool::absoluteTimeInMillis), indexService.getMetadata(), - indexShard.mapperService().mappingLookup() + mappingLookup ); switch (result.getResponseResult()) { @@ -221,6 +222,7 @@ protected void shardOperation(final UpdateRequest request, final ActionListener< UpdateHelper.extractGetResult( request, request.concreteIndex(), + mappingLookup, response.getSeqNo(), response.getPrimaryTerm(), response.getVersion(), @@ -257,6 +259,7 @@ protected void shardOperation(final UpdateRequest request, final ActionListener< UpdateHelper.extractGetResult( request, request.concreteIndex(), + mappingLookup, response.getSeqNo(), response.getPrimaryTerm(), response.getVersion(), @@ -288,6 +291,7 @@ protected void shardOperation(final UpdateRequest request, final ActionListener< UpdateHelper.extractGetResult( request, request.concreteIndex(), + mappingLookup, response.getSeqNo(), response.getPrimaryTerm(), response.getVersion(), diff --git a/server/src/main/java/org/elasticsearch/action/update/UpdateHelper.java b/server/src/main/java/org/elasticsearch/action/update/UpdateHelper.java index fe187ad2d71b8..212b99ca140d3 100644 --- a/server/src/main/java/org/elasticsearch/action/update/UpdateHelper.java +++ b/server/src/main/java/org/elasticsearch/action/update/UpdateHelper.java @@ -24,6 +24,7 @@ import org.elasticsearch.index.engine.DocumentSourceMissingException; import org.elasticsearch.index.get.GetResult; import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.mapper.MappingLookup; import org.elasticsearch.index.mapper.RoutingFieldMapper; import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.shard.ShardId; @@ -62,26 +63,26 @@ public UpdateHelper(ScriptService scriptService, DocumentParsingProvider documen */ public Result prepare(UpdateRequest request, IndexShard indexShard, LongSupplier nowInMillis) throws IOException { final GetResult getResult = indexShard.getService().getForUpdate(request.id(), request.ifSeqNo(), request.ifPrimaryTerm()); - return prepare(indexShard.shardId(), request, getResult, nowInMillis); + return prepare(indexShard, request, getResult, nowInMillis); } /** * Prepares an update request by converting it into an index or delete request or an update response (no action, in the event of a * noop). */ - protected Result prepare(ShardId shardId, UpdateRequest request, final GetResult getResult, LongSupplier nowInMillis) { + protected Result prepare(IndexShard indexShard, UpdateRequest request, final GetResult getResult, LongSupplier nowInMillis) { if (getResult.isExists() == false) { // If the document didn't exist, execute the update request as an upsert - return prepareUpsert(shardId, request, getResult, nowInMillis); + return prepareUpsert(indexShard.shardId(), request, getResult, nowInMillis); } else if (getResult.internalSourceRef() == null) { // no source, we can't do anything, throw a failure... - throw new DocumentSourceMissingException(shardId, request.id()); + throw new DocumentSourceMissingException(indexShard.shardId(), request.id()); } else if (request.script() == null && request.doc() != null) { // The request has no script, it is a new doc that should be merged with the old document - return prepareUpdateIndexRequest(shardId, request, getResult, request.detectNoop()); + return prepareUpdateIndexRequest(indexShard, request, getResult, request.detectNoop()); } else { // The request has a script (or empty script), execute the script and prepare a new index request - return prepareUpdateScriptRequest(shardId, request, getResult, nowInMillis); + return prepareUpdateScriptRequest(indexShard, request, getResult, nowInMillis); } } @@ -179,7 +180,7 @@ static String calculateRouting(GetResult getResult, @Nullable IndexRequest updat * Prepare the request for merging the existing document with a new one, can optionally detect a noop change. Returns a {@code Result} * containing a new {@code IndexRequest} to be executed on the primary and replicas. */ - Result prepareUpdateIndexRequest(ShardId shardId, UpdateRequest request, GetResult getResult, boolean detectNoop) { + Result prepareUpdateIndexRequest(IndexShard indexShard, UpdateRequest request, GetResult getResult, boolean detectNoop) { final IndexRequest currentRequest = request.doc(); final String routing = calculateRouting(getResult, currentRequest); final XContentMeteringParserDecorator meteringParserDecorator = documentParsingProvider.newMeteringParserDecorator(request); @@ -197,7 +198,7 @@ Result prepareUpdateIndexRequest(ShardId shardId, UpdateRequest request, GetResu // where users repopulating multi-fields or adding synonyms, etc. if (detectNoop && noop) { UpdateResponse update = new UpdateResponse( - shardId, + indexShard.shardId(), getResult.getId(), getResult.getSeqNo(), getResult.getPrimaryTerm(), @@ -208,6 +209,7 @@ Result prepareUpdateIndexRequest(ShardId shardId, UpdateRequest request, GetResu extractGetResult( request, request.index(), + indexShard.mapperService().mappingLookup(), getResult.getSeqNo(), getResult.getPrimaryTerm(), getResult.getVersion(), @@ -238,7 +240,7 @@ Result prepareUpdateIndexRequest(ShardId shardId, UpdateRequest request, GetResu * either a new {@code IndexRequest} or {@code DeleteRequest} (depending on the script's returned "op" value) to be executed on the * primary and replicas. */ - Result prepareUpdateScriptRequest(ShardId shardId, UpdateRequest request, GetResult getResult, LongSupplier nowInMillis) { + Result prepareUpdateScriptRequest(IndexShard indexShard, UpdateRequest request, GetResult getResult, LongSupplier nowInMillis) { final IndexRequest currentRequest = request.doc(); final String routing = calculateRouting(getResult, currentRequest); final Tuple> sourceAndContent = XContentHelper.convertToMap(getResult.internalSourceRef(), true); @@ -288,7 +290,7 @@ Result prepareUpdateScriptRequest(ShardId shardId, UpdateRequest request, GetRes default -> { // If it was neither an INDEX or DELETE operation, treat it as a noop UpdateResponse update = new UpdateResponse( - shardId, + indexShard.shardId(), getResult.getId(), getResult.getSeqNo(), getResult.getPrimaryTerm(), @@ -299,6 +301,7 @@ Result prepareUpdateScriptRequest(ShardId shardId, UpdateRequest request, GetRes extractGetResult( request, request.index(), + indexShard.mapperService().mappingLookup(), getResult.getSeqNo(), getResult.getPrimaryTerm(), getResult.getVersion(), @@ -332,6 +335,7 @@ private T executeScript(Script script, T ctxMap) { public static GetResult extractGetResult( final UpdateRequest request, String concreteIndex, + final MappingLookup mappingLookup, long seqNo, long primaryTerm, long version, diff --git a/server/src/main/java/org/elasticsearch/rest/action/ingest/RestDeletePipelineAction.java b/server/src/main/java/org/elasticsearch/rest/action/ingest/RestDeletePipelineAction.java index f4603d3a30683..170fb8c7506ef 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/ingest/RestDeletePipelineAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/ingest/RestDeletePipelineAction.java @@ -39,9 +39,11 @@ public String getName() { @Override public RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { - DeletePipelineRequest request = new DeletePipelineRequest(restRequest.param("id")); - request.masterNodeTimeout(getMasterNodeTimeout(restRequest)); - request.ackTimeout(getAckTimeout(restRequest)); + final var request = new DeletePipelineRequest( + getMasterNodeTimeout(restRequest), + getAckTimeout(restRequest), + restRequest.param("id") + ); return channel -> client.execute(DeletePipelineTransportAction.TYPE, request, new RestToXContentListener<>(channel)); } } diff --git a/server/src/main/java/org/elasticsearch/rest/action/ingest/RestGetPipelineAction.java b/server/src/main/java/org/elasticsearch/rest/action/ingest/RestGetPipelineAction.java index 4119e5726e1b5..c44f632930591 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/ingest/RestGetPipelineAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/ingest/RestGetPipelineAction.java @@ -41,11 +41,11 @@ public String getName() { @Override public RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { - GetPipelineRequest request = new GetPipelineRequest( + final var request = new GetPipelineRequest( + getMasterNodeTimeout(restRequest), restRequest.paramAsBoolean("summary", false), Strings.splitStringByCommaToArray(restRequest.param("id")) ); - request.masterNodeTimeout(getMasterNodeTimeout(restRequest)); return channel -> client.execute( GetPipelineAction.INSTANCE, request, diff --git a/server/src/main/java/org/elasticsearch/rest/action/ingest/RestPutPipelineAction.java b/server/src/main/java/org/elasticsearch/rest/action/ingest/RestPutPipelineAction.java index 372e5cca47249..269d9b08ab66b 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/ingest/RestPutPipelineAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/ingest/RestPutPipelineAction.java @@ -57,9 +57,14 @@ public RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient cl } Tuple sourceTuple = restRequest.contentOrSourceParam(); - PutPipelineRequest request = new PutPipelineRequest(restRequest.param("id"), sourceTuple.v2(), sourceTuple.v1(), ifVersion); - request.masterNodeTimeout(getMasterNodeTimeout(restRequest)); - request.ackTimeout(getAckTimeout(restRequest)); + final var request = new PutPipelineRequest( + getMasterNodeTimeout(restRequest), + getAckTimeout(restRequest), + restRequest.param("id"), + sourceTuple.v2(), + sourceTuple.v1(), + ifVersion + ); return channel -> client.execute(PutPipelineTransportAction.TYPE, request, new RestToXContentListener<>(channel)); } } diff --git a/server/src/main/java/org/elasticsearch/search/internal/ContextIndexSearcher.java b/server/src/main/java/org/elasticsearch/search/internal/ContextIndexSearcher.java index 98199a9b40315..208ca613a350b 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/ContextIndexSearcher.java +++ b/server/src/main/java/org/elasticsearch/search/internal/ContextIndexSearcher.java @@ -202,6 +202,8 @@ public Query rewrite(Query original) throws IOException { } catch (TimeExceededException e) { timeExceeded = true; return new MatchNoDocsQuery("rewrite timed out"); + } catch (TooManyClauses e) { + throw new IllegalArgumentException("Query rewrite failed: too many clauses", e); } finally { if (profiler != null) { profiler.stopAndAddRewriteTime(rewriteTimer); diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportShardBulkActionTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportShardBulkActionTests.java index fd97d799f7c52..35ef892da59a2 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportShardBulkActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportShardBulkActionTests.java @@ -41,6 +41,7 @@ import org.elasticsearch.index.mapper.DocumentMapper; import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.Mapping; +import org.elasticsearch.index.mapper.MappingLookup; import org.elasticsearch.index.mapper.MetadataFieldMapper; import org.elasticsearch.index.mapper.RootObjectMapper; import org.elasticsearch.index.shard.IndexShard; @@ -53,6 +54,9 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool.Names; import org.mockito.ArgumentCaptor; +import org.mockito.MockingDetails; +import org.mockito.Mockito; +import org.mockito.stubbing.Stubbing; import java.io.IOException; import java.util.Collections; @@ -75,6 +79,7 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockingDetails; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -263,20 +268,18 @@ public void testExecuteBulkIndexRequestWithMappingUpdates() throws Exception { Translog.Location resultLocation = new Translog.Location(42, 42, 42); Engine.IndexResult success = new FakeIndexResult(1, 1, 13, true, resultLocation, "id"); - IndexShard shard = mock(IndexShard.class); - when(shard.shardId()).thenReturn(shardId); - when(shard.applyIndexOperationOnPrimary(anyLong(), any(), any(), anyLong(), anyLong(), anyLong(), anyBoolean())).thenReturn( - mappingUpdate - ); - MapperService mapperService = mock(MapperService.class); - when(shard.mapperService()).thenReturn(mapperService); - addMockCloseImplementation(shard); - // merged mapping source needs to be different from previous one for the master node to be invoked + MapperService mapperService = mock(MapperService.class); DocumentMapper mergedDoc = mock(DocumentMapper.class); when(mapperService.merge(any(), any(CompressedXContent.class), any())).thenReturn(mergedDoc); when(mergedDoc.mappingSource()).thenReturn(CompressedXContent.fromJSON("{}")); + IndexShard shard = mockShard(null, mapperService); + when(shard.applyIndexOperationOnPrimary(anyLong(), any(), any(), anyLong(), anyLong(), anyLong(), anyBoolean())).thenReturn( + mappingUpdate + ); + addMockCloseImplementation(shard); + randomlySetIgnoredPrimaryResponse(items[0]); // Pretend the mappings haven't made it to the node yet @@ -489,7 +492,7 @@ public void testNoopUpdateRequest() throws Exception { DocWriteResponse noopUpdateResponse = new UpdateResponse(shardId, "id", 0, 2, 1, DocWriteResponse.Result.NOOP); - IndexShard shard = mock(IndexShard.class); + IndexShard shard = mockShard(null, null); UpdateHelper updateHelper = mock(UpdateHelper.class); when(updateHelper.prepare(any(), eq(shard), any())).thenReturn( @@ -541,11 +544,10 @@ public void testUpdateRequestWithFailure() throws Exception { Exception err = new ElasticsearchException("I'm dead <(x.x)>"); Engine.IndexResult indexResult = new Engine.IndexResult(err, 0, 0, 0, "id"); - IndexShard shard = mock(IndexShard.class); + IndexShard shard = mockShard(indexSettings, null); when(shard.applyIndexOperationOnPrimary(anyLong(), any(), any(), anyLong(), anyLong(), anyLong(), anyBoolean())).thenReturn( indexResult ); - when(shard.indexSettings()).thenReturn(indexSettings); UpdateHelper updateHelper = mock(UpdateHelper.class); when(updateHelper.prepare(any(), eq(shard), any())).thenReturn( @@ -605,11 +607,10 @@ public void testUpdateRequestWithConflictFailure() throws Exception { Exception err = new VersionConflictEngineException(shardId, "id", "I'm conflicted <(;_;)>"); Engine.IndexResult indexResult = new Engine.IndexResult(err, 0, 0, 0, "id"); - IndexShard shard = mock(IndexShard.class); + IndexShard shard = mockShard(indexSettings, null); when(shard.applyIndexOperationOnPrimary(anyLong(), any(), any(), anyLong(), anyLong(), anyLong(), anyBoolean())).thenReturn( indexResult ); - when(shard.indexSettings()).thenReturn(indexSettings); UpdateHelper updateHelper = mock(UpdateHelper.class); when(updateHelper.prepare(any(), eq(shard), any())).thenReturn( @@ -675,12 +676,10 @@ public void testUpdateRequestWithSuccess() throws Exception { boolean created = randomBoolean(); Translog.Location resultLocation = new Translog.Location(42, 42, 42); Engine.IndexResult indexResult = new FakeIndexResult(1, 1, 13, created, resultLocation, "id"); - IndexShard shard = mock(IndexShard.class); + IndexShard shard = mockShard(indexSettings, null); when(shard.applyIndexOperationOnPrimary(anyLong(), any(), any(), anyLong(), anyLong(), anyLong(), anyBoolean())).thenReturn( indexResult ); - when(shard.indexSettings()).thenReturn(indexSettings); - when(shard.shardId()).thenReturn(shardId); UpdateHelper updateHelper = mock(UpdateHelper.class); when(updateHelper.prepare(any(), eq(shard), any())).thenReturn( @@ -739,10 +738,8 @@ public void testUpdateWithDelete() throws Exception { Translog.Location resultLocation = new Translog.Location(42, 42, 42); final long resultSeqNo = 13; Engine.DeleteResult deleteResult = new FakeDeleteResult(1, 1, resultSeqNo, found, resultLocation, "id"); - IndexShard shard = mock(IndexShard.class); + IndexShard shard = mockShard(indexSettings, null); when(shard.applyDeleteOperationOnPrimary(anyLong(), any(), any(), anyLong(), anyLong())).thenReturn(deleteResult); - when(shard.indexSettings()).thenReturn(indexSettings); - when(shard.shardId()).thenReturn(shardId); UpdateHelper updateHelper = mock(UpdateHelper.class); when(updateHelper.prepare(any(), eq(shard), any())).thenReturn( @@ -787,7 +784,7 @@ public void testFailureDuringUpdateProcessing() throws Exception { DocWriteRequest writeRequest = new UpdateRequest("index", "id").doc(Requests.INDEX_CONTENT_TYPE, "field", "value"); BulkItemRequest primaryRequest = new BulkItemRequest(0, writeRequest); - IndexShard shard = mock(IndexShard.class); + IndexShard shard = mockShard(null, null); UpdateHelper updateHelper = mock(UpdateHelper.class); final ElasticsearchException err = new ElasticsearchException("oops"); @@ -905,7 +902,12 @@ public void testRetries() throws Exception { Translog.Location resultLocation = new Translog.Location(42, 42, 42); Engine.IndexResult success = new FakeIndexResult(1, 1, 13, true, resultLocation, "id"); - IndexShard shard = mock(IndexShard.class); + MapperService mapperService = mock(MapperService.class); + DocumentMapper mergedDocMapper = mock(DocumentMapper.class); + when(mergedDocMapper.mappingSource()).thenReturn(CompressedXContent.fromJSON("{}")); + when(mapperService.merge(any(), any(CompressedXContent.class), any())).thenReturn(mergedDocMapper); + + IndexShard shard = mockShard(indexSettings, mapperService); when(shard.applyIndexOperationOnPrimary(anyLong(), any(), any(), anyLong(), anyLong(), anyLong(), anyBoolean())).thenAnswer(ir -> { if (randomBoolean()) { return conflictedResult; @@ -916,15 +918,6 @@ public void testRetries() throws Exception { return success; } }); - when(shard.indexSettings()).thenReturn(indexSettings); - when(shard.shardId()).thenReturn(shardId); - MapperService mapperService = mock(MapperService.class); - when(shard.mapperService()).thenReturn(mapperService); - when(shard.getBulkOperationListener()).thenReturn(mock(ShardBulkStats.class)); - - DocumentMapper mergedDocMapper = mock(DocumentMapper.class); - when(mergedDocMapper.mappingSource()).thenReturn(CompressedXContent.fromJSON("{}")); - when(mapperService.merge(any(), any(CompressedXContent.class), any())).thenReturn(mergedDocMapper); UpdateHelper updateHelper = mock(UpdateHelper.class); when(updateHelper.prepare(any(), eq(shard), any())).thenReturn( @@ -1003,23 +996,20 @@ public void testForceExecutionOnRejectionAfterMappingUpdate() throws Exception { Engine.IndexResult success1 = new FakeIndexResult(1, 1, 10, true, resultLocation1, "id"); Engine.IndexResult success2 = new FakeIndexResult(1, 1, 13, true, resultLocation2, "id"); - IndexShard shard = mock(IndexShard.class); - when(shard.shardId()).thenReturn(shardId); + // merged mapping source needs to be different from previous one for the master node to be invoked + MapperService mapperService = mock(MapperService.class); + DocumentMapper mergedDoc = mock(DocumentMapper.class); + when(mapperService.merge(any(), any(CompressedXContent.class), any())).thenReturn(mergedDoc); + when(mergedDoc.mappingSource()).thenReturn(CompressedXContent.fromJSON("{}")); + + IndexShard shard = mockShard(null, mapperService); when(shard.applyIndexOperationOnPrimary(anyLong(), any(), any(), anyLong(), anyLong(), anyLong(), anyBoolean())).thenReturn( success1, mappingUpdate, success2 ); - when(shard.getFailedIndexResult(any(EsRejectedExecutionException.class), anyLong(), anyString())).thenCallRealMethod(); - MapperService mapperService = mock(MapperService.class); - when(shard.mapperService()).thenReturn(mapperService); addMockCloseImplementation(shard); - // merged mapping source needs to be different from previous one for the master node to be invoked - DocumentMapper mergedDoc = mock(DocumentMapper.class); - when(mapperService.merge(any(), any(CompressedXContent.class), any())).thenReturn(mergedDoc); - when(mergedDoc.mappingSource()).thenReturn(CompressedXContent.fromJSON("{}")); - randomlySetIgnoredPrimaryResponse(items[0]); AtomicInteger updateCalled = new AtomicInteger(); @@ -1130,19 +1120,18 @@ public void testNoopMappingUpdateInfiniteLoopPrevention() throws Exception { "id" ); - IndexShard shard = mockShard(); - when(shard.applyIndexOperationOnPrimary(anyLong(), any(), any(), anyLong(), anyLong(), anyLong(), anyBoolean())).thenReturn( - mappingUpdate - ); MapperService mapperService = mock(MapperService.class); - when(shard.mapperService()).thenReturn(mapperService); - DocumentMapper documentMapper = mock(DocumentMapper.class); when(documentMapper.mappingSource()).thenReturn(CompressedXContent.fromJSON("{}")); // returning the current document mapper as the merge result to simulate a noop mapping update when(mapperService.documentMapper()).thenReturn(documentMapper); when(mapperService.merge(any(), any(CompressedXContent.class), any())).thenReturn(documentMapper); + IndexShard shard = mockShard(null, mapperService); + when(shard.applyIndexOperationOnPrimary(anyLong(), any(), any(), anyLong(), anyLong(), anyLong(), anyBoolean())).thenReturn( + mappingUpdate + ); + UpdateHelper updateHelper = mock(UpdateHelper.class); when(updateHelper.prepare(any(), eq(shard), any())).thenReturn( new UpdateHelper.Result( @@ -1187,7 +1176,17 @@ public void testNoopMappingUpdateSuccessOnRetry() throws Exception { Translog.Location resultLocation = new Translog.Location(42, 42, 42); Engine.IndexResult successfulResult = new FakeIndexResult(1, 1, 10, true, resultLocation, "id"); - IndexShard shard = mockShard(); + MapperService mapperService = mock(MapperService.class); + DocumentMapper documentMapper = mock(DocumentMapper.class); + when(documentMapper.mappingSource()).thenReturn(CompressedXContent.fromJSON("{}")); + when(mapperService.documentMapper()).thenReturn(documentMapper); + // returning the current document mapper as the merge result to simulate a noop mapping update + when(mapperService.merge(any(), any(CompressedXContent.class), any())).thenReturn(documentMapper); + // on the second invocation, the mapping version is incremented + // so that the second mapping update attempt doesn't trigger the infinite loop prevention + when(mapperService.mappingVersion()).thenReturn(0L, 0L, 1L); + + IndexShard shard = mockShard(null, mapperService); when(shard.applyIndexOperationOnPrimary(anyLong(), any(), any(), anyLong(), anyLong(), anyLong(), anyBoolean())).thenReturn( // on the first invocation, return a result that attempts a mapping update // the mapping update will be a noop and the operation is retired without contacting the master @@ -1199,18 +1198,6 @@ public void testNoopMappingUpdateSuccessOnRetry() throws Exception { successfulResult ); - MapperService mapperService = mock(MapperService.class); - when(shard.mapperService()).thenReturn(mapperService); - - DocumentMapper documentMapper = mock(DocumentMapper.class); - when(documentMapper.mappingSource()).thenReturn(CompressedXContent.fromJSON("{}")); - when(mapperService.documentMapper()).thenReturn(documentMapper); - // returning the current document mapper as the merge result to simulate a noop mapping update - when(mapperService.merge(any(), any(CompressedXContent.class), any())).thenReturn(documentMapper); - // on the second invocation, the mapping version is incremented - // so that the second mapping update attempt doesn't trigger the infinite loop prevention - when(mapperService.mappingVersion()).thenReturn(0L, 0L, 1L); - UpdateHelper updateHelper = mock(UpdateHelper.class); when(updateHelper.prepare(any(), eq(shard), any())).thenReturn( new UpdateHelper.Result( @@ -1244,11 +1231,35 @@ public void testNoopMappingUpdateSuccessOnRetry() throws Exception { verify(mapperService, times(2)).merge(any(), any(CompressedXContent.class), any()); } - private IndexShard mockShard() { + private IndexShard mockShard(IndexSettings indexSettings, MapperService mapperService) { IndexShard shard = mock(IndexShard.class); when(shard.shardId()).thenReturn(shardId); when(shard.getBulkOperationListener()).thenReturn(mock(ShardBulkStats.class)); when(shard.getFailedIndexResult(any(Exception.class), anyLong(), anyString())).thenCallRealMethod(); + + if (indexSettings != null) { + when(shard.indexSettings()).thenReturn(indexSettings); + } + + if (mapperService != null) { + when(shard.mapperService()).thenReturn(mapperService); + if (Mockito.mockingDetails(mapperService).isMock()) { + MockingDetails details = mockingDetails(mapperService); + if (details.getStubbings() + .stream() + .map(Stubbing::getInvocation) + .noneMatch(i -> i.toString().contains(".mappingLookup()"))) { + // If the mappingLookup() method is not mocked, configure it to return an empty mapping + when(mapperService.mappingLookup()).thenReturn(MappingLookup.EMPTY); + } + } + } else { + // By default, create a mapper service that returns an empty mapping lookup + MapperService defaultMapperService = mock(MapperService.class); + when(defaultMapperService.mappingLookup()).thenReturn(MappingLookup.EMPTY); + when(shard.mapperService()).thenReturn(defaultMapperService); + } + return shard; } diff --git a/server/src/test/java/org/elasticsearch/action/ingest/PutPipelineRequestTests.java b/server/src/test/java/org/elasticsearch/action/ingest/PutPipelineRequestTests.java index 2b33ae14ce03b..f3cc25cbeb5e4 100644 --- a/server/src/test/java/org/elasticsearch/action/ingest/PutPipelineRequestTests.java +++ b/server/src/test/java/org/elasticsearch/action/ingest/PutPipelineRequestTests.java @@ -52,7 +52,13 @@ public void testToXContent() throws IOException { // End first processor pipelineBuilder.endArray(); pipelineBuilder.endObject(); - PutPipelineRequest request = new PutPipelineRequest("1", BytesReference.bytes(pipelineBuilder), xContentType); + PutPipelineRequest request = new PutPipelineRequest( + TEST_REQUEST_TIMEOUT, + TEST_REQUEST_TIMEOUT, + "1", + BytesReference.bytes(pipelineBuilder), + xContentType + ); XContentBuilder requestBuilder = XContentBuilder.builder(xContentType.xContent()); BytesReference actualRequestBody = BytesReference.bytes(request.toXContent(requestBuilder, ToXContent.EMPTY_PARAMS)); assertEquals(BytesReference.bytes(pipelineBuilder), actualRequestBody); diff --git a/server/src/test/java/org/elasticsearch/action/update/UpdateRequestTests.java b/server/src/test/java/org/elasticsearch/action/update/UpdateRequestTests.java index 2b28c68fa5a36..d8960bd902ac5 100644 --- a/server/src/test/java/org/elasticsearch/action/update/UpdateRequestTests.java +++ b/server/src/test/java/org/elasticsearch/action/update/UpdateRequestTests.java @@ -22,6 +22,9 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.env.Environment; import org.elasticsearch.index.get.GetResult; +import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.mapper.MappingLookup; +import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.plugins.internal.DocumentParsingProvider; import org.elasticsearch.script.MockScriptEngine; @@ -62,6 +65,7 @@ import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.notNullValue; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class UpdateRequestTests extends ESTestCase { @@ -378,9 +382,10 @@ public void testNowInScript() throws IOException { .script(mockInlineScript("ctx._source.update_timestamp = ctx._now")) .scriptedUpsert(true); long nowInMillis = randomNonNegativeLong(); + IndexShard indexShard = createMockIndexShard(new ShardId("test", "_na_", 0)); // We simulate that the document is not existing yet GetResult getResult = new GetResult("test", "2", UNASSIGNED_SEQ_NO, 0, 0, false, null, null, null); - UpdateHelper.Result result = updateHelper.prepare(new ShardId("test", "_na_", 0), updateRequest, getResult, () -> nowInMillis); + UpdateHelper.Result result = updateHelper.prepare(indexShard, updateRequest, getResult, () -> nowInMillis); Writeable action = result.action(); assertThat(action, instanceOf(IndexRequest.class)); IndexRequest indexAction = (IndexRequest) action; @@ -419,12 +424,8 @@ public void testUpsertTimeout() throws IOException { } private void runTimeoutTest(final GetResult getResult, final UpdateRequest updateRequest) { - final UpdateHelper.Result result = updateHelper.prepare( - new ShardId("test", "", 0), - updateRequest, - getResult, - ESTestCase::randomNonNegativeLong - ); + final IndexShard indexShard = createMockIndexShard(new ShardId("test", "", 0)); + final UpdateHelper.Result result = updateHelper.prepare(indexShard, updateRequest, getResult, ESTestCase::randomNonNegativeLong); final Writeable action = result.action(); assertThat(action, instanceOf(ReplicationRequest.class)); final ReplicationRequest request = (ReplicationRequest) action; @@ -586,7 +587,7 @@ public void testRoutingExtraction() throws Exception { } public void testNoopDetection() throws Exception { - ShardId shardId = new ShardId("test", "", 0); + IndexShard indexShard = createMockIndexShard(new ShardId("test", "", 0)); GetResult getResult = new GetResult("test", "1", 0, 1, 0, true, new BytesArray("{\"body\": \"foo\"}"), null, null); UpdateRequest request; @@ -594,13 +595,13 @@ public void testNoopDetection() throws Exception { request = new UpdateRequest("test", "1").fromXContent(parser); } UpdateHelper updateHelper = new UpdateHelper(mock(ScriptService.class), DocumentParsingProvider.EMPTY_INSTANCE); - UpdateHelper.Result result = updateHelper.prepareUpdateIndexRequest(shardId, request, getResult, true); + UpdateHelper.Result result = updateHelper.prepareUpdateIndexRequest(indexShard, request, getResult, true); assertThat(result.action(), instanceOf(UpdateResponse.class)); assertThat(result.getResponseResult(), equalTo(DocWriteResponse.Result.NOOP)); // Try again, with detectNoop turned off - result = updateHelper.prepareUpdateIndexRequest(shardId, request, getResult, false); + result = updateHelper.prepareUpdateIndexRequest(indexShard, request, getResult, false); assertThat(result.action(), instanceOf(IndexRequest.class)); assertThat(result.getResponseResult(), equalTo(DocWriteResponse.Result.UPDATED)); assertThat(result.updatedSourceAsMap().get("body").toString(), equalTo("foo")); @@ -608,7 +609,7 @@ public void testNoopDetection() throws Exception { try (var parser = createParser(JsonXContent.jsonXContent, new BytesArray("{\"doc\": {\"body\": \"bar\"}}"))) { // Change the request to be a different doc request = new UpdateRequest("test", "1").fromXContent(parser); - result = updateHelper.prepareUpdateIndexRequest(shardId, request, getResult, true); + result = updateHelper.prepareUpdateIndexRequest(indexShard, request, getResult, true); assertThat(result.action(), instanceOf(IndexRequest.class)); assertThat(result.getResponseResult(), equalTo(DocWriteResponse.Result.UPDATED)); @@ -618,13 +619,13 @@ public void testNoopDetection() throws Exception { } public void testUpdateScript() throws Exception { - ShardId shardId = new ShardId("test", "", 0); + IndexShard indexShard = createMockIndexShard(new ShardId("test", "", 0)); GetResult getResult = new GetResult("test", "1", 0, 1, 0, true, new BytesArray("{\"body\": \"bar\"}"), null, null); UpdateRequest request = new UpdateRequest("test", "1").script(mockInlineScript("ctx._source.body = \"foo\"")); UpdateHelper.Result result = updateHelper.prepareUpdateScriptRequest( - shardId, + indexShard, request, getResult, ESTestCase::randomNonNegativeLong @@ -637,7 +638,7 @@ public void testUpdateScript() throws Exception { // Now where the script changes the op to "delete" request = new UpdateRequest("test", "1").script(mockInlineScript("ctx.op = 'delete'")); - result = updateHelper.prepareUpdateScriptRequest(shardId, request, getResult, ESTestCase::randomNonNegativeLong); + result = updateHelper.prepareUpdateScriptRequest(indexShard, request, getResult, ESTestCase::randomNonNegativeLong); assertThat(result.action(), instanceOf(DeleteRequest.class)); assertThat(result.getResponseResult(), equalTo(DocWriteResponse.Result.DELETED)); @@ -650,7 +651,7 @@ public void testUpdateScript() throws Exception { request = new UpdateRequest("test", "1").script(mockInlineScript("ctx.op = 'bad'")); } - result = updateHelper.prepareUpdateScriptRequest(shardId, request, getResult, ESTestCase::randomNonNegativeLong); + result = updateHelper.prepareUpdateScriptRequest(indexShard, request, getResult, ESTestCase::randomNonNegativeLong); assertThat(result.action(), instanceOf(UpdateResponse.class)); assertThat(result.getResponseResult(), equalTo(DocWriteResponse.Result.NOOP)); @@ -668,4 +669,15 @@ public void testToString() throws IOException { scripted_upsert[false], detect_noop[true]}""")); } } + + private static IndexShard createMockIndexShard(ShardId shardId) { + MapperService mapperService = mock(MapperService.class); + when(mapperService.mappingLookup()).thenReturn(MappingLookup.EMPTY); + + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(shardId); + when(indexShard.mapperService()).thenReturn(mapperService); + + return indexShard; + } } diff --git a/server/src/test/java/org/elasticsearch/ingest/IngestServiceTests.java b/server/src/test/java/org/elasticsearch/ingest/IngestServiceTests.java index 1c02aa5a7af06..3adaf398624de 100644 --- a/server/src/test/java/org/elasticsearch/ingest/IngestServiceTests.java +++ b/server/src/test/java/org/elasticsearch/ingest/IngestServiceTests.java @@ -408,7 +408,7 @@ public void testDelete() { assertThat(ingestService.getPipeline("_id"), notNullValue()); // Delete pipeline: - DeletePipelineRequest deleteRequest = new DeletePipelineRequest("_id"); + DeletePipelineRequest deleteRequest = new DeletePipelineRequest(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT, "_id"); previousClusterState = clusterState; clusterState = executeDelete(deleteRequest, clusterState); ingestService.applyClusterState(new ClusterChangedEvent("", clusterState, previousClusterState)); @@ -713,7 +713,7 @@ public void testCrud() throws Exception { assertThat(pipeline.getProcessors().size(), equalTo(1)); assertThat(pipeline.getProcessors().get(0).getType(), equalTo("set")); - DeletePipelineRequest deleteRequest = new DeletePipelineRequest(id); + DeletePipelineRequest deleteRequest = new DeletePipelineRequest(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT, id); previousClusterState = clusterState; clusterState = executeDelete(deleteRequest, clusterState); ingestService.applyClusterState(new ClusterChangedEvent("", clusterState, previousClusterState)); @@ -803,7 +803,7 @@ public void testDeleteUsingWildcard() { assertThat(ingestService.getPipeline("q1"), notNullValue()); // Delete pipeline matching wildcard - DeletePipelineRequest deleteRequest = new DeletePipelineRequest("p*"); + DeletePipelineRequest deleteRequest = new DeletePipelineRequest(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT, "p*"); previousClusterState = clusterState; clusterState = executeDelete(deleteRequest, clusterState); ingestService.applyClusterState(new ClusterChangedEvent("", clusterState, previousClusterState)); @@ -816,13 +816,16 @@ public void testDeleteUsingWildcard() { assertThat( expectThrows( ResourceNotFoundException.class, - () -> executeFailingDelete(new DeletePipelineRequest("unknown"), finalClusterState) + () -> executeFailingDelete( + new DeletePipelineRequest(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT, "unknown"), + finalClusterState + ) ).getMessage(), equalTo("pipeline [unknown] is missing") ); // match all wildcard works on last remaining pipeline - DeletePipelineRequest matchAllDeleteRequest = new DeletePipelineRequest("*"); + DeletePipelineRequest matchAllDeleteRequest = new DeletePipelineRequest(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT, "*"); previousClusterState = clusterState; clusterState = executeDelete(matchAllDeleteRequest, clusterState); ingestService.applyClusterState(new ClusterChangedEvent("", clusterState, previousClusterState)); @@ -851,8 +854,10 @@ public void testDeleteWithExistingUnmatchedPipelines() { ClusterState finalClusterState = clusterState; assertThat( - expectThrows(ResourceNotFoundException.class, () -> executeFailingDelete(new DeletePipelineRequest("z*"), finalClusterState)) - .getMessage(), + expectThrows( + ResourceNotFoundException.class, + () -> executeFailingDelete(new DeletePipelineRequest(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT, "z*"), finalClusterState) + ).getMessage(), equalTo("pipeline [z*] is missing") ); } @@ -878,7 +883,7 @@ public void testDeleteWithIndexUsePipeline() { ingestService.applyClusterState(new ClusterChangedEvent("", clusterState, previousClusterState)); assertThat(ingestService.getPipeline("_id"), notNullValue()); - DeletePipelineRequest deleteRequest = new DeletePipelineRequest("_id"); + DeletePipelineRequest deleteRequest = new DeletePipelineRequest(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT, "_id"); { // delete pipeline which is in used of default_pipeline @@ -2637,7 +2642,14 @@ public void testPutPipelineWithVersionedUpdateWithoutExistingPipeline() { final Integer version = randomInt(); var pipelineString = "{\"version\": " + version + ", \"processors\": []}"; - var request = new PutPipelineRequest(pipelineId, new BytesArray(pipelineString), XContentType.JSON, version); + var request = new PutPipelineRequest( + TEST_REQUEST_TIMEOUT, + TEST_REQUEST_TIMEOUT, + pipelineId, + new BytesArray(pipelineString), + XContentType.JSON, + version + ); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> executeFailingPut(request, clusterState)); assertThat( e.getMessage(), @@ -2662,7 +2674,14 @@ public void testPutPipelineWithVersionedUpdateDoesNotMatchExistingPipeline() { .build(); final Integer requestedVersion = randomValueOtherThan(version, ESTestCase::randomInt); - var request = new PutPipelineRequest(pipelineId, new BytesArray(pipelineString), XContentType.JSON, requestedVersion); + var request = new PutPipelineRequest( + TEST_REQUEST_TIMEOUT, + TEST_REQUEST_TIMEOUT, + pipelineId, + new BytesArray(pipelineString), + XContentType.JSON, + requestedVersion + ); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> executeFailingPut(request, clusterState)); assertThat( e.getMessage(), @@ -2687,7 +2706,14 @@ public void testPutPipelineWithVersionedUpdateSpecifiesSameVersion() throws Exce .metadata(Metadata.builder().putCustom(IngestMetadata.TYPE, new IngestMetadata(Map.of(pipelineId, existingPipeline))).build()) .build(); - var request = new PutPipelineRequest(pipelineId, new BytesArray(pipelineString), XContentType.JSON, version); + var request = new PutPipelineRequest( + TEST_REQUEST_TIMEOUT, + TEST_REQUEST_TIMEOUT, + pipelineId, + new BytesArray(pipelineString), + XContentType.JSON, + version + ); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> executeFailingPut(request, clusterState)); assertThat(e.getMessage(), equalTo(Strings.format("cannot update pipeline [%s] with the same version [%s]", pipelineId, version))); } @@ -2703,7 +2729,14 @@ public void testPutPipelineWithVersionedUpdateSpecifiesValidVersion() throws Exc final int specifiedVersion = randomValueOtherThan(existingVersion, ESTestCase::randomInt); var updatedPipelineString = "{\"version\": " + specifiedVersion + ", \"processors\": []}"; - var request = new PutPipelineRequest(pipelineId, new BytesArray(updatedPipelineString), XContentType.JSON, existingVersion); + var request = new PutPipelineRequest( + TEST_REQUEST_TIMEOUT, + TEST_REQUEST_TIMEOUT, + pipelineId, + new BytesArray(updatedPipelineString), + XContentType.JSON, + existingVersion + ); var updatedState = executePut(request, clusterState); var updatedConfig = ((IngestMetadata) updatedState.metadata().custom(IngestMetadata.TYPE)).getPipelines().get(pipelineId); @@ -2721,7 +2754,14 @@ public void testPutPipelineWithVersionedUpdateIncrementsVersion() throws Excepti .build(); var updatedPipelineString = "{\"processors\": []}"; - var request = new PutPipelineRequest(pipelineId, new BytesArray(updatedPipelineString), XContentType.JSON, existingVersion); + var request = new PutPipelineRequest( + TEST_REQUEST_TIMEOUT, + TEST_REQUEST_TIMEOUT, + pipelineId, + new BytesArray(updatedPipelineString), + XContentType.JSON, + existingVersion + ); var updatedState = executePut(request, clusterState); var updatedConfig = ((IngestMetadata) updatedState.metadata().custom(IngestMetadata.TYPE)).getPipelines().get(pipelineId); diff --git a/server/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java b/server/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java index f7c415b529859..56a8b0f3a8c30 100644 --- a/server/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java @@ -29,6 +29,7 @@ import org.apache.lucene.index.Term; import org.apache.lucene.queries.spans.SpanNearQuery; import org.apache.lucene.queries.spans.SpanTermQuery; +import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanClause.Occur; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.Collector; @@ -1105,6 +1106,22 @@ private static ContextIndexSearcher newContextSearcher(IndexReader reader) throw ); } + public void testTooManyClauses() throws Exception { + indexDocs(); + var oldCount = IndexSearcher.getMaxClauseCount(); + try { + var query = new BooleanQuery.Builder().add(new BooleanClause(new MatchAllDocsQuery(), Occur.SHOULD)) + .add(new MatchAllDocsQuery(), Occur.SHOULD) + .build(); + try (TestSearchContext context = createContext(newContextSearcher(reader), query)) { + IndexSearcher.setMaxClauseCount(1); + expectThrows(IllegalArgumentException.class, context::rewrittenQuery); + } + } finally { + IndexSearcher.setMaxClauseCount(oldCount); + } + } + private static ContextIndexSearcher noCollectionContextSearcher(IndexReader reader) throws IOException { return earlyTerminationContextSearcher(reader, 0); } diff --git a/test/framework/src/main/java/org/elasticsearch/ingest/IngestPipelineTestUtils.java b/test/framework/src/main/java/org/elasticsearch/ingest/IngestPipelineTestUtils.java index 963fc634b5506..8fd3c61d4c9da 100644 --- a/test/framework/src/main/java/org/elasticsearch/ingest/IngestPipelineTestUtils.java +++ b/test/framework/src/main/java/org/elasticsearch/ingest/IngestPipelineTestUtils.java @@ -28,6 +28,7 @@ import java.io.IOException; +import static org.elasticsearch.test.ESTestCase.TEST_REQUEST_TIMEOUT; import static org.elasticsearch.test.ESTestCase.safeGet; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.elasticsearch.xcontent.XContentFactory.jsonBuilder; @@ -46,7 +47,7 @@ private IngestPipelineTestUtils() { /* no instances */ } * @return a new {@link PutPipelineRequest} with the given {@code id} and body. */ public static PutPipelineRequest putJsonPipelineRequest(String id, BytesReference source) { - return new PutPipelineRequest(id, source, XContentType.JSON); + return new PutPipelineRequest(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT, id, source, XContentType.JSON); } /** @@ -103,19 +104,23 @@ public static void putJsonPipeline(ElasticsearchClient client, String id, ToXCon public static void deletePipelinesIgnoringExceptions(ElasticsearchClient client, Iterable ids) { for (final var id : ids) { ESTestCase.safeAwait( - l -> client.execute(DeletePipelineTransportAction.TYPE, new DeletePipelineRequest(id), new ActionListener<>() { - @Override - public void onResponse(AcknowledgedResponse acknowledgedResponse) { - logger.info("delete pipeline [{}] success [acknowledged={}]", id, acknowledgedResponse.isAcknowledged()); - l.onResponse(null); - } + l -> client.execute( + DeletePipelineTransportAction.TYPE, + new DeletePipelineRequest(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT, id), + new ActionListener<>() { + @Override + public void onResponse(AcknowledgedResponse acknowledgedResponse) { + logger.info("delete pipeline [{}] success [acknowledged={}]", id, acknowledgedResponse.isAcknowledged()); + l.onResponse(null); + } - @Override - public void onFailure(Exception e) { - logger.warn(Strings.format("delete pipeline [%s] failure", id), e); - l.onResponse(null); + @Override + public void onFailure(Exception e) { + logger.warn(Strings.format("delete pipeline [%s] failure", id), e); + l.onResponse(null); + } } - }) + ) ); } } diff --git a/test/framework/src/main/java/org/elasticsearch/logsdb/datageneration/datasource/DefaultWrappersHandler.java b/test/framework/src/main/java/org/elasticsearch/logsdb/datageneration/datasource/DefaultWrappersHandler.java index 83684d6958ed5..8af26c28ef5b3 100644 --- a/test/framework/src/main/java/org/elasticsearch/logsdb/datageneration/datasource/DefaultWrappersHandler.java +++ b/test/framework/src/main/java/org/elasticsearch/logsdb/datageneration/datasource/DefaultWrappersHandler.java @@ -27,7 +27,8 @@ public DataSourceResponse.ArrayWrapper handle(DataSourceRequest.ArrayWrapper ign } private static Function, Supplier> injectNulls() { - return (values) -> () -> ESTestCase.randomBoolean() ? null : values.get(); + // Inject some nulls but majority of data should be non-null (as it likely is in reality). + return (values) -> () -> ESTestCase.randomDouble() <= 0.05 ? null : values.get(); } private static Function, Supplier> wrapInArray() { diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java index 0cf026e0483ab..9132474fa9415 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java @@ -1864,7 +1864,15 @@ private void postIndexAsyncActions(String[] indices, List inFlig } while (inFlightAsyncOperations.size() > MAX_IN_FLIGHT_ASYNC_INDEXES) { int waitFor = between(0, inFlightAsyncOperations.size() - 1); - safeAwait(inFlightAsyncOperations.remove(waitFor)); + try { + assertTrue( + "operation did not complete within timeout", + inFlightAsyncOperations.remove(waitFor).await(60, TimeUnit.SECONDS) + ); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + fail(e, "interrupted while waiting for operation to complete"); + } } } @@ -2657,13 +2665,20 @@ protected static void putJsonPipeline(String id, ToXContentFragment toXContent) * @return the result of running the {@link GetPipelineAction} on the given IDs, using the default {@link ESIntegTestCase#client()}. */ protected static GetPipelineResponse getPipelines(String... ids) { - return safeGet(client().execute(GetPipelineAction.INSTANCE, new GetPipelineRequest(ids))); + return safeGet(client().execute(GetPipelineAction.INSTANCE, new GetPipelineRequest(TEST_REQUEST_TIMEOUT, ids))); } /** * Delete the ingest pipeline with the given {@code id}, the default {@link ESIntegTestCase#client()}. */ protected static void deletePipeline(String id) { - assertAcked(safeGet(client().execute(DeletePipelineTransportAction.TYPE, new DeletePipelineRequest(id)))); + assertAcked( + safeGet( + client().execute( + DeletePipelineTransportAction.TYPE, + new DeletePipelineRequest(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT, id) + ) + ) + ); } } diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESSingleNodeTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESSingleNodeTestCase.java index 99de1c6848dc9..459d5573d7c12 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESSingleNodeTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESSingleNodeTestCase.java @@ -524,6 +524,13 @@ protected final void putJsonPipeline(String id, ToXContentFragment toXContent) t * Delete the ingest pipeline with the given {@code id}, the default {@link ESSingleNodeTestCase#client()}. */ protected final void deletePipeline(String id) { - assertAcked(safeGet(client().execute(DeletePipelineTransportAction.TYPE, new DeletePipelineRequest(id)))); + assertAcked( + safeGet( + client().execute( + DeletePipelineTransportAction.TYPE, + new DeletePipelineRequest(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT, id) + ) + ) + ); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/template/IndexTemplateRegistry.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/template/IndexTemplateRegistry.java index 87092c45bf032..8849377e6ad7e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/template/IndexTemplateRegistry.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/template/IndexTemplateRegistry.java @@ -20,6 +20,7 @@ import org.elasticsearch.action.ingest.PutPipelineTransportAction; import org.elasticsearch.action.support.GroupedActionListener; import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.action.support.master.MasterNodeRequest; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.ClusterState; @@ -725,11 +726,12 @@ private void putIngestPipeline(final IngestPipelineConfig pipelineConfig, final final Executor executor = threadPool.generic(); executor.execute(() -> { PutPipelineRequest request = new PutPipelineRequest( + MasterNodeRequest.INFINITE_MASTER_NODE_TIMEOUT, + MasterNodeRequest.INFINITE_MASTER_NODE_TIMEOUT, pipelineConfig.getId(), pipelineConfig.loadConfig(), pipelineConfig.getXContentType() ); - request.masterNodeTimeout(TimeValue.MAX_VALUE); executeAsyncWithOrigin( client.threadPool().getThreadContext(), diff --git a/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/EnrichPolicyReindexPipeline.java b/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/EnrichPolicyReindexPipeline.java index 8d9da1ba631f6..7cddd7e037742 100644 --- a/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/EnrichPolicyReindexPipeline.java +++ b/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/EnrichPolicyReindexPipeline.java @@ -22,6 +22,8 @@ import java.io.IOException; import java.io.UncheckedIOException; +import static org.elasticsearch.xpack.enrich.EnrichPolicyRunner.ENRICH_MASTER_REQUEST_TIMEOUT; + /** * Manages the definitions and lifecycle of the ingest pipeline used by the reindex operation within the Enrich Policy execution. */ @@ -68,7 +70,17 @@ static boolean exists(ClusterState clusterState) { */ public static void create(Client client, ActionListener listener) { final BytesReference pipeline = BytesReference.bytes(currentEnrichPipelineDefinition(XContentType.JSON)); - client.execute(PutPipelineTransportAction.TYPE, new PutPipelineRequest(pipelineName(), pipeline, XContentType.JSON), listener); + client.execute( + PutPipelineTransportAction.TYPE, + new PutPipelineRequest( + ENRICH_MASTER_REQUEST_TIMEOUT, + ENRICH_MASTER_REQUEST_TIMEOUT, + pipelineName(), + pipeline, + XContentType.JSON + ), + listener + ); } private static XContentBuilder currentEnrichPipelineDefinition(XContentType xContentType) { diff --git a/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/EnrichPolicyRunner.java b/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/EnrichPolicyRunner.java index a279eac5befc6..69bc54457785c 100644 --- a/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/EnrichPolicyRunner.java +++ b/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/EnrichPolicyRunner.java @@ -97,6 +97,11 @@ public class EnrichPolicyRunner { static final String ENRICH_INDEX_README_TEXT = "This index is managed by Elasticsearch and should not be modified in any way."; + /** + * Timeout for enrich-related requests that interact with the master node. Possibly this should be longer and/or configurable. + */ + static final TimeValue ENRICH_MASTER_REQUEST_TIMEOUT = TimeValue.THIRTY_SECONDS; + private final String policyName; private final EnrichPolicy policy; private final ExecuteEnrichPolicyTask task; @@ -685,10 +690,7 @@ private void setIndexReadOnly(ActionListener listener) { } private void waitForIndexGreen(ActionListener listener) { - ClusterHealthRequest request = new ClusterHealthRequest( - TimeValue.THIRTY_SECONDS /* TODO should this be longer/configurable? */ , - enrichIndexName - ).waitForGreenStatus(); + ClusterHealthRequest request = new ClusterHealthRequest(ENRICH_MASTER_REQUEST_TIMEOUT, enrichIndexName).waitForGreenStatus(); enrichOriginClient().admin().cluster().health(request, listener); } diff --git a/x-pack/plugin/esql/build.gradle b/x-pack/plugin/esql/build.gradle index 26cf53b334b1e..0225664918b7b 100644 --- a/x-pack/plugin/esql/build.gradle +++ b/x-pack/plugin/esql/build.gradle @@ -11,7 +11,7 @@ esplugin { name 'x-pack-esql' description 'The plugin that powers ESQL for Elasticsearch' classname 'org.elasticsearch.xpack.esql.plugin.EsqlPlugin' - extendedPlugins = ['x-pack-esql-core', 'lang-painless'] + extendedPlugins = ['x-pack-esql-core', 'lang-painless', 'x-pack-ml'] } base { @@ -22,6 +22,7 @@ dependencies { compileOnly project(path: xpackModule('core')) compileOnly project(':modules:lang-painless:spi') compileOnly project(xpackModule('esql-core')) + compileOnly project(xpackModule('ml')) implementation project('compute') implementation project('compute:ann') implementation project(':libs:elasticsearch-dissect') diff --git a/x-pack/plugin/esql/compute/build.gradle b/x-pack/plugin/esql/compute/build.gradle index 971bfd39c231f..81d1a6f5360ca 100644 --- a/x-pack/plugin/esql/compute/build.gradle +++ b/x-pack/plugin/esql/compute/build.gradle @@ -11,11 +11,14 @@ base { dependencies { compileOnly project(':server') compileOnly project('ann') + compileOnly project(xpackModule('ml')) annotationProcessor project('gen') implementation 'com.carrotsearch:hppc:0.8.1' testImplementation project(':test:framework') testImplementation(project(xpackModule('esql-core'))) + testImplementation(project(xpackModule('core'))) + testImplementation(project(xpackModule('ml'))) } def projectDirectory = project.layout.projectDirectory diff --git a/x-pack/plugin/esql/compute/src/main/java/module-info.java b/x-pack/plugin/esql/compute/src/main/java/module-info.java index dc8cda0fbe3c8..1739c90467c2c 100644 --- a/x-pack/plugin/esql/compute/src/main/java/module-info.java +++ b/x-pack/plugin/esql/compute/src/main/java/module-info.java @@ -7,6 +7,7 @@ module org.elasticsearch.compute { + requires org.apache.lucene.analysis.common; requires org.apache.lucene.core; requires org.elasticsearch.base; requires org.elasticsearch.server; @@ -15,6 +16,7 @@ // required due to dependency on org.elasticsearch.common.util.concurrent.AbstractAsyncTask requires org.apache.logging.log4j; requires org.elasticsearch.logging; + requires org.elasticsearch.ml; requires org.elasticsearch.tdigest; requires org.elasticsearch.geo; requires hppc; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec new file mode 100644 index 0000000000000..076f3ee092ecf --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec @@ -0,0 +1,13 @@ +categorize +required_capability: categorize + +FROM sample_data + | STATS count=COUNT(), values=VALUES(message) BY category=CATEGORIZE(message) + | SORT count DESC, category ASC +; + +count:long | values:keyword | category:integer +3 | [Connected to 10.1.0.1, Connected to 10.1.0.2, Connected to 10.1.0.3] | 0 +3 | [Connection error] | 1 +1 | [Disconnected] | 2 +; diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeEvaluator.java new file mode 100644 index 0000000000000..93bc8ce3e2a1b --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeEvaluator.java @@ -0,0 +1,131 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.function.grouping; + +import java.lang.IllegalArgumentException; +import java.lang.Override; +import java.lang.String; +import java.util.function.Function; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.function.Warnings; +import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer; +import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link Categorize}. + * This class is generated. Do not edit it. + */ +public final class CategorizeEvaluator implements EvalOperator.ExpressionEvaluator { + private final Warnings warnings; + + private final EvalOperator.ExpressionEvaluator v; + + private final CategorizationAnalyzer analyzer; + + private final TokenListCategorizer.CloseableTokenListCategorizer categorizer; + + private final DriverContext driverContext; + + public CategorizeEvaluator(Source source, EvalOperator.ExpressionEvaluator v, + CategorizationAnalyzer analyzer, + TokenListCategorizer.CloseableTokenListCategorizer categorizer, DriverContext driverContext) { + this.v = v; + this.analyzer = analyzer; + this.categorizer = categorizer; + this.driverContext = driverContext; + this.warnings = Warnings.createWarnings(driverContext.warningsMode(), source); + } + + @Override + public Block eval(Page page) { + try (BytesRefBlock vBlock = (BytesRefBlock) v.eval(page)) { + BytesRefVector vVector = vBlock.asVector(); + if (vVector == null) { + return eval(page.getPositionCount(), vBlock); + } + return eval(page.getPositionCount(), vVector).asBlock(); + } + } + + public IntBlock eval(int positionCount, BytesRefBlock vBlock) { + try(IntBlock.Builder result = driverContext.blockFactory().newIntBlockBuilder(positionCount)) { + BytesRef vScratch = new BytesRef(); + position: for (int p = 0; p < positionCount; p++) { + if (vBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (vBlock.getValueCount(p) != 1) { + if (vBlock.getValueCount(p) > 1) { + warnings.registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + result.appendInt(Categorize.process(vBlock.getBytesRef(vBlock.getFirstValueIndex(p), vScratch), this.analyzer, this.categorizer)); + } + return result.build(); + } + } + + public IntVector eval(int positionCount, BytesRefVector vVector) { + try(IntVector.FixedBuilder result = driverContext.blockFactory().newIntVectorFixedBuilder(positionCount)) { + BytesRef vScratch = new BytesRef(); + position: for (int p = 0; p < positionCount; p++) { + result.appendInt(p, Categorize.process(vVector.getBytesRef(p, vScratch), this.analyzer, this.categorizer)); + } + return result.build(); + } + } + + @Override + public String toString() { + return "CategorizeEvaluator[" + "v=" + v + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(v, analyzer, categorizer); + } + + static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + + private final EvalOperator.ExpressionEvaluator.Factory v; + + private final Function analyzer; + + private final Function categorizer; + + public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory v, + Function analyzer, + Function categorizer) { + this.source = source; + this.v = v; + this.analyzer = analyzer; + this.categorizer = categorizer; + } + + @Override + public CategorizeEvaluator get(DriverContext context) { + return new CategorizeEvaluator(source, v.get(context), analyzer.apply(context), categorizer.apply(context), context); + } + + @Override + public String toString() { + return "CategorizeEvaluator[" + "v=" + v + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index 49076a1d65e72..ce8c20e4fbf11 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -33,6 +33,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Values; import org.elasticsearch.xpack.esql.expression.function.aggregate.WeightedAvg; import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket; +import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Least; @@ -383,7 +384,10 @@ private FunctionDefinition[][] functions() { } private static FunctionDefinition[][] snapshotFunctions() { - return new FunctionDefinition[][] { new FunctionDefinition[] { def(Rate.class, Rate::withUnresolvedTimestamp, "rate") } }; + return new FunctionDefinition[][] { + new FunctionDefinition[] { + def(Categorize.class, Categorize::new, "categorize"), + def(Rate.class, Rate::withUnresolvedTimestamp, "rate") } }; } public EsqlFunctionRegistry snapshotRegistry() { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java new file mode 100644 index 0000000000000..82c836a6f9d49 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java @@ -0,0 +1,159 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.grouping; + +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.core.WhitespaceTokenizer; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.util.BytesRefHash; +import org.elasticsearch.compute.ann.Evaluator; +import org.elasticsearch.compute.ann.Fixed; +import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; +import org.elasticsearch.index.analysis.CharFilterFactory; +import org.elasticsearch.index.analysis.CustomAnalyzer; +import org.elasticsearch.index.analysis.TokenFilterFactory; +import org.elasticsearch.index.analysis.TokenizerFactory; +import org.elasticsearch.xpack.esql.capabilities.Validatable; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash; +import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationPartOfSpeechDictionary; +import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer; +import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer; + +import java.io.IOException; +import java.util.List; +import java.util.function.Function; + +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isString; + +/** + * Categorizes text messages. + * + * This implementation is incomplete and comes with the following caveats: + * - it only works correctly on a single node. + * - when running on multiple nodes, category IDs of the different nodes are + * aggregated, even though the same ID can correspond to a totally different + * category + * - the output consists of category IDs, which should be replaced by category + * regexes or keys + * + * TODO(jan, nik): fix this + */ +public class Categorize extends GroupingFunction implements Validatable { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "Categorize", + Categorize::new + ); + + private final Expression field; + + @FunctionInfo(returnType = { "integer" }, description = "Categorizes text messages") + public Categorize( + Source source, + @Param(name = "field", type = { "text", "keyword" }, description = "Expression to categorize") Expression field + ) { + super(source, List.of(field)); + this.field = field; + } + + private Categorize(StreamInput in) throws IOException { + this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + source().writeTo(out); + out.writeNamedWriteable(field); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + @Override + public boolean foldable() { + return field.foldable(); + } + + @Evaluator + static int process( + BytesRef v, + @Fixed(includeInToString = false, build = true) CategorizationAnalyzer analyzer, + @Fixed(includeInToString = false, build = true) TokenListCategorizer.CloseableTokenListCategorizer categorizer + ) { + String s = v.utf8ToString(); + try (TokenStream ts = analyzer.tokenStream("text", s)) { + return categorizer.computeCategory(ts, s.length(), 1).getId(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public ExpressionEvaluator.Factory toEvaluator(Function toEvaluator) { + return new CategorizeEvaluator.Factory( + source(), + toEvaluator.apply(field), + context -> new CategorizationAnalyzer( + // TODO(jan): get the correct analyzer in here, see CategorizationAnalyzerConfig::buildStandardCategorizationAnalyzer + new CustomAnalyzer( + TokenizerFactory.newFactory("whitespace", WhitespaceTokenizer::new), + new CharFilterFactory[0], + new TokenFilterFactory[0] + ), + true + ), + context -> new TokenListCategorizer.CloseableTokenListCategorizer( + new CategorizationBytesRefHash(new BytesRefHash(2048, context.bigArrays())), + CategorizationPartOfSpeechDictionary.getInstance(), + 0.70f + ) + ); + } + + @Override + protected TypeResolution resolveType() { + return isString(field(), sourceText(), DEFAULT); + } + + @Override + public DataType dataType() { + return DataType.INTEGER; + } + + @Override + public Expression replaceChildren(List newChildren) { + return new Categorize(source(), newChildren.get(0)); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, Categorize::new, field); + } + + public Expression field() { + return field; + } + + @Override + public String toString() { + return "Categorize{field=" + field + "}"; + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/EsqlScalarFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/EsqlScalarFunction.java index 563847473c992..14b0c872a3b86 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/EsqlScalarFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/EsqlScalarFunction.java @@ -16,6 +16,7 @@ import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper; import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket; +import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Least; @@ -73,6 +74,7 @@ public static List getNamedWriteables() { entries.add(Atan2.ENTRY); entries.add(Bucket.ENTRY); entries.add(Case.ENTRY); + entries.add(Categorize.ENTRY); entries.add(CIDRMatch.ENTRY); entries.add(Coalesce.ENTRY); entries.add(Concat.ENTRY); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialContains.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialContains.java index 24519ae28721f..7f578565f81f2 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialContains.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialContains.java @@ -91,9 +91,23 @@ protected boolean geometryRelatesGeometry(BytesRef left, BytesRef right) throws } @Override - protected boolean geometryRelatesGeometries(MultiValuesCombiner left, MultiValuesCombiner right) throws IOException { - Component2D[] rightComponent2Ds = asLuceneComponent2Ds(crsType, right.combined()); - return geometryRelatesGeometries(left, rightComponent2Ds); + protected void processSourceAndSource(BooleanBlock.Builder builder, int position, BytesRefBlock left, BytesRefBlock right) + throws IOException { + if (right.getValueCount(position) < 1) { + builder.appendNull(); + } else { + processSourceAndConstant(builder, position, left, asLuceneComponent2Ds(crsType, right, position)); + } + } + + @Override + protected void processPointDocValuesAndSource( + BooleanBlock.Builder builder, + int position, + LongBlock leftValue, + BytesRefBlock rightValue + ) throws IOException { + processPointDocValuesAndConstant(builder, position, leftValue, asLuceneComponent2Ds(crsType, rightValue, position)); } private boolean geometryRelatesGeometries(BytesRef left, Component2D[] rightComponent2Ds) throws IOException { @@ -102,11 +116,6 @@ private boolean geometryRelatesGeometries(BytesRef left, Component2D[] rightComp return geometryRelatesGeometries(leftDocValueReader, rightComponent2Ds); } - private boolean geometryRelatesGeometries(MultiValuesCombiner left, Component2D[] rightComponent2Ds) throws IOException { - GeometryDocValueReader leftDocValueReader = asGeometryDocValueReader(coordinateEncoder, shapeIndexer, left.combined()); - return geometryRelatesGeometries(leftDocValueReader, rightComponent2Ds); - } - private boolean geometryRelatesGeometries(GeometryDocValueReader leftDocValueReader, Component2D[] rightComponent2Ds) throws IOException { for (Component2D rightComponent2D : rightComponent2Ds) { @@ -123,18 +132,28 @@ private void processSourceAndConstant(BooleanBlock.Builder builder, int position if (left.getValueCount(position) < 1) { builder.appendNull(); } else { - MultiValuesBytesRef leftValues = new MultiValuesBytesRef(left, position); - builder.appendBoolean(geometryRelatesGeometries(leftValues, right)); + final GeometryDocValueReader reader = asGeometryDocValueReader(coordinateEncoder, shapeIndexer, left, position); + builder.appendBoolean(geometryRelatesGeometries(reader, right)); } } - private void processPointDocValuesAndConstant(BooleanBlock.Builder builder, int p, LongBlock left, @Fixed Component2D[] right) - throws IOException { - if (left.getValueCount(p) < 1) { + private void processPointDocValuesAndConstant( + BooleanBlock.Builder builder, + int position, + LongBlock left, + @Fixed Component2D[] right + ) throws IOException { + if (left.getValueCount(position) < 1) { builder.appendNull(); } else { - MultiValuesLong leftValues = new MultiValuesLong(left, p, spatialCoordinateType::longAsPoint); - builder.appendBoolean(geometryRelatesGeometries(leftValues, right)); + final GeometryDocValueReader reader = asGeometryDocValueReader( + coordinateEncoder, + shapeIndexer, + left, + position, + spatialCoordinateType::longAsPoint + ); + builder.appendBoolean(geometryRelatesGeometries(reader, right)); } } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialRelatesFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialRelatesFunction.java index a73f8d08c6397..68ca793089499 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialRelatesFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialRelatesFunction.java @@ -17,10 +17,6 @@ import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.operator.EvalOperator; -import org.elasticsearch.geometry.Geometry; -import org.elasticsearch.geometry.GeometryCollection; -import org.elasticsearch.geometry.MultiPoint; -import org.elasticsearch.geometry.Point; import org.elasticsearch.index.mapper.ShapeIndexer; import org.elasticsearch.lucene.spatial.Component2DVisitor; import org.elasticsearch.lucene.spatial.CoordinateEncoder; @@ -33,8 +29,6 @@ import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper; import java.io.IOException; -import java.util.ArrayList; -import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; @@ -169,23 +163,13 @@ protected boolean geometryRelatesGeometry(GeometryDocValueReader reader, Compone return visitor.matches(); } - protected boolean geometryRelatesGeometries(MultiValuesCombiner left, MultiValuesCombiner right) throws IOException { - Component2D rightComponent2D = asLuceneComponent2D(crsType, right.combined()); - return geometryRelatesGeometry(left, rightComponent2D); - } - - private boolean geometryRelatesGeometry(MultiValuesCombiner left, Component2D rightComponent2D) throws IOException { - GeometryDocValueReader leftDocValueReader = asGeometryDocValueReader(coordinateEncoder, shapeIndexer, left.combined()); - return geometryRelatesGeometry(leftDocValueReader, rightComponent2D); - } - protected void processSourceAndConstant(BooleanBlock.Builder builder, int position, BytesRefBlock left, @Fixed Component2D right) throws IOException { if (left.getValueCount(position) < 1) { builder.appendNull(); } else { - MultiValuesBytesRef leftValues = new MultiValuesBytesRef(left, position); - builder.appendBoolean(geometryRelatesGeometry(leftValues, right)); + final GeometryDocValueReader reader = asGeometryDocValueReader(coordinateEncoder, shapeIndexer, left, position); + builder.appendBoolean(geometryRelatesGeometry(reader, right)); } } @@ -194,9 +178,9 @@ protected void processSourceAndSource(BooleanBlock.Builder builder, int position if (left.getValueCount(position) < 1 || right.getValueCount(position) < 1) { builder.appendNull(); } else { - MultiValuesBytesRef leftValues = new MultiValuesBytesRef(left, position); - MultiValuesBytesRef rightValues = new MultiValuesBytesRef(right, position); - builder.appendBoolean(geometryRelatesGeometries(leftValues, rightValues)); + final GeometryDocValueReader reader = asGeometryDocValueReader(coordinateEncoder, shapeIndexer, left, position); + final Component2D component2D = asLuceneComponent2D(crsType, right, position); + builder.appendBoolean(geometryRelatesGeometry(reader, component2D)); } } @@ -209,8 +193,14 @@ protected void processPointDocValuesAndConstant( if (leftValue.getValueCount(position) < 1) { builder.appendNull(); } else { - MultiValuesLong leftValues = new MultiValuesLong(leftValue, position, spatialCoordinateType::longAsPoint); - builder.appendBoolean(geometryRelatesGeometry(leftValues, rightValue)); + final GeometryDocValueReader reader = asGeometryDocValueReader( + coordinateEncoder, + shapeIndexer, + leftValue, + position, + spatialCoordinateType::longAsPoint + ); + builder.appendBoolean(geometryRelatesGeometry(reader, rightValue)); } } @@ -223,100 +213,16 @@ protected void processPointDocValuesAndSource( if (leftValue.getValueCount(position) < 1 || rightValue.getValueCount(position) < 1) { builder.appendNull(); } else { - MultiValuesLong leftValues = new MultiValuesLong(leftValue, position, spatialCoordinateType::longAsPoint); - MultiValuesBytesRef rightValues = new MultiValuesBytesRef(rightValue, position); - builder.appendBoolean(geometryRelatesGeometries(leftValues, rightValues)); - } - } - } - - /** - * When dealing with ST_CONTAINS and ST_WITHIN we need to pre-combine the field geometries for multi-values in order - * to perform the relationship check. This means instead of relying on the generated evaluators to iterate over all - * values in a multi-value field, the entire block is passed into the spatial function, and we combine the values into - * a geometry collection or multipoint. - */ - protected interface MultiValuesCombiner { - Geometry combined(); - } - - /** - * Values read from source will be encoded as WKB in BytesRefBlock. The block contains multiple rows, and within - * each row multiple values, so we need to efficiently iterate over only the values required for the requested row. - * This class works for point and shape fields, because both are extracted into the same block encoding. - * However, we do detect if all values in the field are actually points and create a MultiPoint instead of a GeometryCollection. - */ - protected static class MultiValuesBytesRef implements MultiValuesCombiner { - private final BytesRefBlock valueBlock; - private final int valueCount; - private final BytesRef scratch = new BytesRef(); - private final int firstValue; - - MultiValuesBytesRef(BytesRefBlock valueBlock, int position) { - this.valueBlock = valueBlock; - this.firstValue = valueBlock.getFirstValueIndex(position); - this.valueCount = valueBlock.getValueCount(position); - } - - @Override - public Geometry combined() { - int valueIndex = firstValue; - boolean allPoints = true; - if (valueCount == 1) { - return fromBytesRef(valueBlock.getBytesRef(valueIndex, scratch)); - } - List geometries = new ArrayList<>(); - while (valueIndex < firstValue + valueCount) { - geometries.add(fromBytesRef(valueBlock.getBytesRef(valueIndex++, scratch))); - if (geometries.getLast() instanceof Point == false) { - allPoints = false; - } - } - return allPoints ? new MultiPoint(asPointList(geometries)) : new GeometryCollection<>(geometries); - } - - private List asPointList(List geometries) { - List points = new ArrayList<>(geometries.size()); - for (Geometry geometry : geometries) { - points.add((Point) geometry); - } - return points; - } - - protected Geometry fromBytesRef(BytesRef bytesRef) { - return SpatialCoordinateTypes.UNSPECIFIED.wkbToGeometry(bytesRef); - } - } - - /** - * Point values read from doc-values will be encoded as in LogBlock. The block contains multiple rows, and within - * each row multiple values, so we need to efficiently iterate over only the values required for the requested row. - * Since the encoding differs for GEO and CARTESIAN, we need the decoder function to be passed in the constructor. - */ - protected static class MultiValuesLong implements MultiValuesCombiner { - private final LongBlock valueBlock; - private final Function decoder; - private final int valueCount; - private final int firstValue; - - MultiValuesLong(LongBlock valueBlock, int position, Function decoder) { - this.valueBlock = valueBlock; - this.decoder = decoder; - this.firstValue = valueBlock.getFirstValueIndex(position); - this.valueCount = valueBlock.getValueCount(position); - } - - @Override - public Geometry combined() { - int valueIndex = firstValue; - if (valueCount == 1) { - return decoder.apply(valueBlock.getLong(valueIndex)); - } - List points = new ArrayList<>(); - while (valueIndex < firstValue + valueCount) { - points.add(decoder.apply(valueBlock.getLong(valueIndex++))); + final GeometryDocValueReader reader = asGeometryDocValueReader( + coordinateEncoder, + shapeIndexer, + leftValue, + position, + spatialCoordinateType::longAsPoint + ); + final Component2D component2D = asLuceneComponent2D(crsType, rightValue, position); + builder.appendBoolean(geometryRelatesGeometry(reader, component2D)); } - return new MultiPoint(points); } } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialRelatesUtils.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialRelatesUtils.java index 6997eb7fa9528..6ae99ea8165cd 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialRelatesUtils.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialRelatesUtils.java @@ -13,9 +13,13 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.geo.LuceneGeometriesUtils; import org.elasticsearch.common.geo.Orientation; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.geometry.Circle; import org.elasticsearch.geometry.Geometry; import org.elasticsearch.geometry.GeometryCollection; +import org.elasticsearch.geometry.MultiPoint; +import org.elasticsearch.geometry.Point; import org.elasticsearch.geometry.ShapeType; import org.elasticsearch.index.mapper.GeoShapeIndexer; import org.elasticsearch.index.mapper.ShapeIndexer; @@ -31,19 +35,18 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.function.Function; import static org.elasticsearch.xpack.esql.core.expression.Foldables.valueOf; public class SpatialRelatesUtils { - /** - * This function is used to convert a spatial constant to a lucene Component2D. - * When both left and right sides are constants, we convert the left to a doc-values byte array and the right to a Component2D. - */ + /** Converts a {@link Expression} into a {@link Component2D}. */ static Component2D asLuceneComponent2D(BinarySpatialFunction.SpatialCrsType crsType, Expression expression) { return asLuceneComponent2D(crsType, makeGeometryFromLiteral(expression)); } + /** Converts a {@link Geometry} into a {@link Component2D}. */ static Component2D asLuceneComponent2D(BinarySpatialFunction.SpatialCrsType crsType, Geometry geometry) { if (crsType == BinarySpatialFunction.SpatialCrsType.GEO) { var luceneGeometries = LuceneGeometriesUtils.toLatLonGeometry(geometry, true, t -> {}); @@ -54,15 +57,23 @@ static Component2D asLuceneComponent2D(BinarySpatialFunction.SpatialCrsType crsT } } + /** Converts a {@link BytesRefBlock} at a given {@code position} into a {@link Component2D}. */ + static Component2D asLuceneComponent2D(BinarySpatialFunction.SpatialCrsType type, BytesRefBlock valueBlock, int position) { + return asLuceneComponent2D(type, asGeometry(valueBlock, position)); + } + /** - * This function is used to convert a spatial constant to an array of lucene Component2Ds. - * When both left and right sides are constants, we convert the left to a doc-values byte array and the right to a Component2D[]. + * Converts a {@link Expression} at a given {@code position} into a {@link Component2D} array. * The reason for generating an array instead of a single component is for multi-shape support with ST_CONTAINS. */ static Component2D[] asLuceneComponent2Ds(BinarySpatialFunction.SpatialCrsType crsType, Expression expression) { return asLuceneComponent2Ds(crsType, makeGeometryFromLiteral(expression)); } + /** + * Converts a {@link Geometry} at a given {@code position} into a {@link Component2D} array. + * The reason for generating an array instead of a single component is for multi-shape support with ST_CONTAINS. + */ static Component2D[] asLuceneComponent2Ds(BinarySpatialFunction.SpatialCrsType crsType, Geometry geometry) { if (crsType == BinarySpatialFunction.SpatialCrsType.GEO) { var luceneGeometries = LuceneGeometriesUtils.toLatLonGeometry(geometry, true, t -> {}); @@ -73,10 +84,12 @@ static Component2D[] asLuceneComponent2Ds(BinarySpatialFunction.SpatialCrsType c } } - /** - * This function is used to convert a spatial constant to a doc-values byte array. - * When both left and right sides are constants, we convert the left to a doc-values byte array and the right to a Component2D. - */ + /** Converts a {@link BytesRefBlock} at a given {@code position} into a {@link Component2D} array. */ + static Component2D[] asLuceneComponent2Ds(BinarySpatialFunction.SpatialCrsType type, BytesRefBlock valueBlock, int position) { + return asLuceneComponent2Ds(type, asGeometry(valueBlock, position)); + } + + /** Converts a {@link Expression} into a {@link GeometryDocValueReader} */ static GeometryDocValueReader asGeometryDocValueReader(BinarySpatialFunction.SpatialCrsType crsType, Expression expression) throws IOException { Geometry geometry = makeGeometryFromLiteral(expression); @@ -92,11 +105,7 @@ static GeometryDocValueReader asGeometryDocValueReader(BinarySpatialFunction.Spa } - /** - * Converting shapes into doc-values byte arrays is needed under two situations: - * - If both left and right are constants, we convert the right to Component2D and the left to doc-values for comparison - * - If the right is a constant and no lucene push-down was possible, we get WKB in the left and convert it to doc-values for comparison - */ + /** Converts a {@link Geometry} into a {@link GeometryDocValueReader} */ static GeometryDocValueReader asGeometryDocValueReader(CoordinateEncoder encoder, ShapeIndexer shapeIndexer, Geometry geometry) throws IOException { GeometryDocValueReader reader = new GeometryDocValueReader(); @@ -110,6 +119,50 @@ static GeometryDocValueReader asGeometryDocValueReader(CoordinateEncoder encoder return reader; } + /** Converts a {@link LongBlock} at a give {@code position} into a {@link GeometryDocValueReader} */ + static GeometryDocValueReader asGeometryDocValueReader( + CoordinateEncoder encoder, + ShapeIndexer shapeIndexer, + LongBlock valueBlock, + int position, + Function decoder + ) throws IOException { + final int firstValueIndex = valueBlock.getFirstValueIndex(position); + final int valueCount = valueBlock.getValueCount(position); + if (valueCount == 1) { + return asGeometryDocValueReader(encoder, shapeIndexer, decoder.apply(valueBlock.getLong(firstValueIndex))); + } + final List points = new ArrayList<>(valueCount); + for (int i = 0; i < valueCount; i++) { + points.add(decoder.apply(valueBlock.getLong(firstValueIndex + i))); + } + return asGeometryDocValueReader(encoder, shapeIndexer, new MultiPoint(points)); + } + + /** Converts a {@link BytesRefBlock} at a given {code position} into a {@link GeometryDocValueReader} */ + static GeometryDocValueReader asGeometryDocValueReader( + CoordinateEncoder encoder, + ShapeIndexer shapeIndexer, + BytesRefBlock valueBlock, + int position + ) throws IOException { + return asGeometryDocValueReader(encoder, shapeIndexer, asGeometry(valueBlock, position)); + } + + private static Geometry asGeometry(BytesRefBlock valueBlock, int position) { + final BytesRef scratch = new BytesRef(); + final int firstValueIndex = valueBlock.getFirstValueIndex(position); + final int valueCount = valueBlock.getValueCount(position); + if (valueCount == 1) { + return SpatialCoordinateTypes.UNSPECIFIED.wkbToGeometry(valueBlock.getBytesRef(firstValueIndex, scratch)); + } + final List geometries = new ArrayList<>(valueCount); + for (int i = 0; i < valueCount; i++) { + geometries.add(SpatialCoordinateTypes.UNSPECIFIED.wkbToGeometry(valueBlock.getBytesRef(firstValueIndex + i, scratch))); + } + return new GeometryCollection<>(geometries); + } + /** * This function is used in two places, when evaluating a spatial constant in the SpatialRelatesFunction, as well as when * we do lucene-pushdown of spatial functions. diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlFeatures.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlFeatures.java index 4b6a38a3e8762..f19e6523aa075 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlFeatures.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlFeatures.java @@ -179,9 +179,14 @@ public class EsqlFeatures implements FeatureSpecification { */ public static final NodeFeature RESOLVE_FIELDS_API = new NodeFeature("esql.resolve_fields_api"); + /** + * Support categorize + */ + public static final NodeFeature CATEGORIZE = new NodeFeature("esql.categorize"); + private Set snapshotBuildFeatures() { assert Build.current().isSnapshot() : Build.current(); - return Set.of(METRICS_SYNTAX); + return Set.of(METRICS_SYNTAX, CATEGORIZE); } @Override diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeTests.java new file mode 100644 index 0000000000000..f93389d5cb659 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeTests.java @@ -0,0 +1,55 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.grouping; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Supplier; + +import static org.hamcrest.Matchers.equalTo; + +public class CategorizeTests extends AbstractScalarFunctionTestCase { + public CategorizeTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + List suppliers = new ArrayList<>(); + for (DataType dataType : List.of(DataType.KEYWORD, DataType.TEXT)) { + suppliers.add( + new TestCaseSupplier( + "text with " + dataType.typeName(), + List.of(dataType), + () -> new TestCaseSupplier.TestCase( + List.of(new TestCaseSupplier.TypedData(new BytesRef("blah blah blah"), dataType, "f")), + "CategorizeEvaluator[v=Attribute[channel=0]]", + DataType.INTEGER, + equalTo(0) + ) + ) + ); + } + return parameterSuppliersFromTypedDataWithDefaultChecks(true, suppliers, (v, p) -> "string"); + } + + @Override + protected Expression build(Source source, List args) { + return new Categorize(source, args.get(0)); + } +} diff --git a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoader.java b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoader.java index a63d911e9d40d..e927c46e6bd29 100644 --- a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoader.java +++ b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoader.java @@ -15,17 +15,12 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.Setting; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.tasks.Task; -import org.elasticsearch.threadpool.ExecutorBuilder; -import org.elasticsearch.threadpool.FixedExecutorBuilder; import org.elasticsearch.xpack.core.ml.packageloader.action.GetTrainedModelPackageConfigAction; import org.elasticsearch.xpack.core.ml.packageloader.action.LoadTrainedModelPackageAction; import org.elasticsearch.xpack.ml.packageloader.action.ModelDownloadTask; -import org.elasticsearch.xpack.ml.packageloader.action.ModelImporter; import org.elasticsearch.xpack.ml.packageloader.action.TransportGetTrainedModelPackageConfigAction; import org.elasticsearch.xpack.ml.packageloader.action.TransportLoadTrainedModelPackage; @@ -49,6 +44,9 @@ public class MachineLearningPackageLoader extends Plugin implements ActionPlugin Setting.Property.Dynamic ); + // re-using thread pool setup by the ml plugin + public static final String UTILITY_THREAD_POOL_NAME = "ml_utility"; + // This link will be invalid for serverless, but serverless will never be // air-gapped, so this message should never be needed. private static final String MODEL_REPOSITORY_DOCUMENTATION_LINK = format( @@ -56,8 +54,6 @@ public class MachineLearningPackageLoader extends Plugin implements ActionPlugin Build.current().version().replaceFirst("^(\\d+\\.\\d+).*", "$1") ); - public static final String MODEL_DOWNLOAD_THREADPOOL_NAME = "model_download"; - public MachineLearningPackageLoader() {} @Override @@ -85,24 +81,6 @@ public List getNamedWriteables() { ); } - @Override - public List> getExecutorBuilders(Settings settings) { - return List.of(modelDownloadExecutor(settings)); - } - - public static FixedExecutorBuilder modelDownloadExecutor(Settings settings) { - // Threadpool with a fixed number of threads for - // downloading the model definition files - return new FixedExecutorBuilder( - settings, - MODEL_DOWNLOAD_THREADPOOL_NAME, - ModelImporter.NUMBER_OF_STREAMS, - -1, // unbounded queue size - "xpack.ml.model_download_thread_pool", - EsExecutors.TaskTrackingConfig.DO_NOT_TRACK - ); - } - @Override public List getBootstrapChecks() { return List.of(new BootstrapCheck() { diff --git a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporter.java b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporter.java index 86711804ed03c..33d5d5982d2b0 100644 --- a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporter.java +++ b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporter.java @@ -10,248 +10,124 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.support.RefCountingListener; -import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.ActionType; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesArray; -import org.elasticsearch.core.Nullable; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.TaskCancelledException; -import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.action.PutTrainedModelDefinitionPartAction; import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig; -import org.elasticsearch.xpack.ml.packageloader.MachineLearningPackageLoader; +import java.io.IOException; import java.io.InputStream; import java.net.URI; import java.net.URISyntaxException; -import java.util.ArrayList; -import java.util.List; import java.util.Objects; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.atomic.AtomicInteger; import static org.elasticsearch.core.Strings.format; /** - * For downloading and the vocabulary and model definition file and - * indexing those files in Elasticsearch. - * Holding the large model definition file in memory will consume - * too much memory, instead it is streamed in chunks and each chunk - * written to the index in a non-blocking request. - * The model files may be installed from a local file or download - * from a server. The server download uses {@link #NUMBER_OF_STREAMS} - * connections each using the Range header to split the stream by byte - * range. There is a complication in that the final part of the model - * definition must be uploaded last as writing this part causes an index - * refresh. - * When read from file a single thread is used to read the file - * stream, split into chunks and index those chunks. + * A helper class for abstracting out the use of the ModelLoaderUtils to make dependency injection testing easier. */ -public class ModelImporter { +class ModelImporter { private static final int DEFAULT_CHUNK_SIZE = 1024 * 1024; // 1MB - public static final int NUMBER_OF_STREAMS = 5; private static final Logger logger = LogManager.getLogger(ModelImporter.class); private final Client client; private final String modelId; private final ModelPackageConfig config; private final ModelDownloadTask task; - private final ExecutorService executorService; - private final AtomicInteger progressCounter = new AtomicInteger(); - private final URI uri; - ModelImporter(Client client, String modelId, ModelPackageConfig packageConfig, ModelDownloadTask task, ThreadPool threadPool) - throws URISyntaxException { + ModelImporter(Client client, String modelId, ModelPackageConfig packageConfig, ModelDownloadTask task) { this.client = client; this.modelId = Objects.requireNonNull(modelId); this.config = Objects.requireNonNull(packageConfig); this.task = Objects.requireNonNull(task); - this.executorService = threadPool.executor(MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME); - this.uri = ModelLoaderUtils.resolvePackageLocation( - config.getModelRepository(), - config.getPackagedModelId() + ModelLoaderUtils.MODEL_FILE_EXTENSION - ); } - public void doImport(ActionListener listener) { - executorService.execute(() -> doImportInternal(listener)); - } + public void doImport() throws URISyntaxException, IOException, ElasticsearchStatusException { + long size = config.getSize(); - private void doImportInternal(ActionListener finalListener) { - assert ThreadPool.assertCurrentThreadPool(MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME) - : format( - "Model download must execute from [%s] but thread is [%s]", - MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME, - Thread.currentThread().getName() - ); + // Uploading other artefacts of the model first, that way the model is last and a simple search can be used to check if the + // download is complete + if (Strings.isNullOrEmpty(config.getVocabularyFile()) == false) { + uploadVocabulary(); - ModelLoaderUtils.VocabularyParts vocabularyParts = null; - try { - if (config.getVocabularyFile() != null) { - vocabularyParts = ModelLoaderUtils.loadVocabulary( - ModelLoaderUtils.resolvePackageLocation(config.getModelRepository(), config.getVocabularyFile()) - ); - } + logger.debug(() -> format("[%s] imported model vocabulary [%s]", modelId, config.getVocabularyFile())); + } - // simple round up - int totalParts = (int) ((config.getSize() + DEFAULT_CHUNK_SIZE - 1) / DEFAULT_CHUNK_SIZE); + URI uri = ModelLoaderUtils.resolvePackageLocation( + config.getModelRepository(), + config.getPackagedModelId() + ModelLoaderUtils.MODEL_FILE_EXTENSION + ); - if (ModelLoaderUtils.uriIsFile(uri) == false) { - var ranges = ModelLoaderUtils.split(config.getSize(), NUMBER_OF_STREAMS, DEFAULT_CHUNK_SIZE); - var downloaders = new ArrayList(ranges.size()); - for (var range : ranges) { - downloaders.add(new ModelLoaderUtils.HttpStreamChunker(uri, range, DEFAULT_CHUNK_SIZE)); - } - downloadModelDefinition(config.getSize(), totalParts, vocabularyParts, downloaders, finalListener); - } else { - InputStream modelInputStream = ModelLoaderUtils.getFileInputStream(uri); - ModelLoaderUtils.InputStreamChunker chunkIterator = new ModelLoaderUtils.InputStreamChunker( - modelInputStream, - DEFAULT_CHUNK_SIZE - ); - readModelDefinitionFromFile(config.getSize(), totalParts, chunkIterator, vocabularyParts, finalListener); - } - } catch (Exception e) { - finalListener.onFailure(e); - return; - } - } + InputStream modelInputStream = ModelLoaderUtils.getInputStreamFromModelRepository(uri); - void downloadModelDefinition( - long size, - int totalParts, - @Nullable ModelLoaderUtils.VocabularyParts vocabularyParts, - List downloaders, - ActionListener finalListener - ) { - try (var countingListener = new RefCountingListener(1, ActionListener.wrap(ignore -> executorService.execute(() -> { - var finalDownloader = downloaders.get(downloaders.size() - 1); - downloadFinalPart(size, totalParts, finalDownloader, finalListener.delegateFailureAndWrap((l, r) -> { - checkDownloadComplete(downloaders); - l.onResponse(AcknowledgedResponse.TRUE); - })); - }), finalListener::onFailure))) { - // Uploading other artefacts of the model first, that way the model is last and a simple search can be used to check if the - // download is complete - if (vocabularyParts != null) { - uploadVocabulary(vocabularyParts, countingListener); - } + ModelLoaderUtils.InputStreamChunker chunkIterator = new ModelLoaderUtils.InputStreamChunker(modelInputStream, DEFAULT_CHUNK_SIZE); - // Download all but the final split. - // The final split is a single chunk - for (int streamSplit = 0; streamSplit < downloaders.size() - 1; ++streamSplit) { - final var downloader = downloaders.get(streamSplit); - var rangeDownloadedListener = countingListener.acquire(); // acquire to keep the counting listener from closing - executorService.execute( - () -> downloadPartInRange(size, totalParts, downloader, executorService, countingListener, rangeDownloadedListener) - ); - } - } - } + // simple round up + int totalParts = (int) ((size + DEFAULT_CHUNK_SIZE - 1) / DEFAULT_CHUNK_SIZE); - private void downloadPartInRange( - long size, - int totalParts, - ModelLoaderUtils.HttpStreamChunker downloadChunker, - ExecutorService executorService, - RefCountingListener countingListener, - ActionListener rangeFullyDownloadedListener - ) { - assert ThreadPool.assertCurrentThreadPool(MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME) - : format( - "Model download must execute from [%s] but thread is [%s]", - MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME, - Thread.currentThread().getName() + for (int part = 0; part < totalParts - 1; ++part) { + task.setProgress(totalParts, part); + BytesArray definition = chunkIterator.next(); + + PutTrainedModelDefinitionPartAction.Request modelPartRequest = new PutTrainedModelDefinitionPartAction.Request( + modelId, + definition, + part, + size, + totalParts, + true ); - if (countingListener.isFailing()) { - rangeFullyDownloadedListener.onResponse(null); // the error has already been reported elsewhere - return; + executeRequestIfNotCancelled(PutTrainedModelDefinitionPartAction.INSTANCE, modelPartRequest); } - try { - throwIfTaskCancelled(); - var bytesAndIndex = downloadChunker.next(); - task.setProgress(totalParts, progressCounter.getAndIncrement()); - - indexPart(bytesAndIndex.partIndex(), totalParts, size, bytesAndIndex.bytes(), countingListener.acquire(ack -> {})); - } catch (Exception e) { - rangeFullyDownloadedListener.onFailure(e); - return; - } + // get the last part, this time verify the checksum and size + BytesArray definition = chunkIterator.next(); - if (downloadChunker.hasNext()) { - executorService.execute( - () -> downloadPartInRange( - size, - totalParts, - downloadChunker, - executorService, - countingListener, - rangeFullyDownloadedListener - ) + if (config.getSha256().equals(chunkIterator.getSha256()) == false) { + String message = format( + "Model sha256 checksums do not match, expected [%s] but got [%s]", + config.getSha256(), + chunkIterator.getSha256() ); - } else { - rangeFullyDownloadedListener.onResponse(null); + + throw new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR); } - } - private void downloadFinalPart( - long size, - int totalParts, - ModelLoaderUtils.HttpStreamChunker downloader, - ActionListener lastPartWrittenListener - ) { - assert ThreadPool.assertCurrentThreadPool(MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME) - : format( - "Model download must execute from [%s] but thread is [%s]", - MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME, - Thread.currentThread().getName() + if (config.getSize() != chunkIterator.getTotalBytesRead()) { + String message = format( + "Model size does not match, expected [%d] but got [%d]", + config.getSize(), + chunkIterator.getTotalBytesRead() ); - try { - var bytesAndIndex = downloader.next(); - task.setProgress(totalParts, progressCounter.getAndIncrement()); - - indexPart(bytesAndIndex.partIndex(), totalParts, size, bytesAndIndex.bytes(), lastPartWrittenListener); - } catch (Exception e) { - lastPartWrittenListener.onFailure(e); + throw new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR); } - } - - void readModelDefinitionFromFile( - long size, - int totalParts, - ModelLoaderUtils.InputStreamChunker chunkIterator, - @Nullable ModelLoaderUtils.VocabularyParts vocabularyParts, - ActionListener finalListener - ) { - try (var countingListener = new RefCountingListener(1, ActionListener.wrap(ignore -> executorService.execute(() -> { - finalListener.onResponse(AcknowledgedResponse.TRUE); - }), finalListener::onFailure))) { - try { - if (vocabularyParts != null) { - uploadVocabulary(vocabularyParts, countingListener); - } - for (int part = 0; part < totalParts; ++part) { - throwIfTaskCancelled(); - task.setProgress(totalParts, part); - BytesArray definition = chunkIterator.next(); - indexPart(part, totalParts, size, definition, countingListener.acquire(ack -> {})); - } - task.setProgress(totalParts, totalParts); + PutTrainedModelDefinitionPartAction.Request finalModelPartRequest = new PutTrainedModelDefinitionPartAction.Request( + modelId, + definition, + totalParts - 1, + size, + totalParts, + true + ); - checkDownloadComplete(chunkIterator, totalParts); - } catch (Exception e) { - countingListener.acquire().onFailure(e); - } - } + executeRequestIfNotCancelled(PutTrainedModelDefinitionPartAction.INSTANCE, finalModelPartRequest); + logger.debug(format("finished importing model [%s] using [%d] parts", modelId, totalParts)); } - private void uploadVocabulary(ModelLoaderUtils.VocabularyParts vocabularyParts, RefCountingListener countingListener) { + private void uploadVocabulary() throws URISyntaxException { + ModelLoaderUtils.VocabularyParts vocabularyParts = ModelLoaderUtils.loadVocabulary( + ModelLoaderUtils.resolvePackageLocation(config.getModelRepository(), config.getVocabularyFile()) + ); + PutTrainedModelVocabularyAction.Request request = new PutTrainedModelVocabularyAction.Request( modelId, vocabularyParts.vocab(), @@ -260,58 +136,17 @@ private void uploadVocabulary(ModelLoaderUtils.VocabularyParts vocabularyParts, true ); - client.execute(PutTrainedModelVocabularyAction.INSTANCE, request, countingListener.acquire(r -> { - logger.debug(() -> format("[%s] imported model vocabulary [%s]", modelId, config.getVocabularyFile())); - })); - } - - private void indexPart(int partIndex, int totalParts, long totalSize, BytesArray bytes, ActionListener listener) { - PutTrainedModelDefinitionPartAction.Request modelPartRequest = new PutTrainedModelDefinitionPartAction.Request( - modelId, - bytes, - partIndex, - totalSize, - totalParts, - true - ); - - client.execute(PutTrainedModelDefinitionPartAction.INSTANCE, modelPartRequest, listener); - } - - private void checkDownloadComplete(List downloaders) { - long totalBytesRead = downloaders.stream().mapToLong(ModelLoaderUtils.HttpStreamChunker::getTotalBytesRead).sum(); - int totalParts = downloaders.stream().mapToInt(ModelLoaderUtils.HttpStreamChunker::getCurrentPart).sum(); - checkSize(totalBytesRead); - logger.debug(format("finished importing model [%s] using [%d] parts", modelId, totalParts)); + executeRequestIfNotCancelled(PutTrainedModelVocabularyAction.INSTANCE, request); } - private void checkDownloadComplete(ModelLoaderUtils.InputStreamChunker fileInputStream, int totalParts) { - checkSha256(fileInputStream.getSha256()); - checkSize(fileInputStream.getTotalBytesRead()); - logger.debug(format("finished importing model [%s] using [%d] parts", modelId, totalParts)); - } - - private void checkSha256(String sha256) { - if (config.getSha256().equals(sha256) == false) { - String message = format("Model sha256 checksums do not match, expected [%s] but got [%s]", config.getSha256(), sha256); - - throw new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR); - } - } - - private void checkSize(long definitionSize) { - if (config.getSize() != definitionSize) { - String message = format("Model size does not match, expected [%d] but got [%d]", config.getSize(), definitionSize); - throw new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR); - } - } - - private void throwIfTaskCancelled() { + private void executeRequestIfNotCancelled( + ActionType action, + Request request + ) { if (task.isCancelled()) { - logger.info("Model [{}] download task cancelled", modelId); - throw new TaskCancelledException( - format("Model [%s] download task cancelled with reason [%s]", modelId, task.getReasonCancelled()) - ); + throw new TaskCancelledException(format("task cancelled with reason [%s]", task.getReasonCancelled())); } + + client.execute(action, request).actionGet(); } } diff --git a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtils.java b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtils.java index 42bfbb249b623..2f3f9cbf3f32c 100644 --- a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtils.java +++ b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtils.java @@ -17,7 +17,6 @@ import org.elasticsearch.common.io.Streams; import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.unit.ByteSizeValue; -import org.elasticsearch.core.Nullable; import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.XContentParser; @@ -35,20 +34,16 @@ import java.security.AccessController; import java.security.MessageDigest; import java.security.PrivilegedAction; -import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; import java.util.stream.Collectors; import static java.net.HttpURLConnection.HTTP_MOVED_PERM; import static java.net.HttpURLConnection.HTTP_MOVED_TEMP; import static java.net.HttpURLConnection.HTTP_NOT_FOUND; import static java.net.HttpURLConnection.HTTP_OK; -import static java.net.HttpURLConnection.HTTP_PARTIAL; import static java.net.HttpURLConnection.HTTP_SEE_OTHER; /** @@ -66,73 +61,6 @@ final class ModelLoaderUtils { record VocabularyParts(List vocab, List merges, List scores) {} - // Range in bytes - record RequestRange(long rangeStart, long rangeEnd, int startPart, int numParts) { - public String bytesRange() { - return "bytes=" + rangeStart + "-" + rangeEnd; - } - } - - static class HttpStreamChunker { - - record BytesAndPartIndex(BytesArray bytes, int partIndex) {} - - private final InputStream inputStream; - private final int chunkSize; - private final AtomicLong totalBytesRead = new AtomicLong(); - private final AtomicInteger currentPart; - private final int lastPartNumber; - - HttpStreamChunker(URI uri, RequestRange range, int chunkSize) { - var inputStream = getHttpOrHttpsInputStream(uri, range); - this.inputStream = inputStream; - this.chunkSize = chunkSize; - this.lastPartNumber = range.startPart() + range.numParts(); - this.currentPart = new AtomicInteger(range.startPart()); - } - - // This ctor exists for testing purposes only. - HttpStreamChunker(InputStream inputStream, RequestRange range, int chunkSize) { - this.inputStream = inputStream; - this.chunkSize = chunkSize; - this.lastPartNumber = range.startPart() + range.numParts(); - this.currentPart = new AtomicInteger(range.startPart()); - } - - public boolean hasNext() { - return currentPart.get() < lastPartNumber; - } - - public BytesAndPartIndex next() throws IOException { - int bytesRead = 0; - byte[] buf = new byte[chunkSize]; - - while (bytesRead < chunkSize) { - int read = inputStream.read(buf, bytesRead, chunkSize - bytesRead); - // EOF?? - if (read == -1) { - break; - } - bytesRead += read; - } - - if (bytesRead > 0) { - totalBytesRead.addAndGet(bytesRead); - return new BytesAndPartIndex(new BytesArray(buf, 0, bytesRead), currentPart.getAndIncrement()); - } else { - return new BytesAndPartIndex(BytesArray.EMPTY, currentPart.get()); - } - } - - public long getTotalBytesRead() { - return totalBytesRead.get(); - } - - public int getCurrentPart() { - return currentPart.get(); - } - } - static class InputStreamChunker { private final InputStream inputStream; @@ -173,14 +101,14 @@ public int getTotalBytesRead() { } } - static InputStream getInputStreamFromModelRepository(URI uri) { + static InputStream getInputStreamFromModelRepository(URI uri) throws IOException { String scheme = uri.getScheme().toLowerCase(Locale.ROOT); // if you add a scheme here, also add it to the bootstrap check in {@link MachineLearningPackageLoader#validateModelRepository} switch (scheme) { case "http": case "https": - return getHttpOrHttpsInputStream(uri, null); + return getHttpOrHttpsInputStream(uri); case "file": return getFileInputStream(uri); default: @@ -188,11 +116,6 @@ static InputStream getInputStreamFromModelRepository(URI uri) { } } - static boolean uriIsFile(URI uri) { - String scheme = uri.getScheme().toLowerCase(Locale.ROOT); - return "file".equals(scheme); - } - static VocabularyParts loadVocabulary(URI uri) { if (uri.getPath().endsWith(".json")) { try (InputStream vocabInputStream = getInputStreamFromModelRepository(uri)) { @@ -251,7 +174,7 @@ private ModelLoaderUtils() {} @SuppressWarnings("'java.lang.SecurityManager' is deprecated and marked for removal ") @SuppressForbidden(reason = "we need socket connection to download") - private static InputStream getHttpOrHttpsInputStream(URI uri, @Nullable RequestRange range) { + private static InputStream getHttpOrHttpsInputStream(URI uri) throws IOException { assert uri.getUserInfo() == null : "URI's with credentials are not supported"; @@ -263,30 +186,18 @@ private static InputStream getHttpOrHttpsInputStream(URI uri, @Nullable RequestR PrivilegedAction privilegedHttpReader = () -> { try { HttpURLConnection conn = (HttpURLConnection) uri.toURL().openConnection(); - if (range != null) { - conn.setRequestProperty("Range", range.bytesRange()); - } switch (conn.getResponseCode()) { case HTTP_OK: - case HTTP_PARTIAL: return conn.getInputStream(); - case HTTP_MOVED_PERM: case HTTP_MOVED_TEMP: case HTTP_SEE_OTHER: throw new IllegalStateException("redirects aren't supported yet"); case HTTP_NOT_FOUND: throw new ResourceNotFoundException("{} not found", uri); - case 416: // Range not satisfiable, for some reason not in the list of constants - throw new IllegalStateException("Invalid request range [" + range.bytesRange() + "]"); default: int responseCode = conn.getResponseCode(); - throw new ElasticsearchStatusException( - "error during downloading {}. Got response code {}", - RestStatus.fromCode(responseCode), - uri, - responseCode - ); + throw new ElasticsearchStatusException("error during downloading {}", RestStatus.fromCode(responseCode), uri); } } catch (IOException e) { throw new UncheckedIOException(e); @@ -298,7 +209,7 @@ private static InputStream getHttpOrHttpsInputStream(URI uri, @Nullable RequestR @SuppressWarnings("'java.lang.SecurityManager' is deprecated and marked for removal ") @SuppressForbidden(reason = "we need load model data from a file") - static InputStream getFileInputStream(URI uri) { + private static InputStream getFileInputStream(URI uri) { SecurityManager sm = System.getSecurityManager(); if (sm != null) { @@ -321,53 +232,4 @@ static InputStream getFileInputStream(URI uri) { return AccessController.doPrivileged(privilegedFileReader); } - /** - * Split a stream of size {@code sizeInBytes} into {@code numberOfStreams} +1 - * ranges aligned on {@code chunkSizeBytes} boundaries. Each range contains a - * whole number of chunks. - * The first {@code numberOfStreams} ranges will be split evenly (in terms of - * number of chunks not the byte size), the final range split - * is for the single final chunk and will be no more than {@code chunkSizeBytes} - * in size. The separate range for the final chunk is because when streaming and - * uploading a large model definition, writing the last part has to handled - * as a special case. - * @param sizeInBytes The total size of the stream - * @param numberOfStreams Divide the bulk of the size into this many streams. - * @param chunkSizeBytes The size of each chunk - * @return List of {@code numberOfStreams} + 1 ranges. - */ - static List split(long sizeInBytes, int numberOfStreams, long chunkSizeBytes) { - int numberOfChunks = (int) ((sizeInBytes + chunkSizeBytes - 1) / chunkSizeBytes); - - var ranges = new ArrayList(); - - int baseChunksPerStream = numberOfChunks / numberOfStreams; - int remainder = numberOfChunks % numberOfStreams; - long startOffset = 0; - int startChunkIndex = 0; - - for (int i = 0; i < numberOfStreams - 1; i++) { - int numChunksInStream = (i < remainder) ? baseChunksPerStream + 1 : baseChunksPerStream; - long rangeEnd = startOffset + (numChunksInStream * chunkSizeBytes) - 1; // range index is 0 based - ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, numChunksInStream)); - startOffset = rangeEnd + 1; // range is inclusive start and end - startChunkIndex += numChunksInStream; - } - - // Want the final range request to be a single chunk - if (baseChunksPerStream > 1) { - int numChunksExcludingFinal = baseChunksPerStream - 1; - long rangeEnd = startOffset + (numChunksExcludingFinal * chunkSizeBytes) - 1; - ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, numChunksExcludingFinal)); - - startOffset = rangeEnd + 1; - startChunkIndex += numChunksExcludingFinal; - } - - // The final range is a single chunk the end of which should not exceed sizeInBytes - long rangeEnd = Math.min(sizeInBytes, startOffset + (baseChunksPerStream * chunkSizeBytes)) - 1; - ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, 1)); - - return ranges; - } } diff --git a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportGetTrainedModelPackageConfigAction.java b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportGetTrainedModelPackageConfigAction.java index 68f869742d9e5..ba50f2f6a6b74 100644 --- a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportGetTrainedModelPackageConfigAction.java +++ b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportGetTrainedModelPackageConfigAction.java @@ -77,7 +77,7 @@ protected void masterOperation(Task task, Request request, ClusterState state, A String packagedModelId = request.getPackagedModelId(); logger.debug(() -> format("Fetch package manifest for [%s] from [%s]", packagedModelId, repository)); - threadPool.executor(MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME).execute(() -> { + threadPool.executor(MachineLearningPackageLoader.UTILITY_THREAD_POOL_NAME).execute(() -> { try { URI uri = ModelLoaderUtils.resolvePackageLocation(repository, packagedModelId + ModelLoaderUtils.METADATA_FILE_EXTENSION); InputStream inputStream = ModelLoaderUtils.getInputStreamFromModelRepository(uri); diff --git a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java index 8ca029d01d3c0..70dcee165d3f6 100644 --- a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java +++ b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java @@ -37,12 +37,14 @@ import org.elasticsearch.xpack.core.ml.action.NodeAcknowledgedResponse; import org.elasticsearch.xpack.core.ml.packageloader.action.LoadTrainedModelPackageAction; import org.elasticsearch.xpack.core.ml.packageloader.action.LoadTrainedModelPackageAction.Request; +import org.elasticsearch.xpack.ml.packageloader.MachineLearningPackageLoader; import java.io.IOException; import java.net.MalformedURLException; import java.net.URISyntaxException; import java.util.Map; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; @@ -96,13 +98,11 @@ protected void masterOperation(Task task, Request request, ClusterState state, A parentTaskAssigningClient, request.getModelId(), request.getModelPackageConfig(), - downloadTask, - threadPool + downloadTask ); - var downloadCompleteListener = request.isWaitForCompletion() ? listener : ActionListener.noop(); - - importModel(client, taskManager, request, modelImporter, downloadCompleteListener, downloadTask); + threadPool.executor(MachineLearningPackageLoader.UTILITY_THREAD_POOL_NAME) + .execute(() -> importModel(client, taskManager, request, modelImporter, listener, downloadTask)); } catch (Exception e) { taskManager.unregister(downloadTask); listener.onFailure(e); @@ -136,12 +136,16 @@ static void importModel( ActionListener listener, Task task ) { - final String modelId = request.getModelId(); - final long relativeStartNanos = System.nanoTime(); + String modelId = request.getModelId(); + final AtomicReference exceptionRef = new AtomicReference<>(); + + try { + final long relativeStartNanos = System.nanoTime(); - logAndWriteNotificationAtLevel(auditClient, modelId, "starting model import", Level.INFO); + logAndWriteNotificationAtLevel(auditClient, modelId, "starting model import", Level.INFO); + + modelImporter.doImport(); - var finishListener = ActionListener.wrap(success -> { final long totalRuntimeNanos = System.nanoTime() - relativeStartNanos; logAndWriteNotificationAtLevel( auditClient, @@ -149,25 +153,29 @@ static void importModel( format("finished model import after [%d] seconds", TimeUnit.NANOSECONDS.toSeconds(totalRuntimeNanos)), Level.INFO ); - listener.onResponse(AcknowledgedResponse.TRUE); - }, exception -> listener.onFailure(processException(auditClient, modelId, exception))); - - modelImporter.doImport(ActionListener.runAfter(finishListener, () -> taskManager.unregister(task))); - } + } catch (TaskCancelledException e) { + recordError(auditClient, modelId, exceptionRef, e, Level.WARNING); + } catch (ElasticsearchException e) { + recordError(auditClient, modelId, exceptionRef, e, Level.ERROR); + } catch (MalformedURLException e) { + recordError(auditClient, modelId, "an invalid URL", exceptionRef, e, Level.ERROR, RestStatus.INTERNAL_SERVER_ERROR); + } catch (URISyntaxException e) { + recordError(auditClient, modelId, "an invalid URL syntax", exceptionRef, e, Level.ERROR, RestStatus.INTERNAL_SERVER_ERROR); + } catch (IOException e) { + recordError(auditClient, modelId, "an IOException", exceptionRef, e, Level.ERROR, RestStatus.SERVICE_UNAVAILABLE); + } catch (Exception e) { + recordError(auditClient, modelId, "an Exception", exceptionRef, e, Level.ERROR, RestStatus.INTERNAL_SERVER_ERROR); + } finally { + taskManager.unregister(task); + + if (request.isWaitForCompletion()) { + if (exceptionRef.get() != null) { + listener.onFailure(exceptionRef.get()); + } else { + listener.onResponse(AcknowledgedResponse.TRUE); + } - static Exception processException(Client auditClient, String modelId, Exception e) { - if (e instanceof TaskCancelledException te) { - return recordError(auditClient, modelId, te, Level.WARNING); - } else if (e instanceof ElasticsearchException es) { - return recordError(auditClient, modelId, es, Level.ERROR); - } else if (e instanceof MalformedURLException) { - return recordError(auditClient, modelId, "an invalid URL", e, Level.ERROR, RestStatus.BAD_REQUEST); - } else if (e instanceof URISyntaxException) { - return recordError(auditClient, modelId, "an invalid URL syntax", e, Level.ERROR, RestStatus.BAD_REQUEST); - } else if (e instanceof IOException) { - return recordError(auditClient, modelId, "an IOException", e, Level.ERROR, RestStatus.SERVICE_UNAVAILABLE); - } else { - return recordError(auditClient, modelId, "an Exception", e, Level.ERROR, RestStatus.INTERNAL_SERVER_ERROR); + } } } @@ -205,16 +213,30 @@ public ModelDownloadTask createTask(long id, String type, String action, TaskId } } - private static Exception recordError(Client client, String modelId, ElasticsearchException e, Level level) { + private static void recordError( + Client client, + String modelId, + AtomicReference exceptionRef, + ElasticsearchException e, + Level level + ) { String message = format("Model importing failed due to [%s]", e.getDetailedMessage()); logAndWriteNotificationAtLevel(client, modelId, message, level); - return e; + exceptionRef.set(e); } - private static Exception recordError(Client client, String modelId, String failureType, Exception e, Level level, RestStatus status) { + private static void recordError( + Client client, + String modelId, + String failureType, + AtomicReference exceptionRef, + Exception e, + Level level, + RestStatus status + ) { String message = format("Model importing failed due to %s [%s]", failureType, e); logAndWriteNotificationAtLevel(client, modelId, message, level); - return new ElasticsearchStatusException(message, status, e); + exceptionRef.set(new ElasticsearchStatusException(message, status, e)); } private static void logAndWriteNotificationAtLevel(Client client, String modelId, String message, Level level) { diff --git a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoaderTests.java b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoaderTests.java index 2e487b6a9624c..967d1b4ba4b6a 100644 --- a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoaderTests.java +++ b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoaderTests.java @@ -7,13 +7,9 @@ package org.elasticsearch.xpack.ml.packageloader; -import org.elasticsearch.common.settings.Setting; -import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.PathUtils; import org.elasticsearch.test.ESTestCase; -import java.util.List; - import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.oneOf; @@ -84,12 +80,4 @@ public void testValidateModelRepository() { assertEquals("xpack.ml.model_repository does not support authentication", e.getMessage()); } - - public void testThreadPoolHasSingleThread() { - var fixedThreadPool = MachineLearningPackageLoader.modelDownloadExecutor(Settings.EMPTY); - List> settings = fixedThreadPool.getRegisteredSettings(); - var sizeSettting = settings.stream().filter(s -> s.getKey().startsWith("xpack.ml.model_download_thread_pool")).findFirst(); - assertTrue(sizeSettting.isPresent()); - assertEquals(5, sizeSettting.get().get(Settings.EMPTY)); - } } diff --git a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelDownloadTaskTests.java b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelDownloadTaskTests.java index 3a682fb6a5094..0afd08c70cf45 100644 --- a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelDownloadTaskTests.java +++ b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelDownloadTaskTests.java @@ -20,7 +20,14 @@ public class ModelDownloadTaskTests extends ESTestCase { public void testStatus() { - var task = testTask(); + var task = new ModelDownloadTask( + 0L, + MODEL_IMPORT_TASK_TYPE, + MODEL_IMPORT_TASK_ACTION, + downloadModelTaskDescription("foo"), + TaskId.EMPTY_TASK_ID, + Map.of() + ); task.setProgress(100, 0); var taskInfo = task.taskInfo("node", true); @@ -32,15 +39,4 @@ public void testStatus() { status = Strings.toString(taskInfo.status()); assertThat(status, containsString("{\"total_parts\":100,\"downloaded_parts\":1}")); } - - public static ModelDownloadTask testTask() { - return new ModelDownloadTask( - 0L, - MODEL_IMPORT_TASK_TYPE, - MODEL_IMPORT_TASK_ACTION, - downloadModelTaskDescription("foo"), - TaskId.EMPTY_TASK_ID, - Map.of() - ); - } } diff --git a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporterTests.java b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporterTests.java deleted file mode 100644 index 99efb331a350c..0000000000000 --- a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporterTests.java +++ /dev/null @@ -1,316 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.ml.packageloader.action; - -import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.LatchedActionListener; -import org.elasticsearch.action.support.ActionTestUtils; -import org.elasticsearch.action.support.master.AcknowledgedResponse; -import org.elasticsearch.client.internal.Client; -import org.elasticsearch.common.hash.MessageDigests; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.threadpool.TestThreadPool; -import org.elasticsearch.xpack.core.ml.action.PutTrainedModelDefinitionPartAction; -import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig; -import org.elasticsearch.xpack.ml.packageloader.MachineLearningPackageLoader; -import org.junit.After; -import org.junit.Before; - -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.net.URISyntaxException; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.atomic.AtomicReference; - -import static org.hamcrest.Matchers.containsString; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -public class ModelImporterTests extends ESTestCase { - - private TestThreadPool threadPool; - - @Before - public void createThreadPool() { - threadPool = createThreadPool(MachineLearningPackageLoader.modelDownloadExecutor(Settings.EMPTY)); - } - - @After - public void closeThreadPool() { - threadPool.close(); - } - - public void testDownloadModelDefinition() throws InterruptedException, URISyntaxException { - var client = mockClient(false); - var task = ModelDownloadTaskTests.testTask(); - var config = mockConfigWithRepoLinks(); - var vocab = new ModelLoaderUtils.VocabularyParts(List.of(), List.of(), List.of()); - - int totalParts = 5; - int chunkSize = 10; - long size = totalParts * chunkSize; - var modelDef = modelDefinition(totalParts, chunkSize); - var streamers = mockHttpStreamChunkers(modelDef, chunkSize, 2); - - var digest = computeDigest(modelDef); - when(config.getSha256()).thenReturn(digest); - when(config.getSize()).thenReturn(size); - - var importer = new ModelImporter(client, "foo", config, task, threadPool); - - var latch = new CountDownLatch(1); - var latchedListener = new LatchedActionListener(ActionTestUtils.assertNoFailureListener(ignore -> {}), latch); - importer.downloadModelDefinition(size, totalParts, vocab, streamers, latchedListener); - - latch.await(); - verify(client, times(totalParts)).execute(eq(PutTrainedModelDefinitionPartAction.INSTANCE), any(), any()); - assertEquals(totalParts - 1, task.getStatus().downloadProgress().downloadedParts()); - assertEquals(totalParts, task.getStatus().downloadProgress().totalParts()); - } - - public void testReadModelDefinitionFromFile() throws InterruptedException, URISyntaxException { - var client = mockClient(false); - var task = ModelDownloadTaskTests.testTask(); - var config = mockConfigWithRepoLinks(); - var vocab = new ModelLoaderUtils.VocabularyParts(List.of(), List.of(), List.of()); - - int totalParts = 3; - int chunkSize = 10; - long size = totalParts * chunkSize; - var modelDef = modelDefinition(totalParts, chunkSize); - - var digest = computeDigest(modelDef); - when(config.getSha256()).thenReturn(digest); - when(config.getSize()).thenReturn(size); - - var importer = new ModelImporter(client, "foo", config, task, threadPool); - var streamChunker = new ModelLoaderUtils.InputStreamChunker(new ByteArrayInputStream(modelDef), chunkSize); - - var latch = new CountDownLatch(1); - var latchedListener = new LatchedActionListener(ActionTestUtils.assertNoFailureListener(ignore -> {}), latch); - importer.readModelDefinitionFromFile(size, totalParts, streamChunker, vocab, latchedListener); - - latch.await(); - verify(client, times(totalParts)).execute(eq(PutTrainedModelDefinitionPartAction.INSTANCE), any(), any()); - assertEquals(totalParts, task.getStatus().downloadProgress().downloadedParts()); - assertEquals(totalParts, task.getStatus().downloadProgress().totalParts()); - } - - public void testSizeMismatch() throws InterruptedException, URISyntaxException { - var client = mockClient(false); - var task = mock(ModelDownloadTask.class); - var config = mockConfigWithRepoLinks(); - - int totalParts = 5; - int chunkSize = 10; - long size = totalParts * chunkSize; - var modelDef = modelDefinition(totalParts, chunkSize); - var streamers = mockHttpStreamChunkers(modelDef, chunkSize, 2); - - var digest = computeDigest(modelDef); - when(config.getSha256()).thenReturn(digest); - when(config.getSize()).thenReturn(size - 1); // expected size and read size are different - - var exceptionHolder = new AtomicReference(); - - var latch = new CountDownLatch(1); - var latchedListener = new LatchedActionListener( - ActionTestUtils.assertNoSuccessListener(exceptionHolder::set), - latch - ); - - var importer = new ModelImporter(client, "foo", config, task, threadPool); - importer.downloadModelDefinition(size, totalParts, null, streamers, latchedListener); - - latch.await(); - assertThat(exceptionHolder.get().getMessage(), containsString("Model size does not match")); - verify(client, times(totalParts)).execute(eq(PutTrainedModelDefinitionPartAction.INSTANCE), any(), any()); - } - - public void testDigestMismatch() throws InterruptedException, URISyntaxException { - var client = mockClient(false); - var task = mock(ModelDownloadTask.class); - var config = mockConfigWithRepoLinks(); - - int totalParts = 5; - int chunkSize = 10; - long size = totalParts * chunkSize; - var modelDef = modelDefinition(totalParts, chunkSize); - var streamers = mockHttpStreamChunkers(modelDef, chunkSize, 2); - - when(config.getSha256()).thenReturn("0x"); // digest is different - when(config.getSize()).thenReturn(size); - - var exceptionHolder = new AtomicReference(); - var latch = new CountDownLatch(1); - var latchedListener = new LatchedActionListener( - ActionTestUtils.assertNoSuccessListener(exceptionHolder::set), - latch - ); - - var importer = new ModelImporter(client, "foo", config, task, threadPool); - // Message digest can only be calculated for the file reader - var streamChunker = new ModelLoaderUtils.InputStreamChunker(new ByteArrayInputStream(modelDef), chunkSize); - importer.readModelDefinitionFromFile(size, totalParts, streamChunker, null, latchedListener); - - latch.await(); - assertThat(exceptionHolder.get().getMessage(), containsString("Model sha256 checksums do not match")); - verify(client, times(totalParts)).execute(eq(PutTrainedModelDefinitionPartAction.INSTANCE), any(), any()); - } - - public void testPutFailure() throws InterruptedException, URISyntaxException { - var client = mockClient(true); // client will fail put - var task = mock(ModelDownloadTask.class); - var config = mockConfigWithRepoLinks(); - - int totalParts = 4; - int chunkSize = 10; - long size = totalParts * chunkSize; - var modelDef = modelDefinition(totalParts, chunkSize); - var streamers = mockHttpStreamChunkers(modelDef, chunkSize, 1); - - var exceptionHolder = new AtomicReference(); - var latch = new CountDownLatch(1); - var latchedListener = new LatchedActionListener( - ActionTestUtils.assertNoSuccessListener(exceptionHolder::set), - latch - ); - - var importer = new ModelImporter(client, "foo", config, task, threadPool); - importer.downloadModelDefinition(size, totalParts, null, streamers, latchedListener); - - latch.await(); - assertThat(exceptionHolder.get().getMessage(), containsString("put model part failed")); - verify(client, times(1)).execute(eq(PutTrainedModelDefinitionPartAction.INSTANCE), any(), any()); - } - - public void testReadFailure() throws IOException, InterruptedException, URISyntaxException { - var client = mockClient(true); - var task = mock(ModelDownloadTask.class); - var config = mockConfigWithRepoLinks(); - - int totalParts = 4; - int chunkSize = 10; - long size = totalParts * chunkSize; - - var streamer = mock(ModelLoaderUtils.HttpStreamChunker.class); - when(streamer.hasNext()).thenReturn(true); - when(streamer.next()).thenThrow(new IOException("stream failed")); // fail the read - - var exceptionHolder = new AtomicReference(); - var latch = new CountDownLatch(1); - var latchedListener = new LatchedActionListener( - ActionTestUtils.assertNoSuccessListener(exceptionHolder::set), - latch - ); - - var importer = new ModelImporter(client, "foo", config, task, threadPool); - importer.downloadModelDefinition(size, totalParts, null, List.of(streamer), latchedListener); - - latch.await(); - assertThat(exceptionHolder.get().getMessage(), containsString("stream failed")); - } - - @SuppressWarnings("unchecked") - public void testUploadVocabFailure() throws InterruptedException, URISyntaxException { - var client = mock(Client.class); - doAnswer(invocation -> { - ActionListener listener = (ActionListener) invocation.getArguments()[2]; - listener.onFailure(new ElasticsearchStatusException("put vocab failed", RestStatus.BAD_REQUEST)); - return null; - }).when(client).execute(eq(PutTrainedModelVocabularyAction.INSTANCE), any(), any()); - - var task = mock(ModelDownloadTask.class); - var config = mockConfigWithRepoLinks(); - - var vocab = new ModelLoaderUtils.VocabularyParts(List.of(), List.of(), List.of()); - - var exceptionHolder = new AtomicReference(); - var latch = new CountDownLatch(1); - var latchedListener = new LatchedActionListener( - ActionTestUtils.assertNoSuccessListener(exceptionHolder::set), - latch - ); - - var importer = new ModelImporter(client, "foo", config, task, threadPool); - importer.downloadModelDefinition(100, 5, vocab, List.of(), latchedListener); - - latch.await(); - assertThat(exceptionHolder.get().getMessage(), containsString("put vocab failed")); - verify(client, times(1)).execute(eq(PutTrainedModelVocabularyAction.INSTANCE), any(), any()); - verify(client, never()).execute(eq(PutTrainedModelDefinitionPartAction.INSTANCE), any(), any()); - } - - private List mockHttpStreamChunkers(byte[] modelDef, int chunkSize, int numStreams) { - var ranges = ModelLoaderUtils.split(modelDef.length, numStreams, chunkSize); - - var result = new ArrayList(ranges.size()); - for (var range : ranges) { - int len = range.numParts() * chunkSize; - var modelDefStream = new ByteArrayInputStream(modelDef, (int) range.rangeStart(), len); - result.add(new ModelLoaderUtils.HttpStreamChunker(modelDefStream, range, chunkSize)); - } - - return result; - } - - private byte[] modelDefinition(int totalParts, int chunkSize) { - var bytes = new byte[totalParts * chunkSize]; - for (int i = 0; i < totalParts; i++) { - System.arraycopy(randomByteArrayOfLength(chunkSize), 0, bytes, i * chunkSize, chunkSize); - } - return bytes; - } - - private String computeDigest(byte[] modelDef) { - var digest = MessageDigests.sha256(); - digest.update(modelDef); - return MessageDigests.toHexString(digest.digest()); - } - - @SuppressWarnings("unchecked") - private Client mockClient(boolean failPutPart) { - var client = mock(Client.class); - doAnswer(invocation -> { - ActionListener listener = (ActionListener) invocation.getArguments()[2]; - if (failPutPart) { - listener.onFailure(new IllegalStateException("put model part failed")); - } else { - listener.onResponse(AcknowledgedResponse.TRUE); - } - return null; - }).when(client).execute(eq(PutTrainedModelDefinitionPartAction.INSTANCE), any(), any()); - - doAnswer(invocation -> { - ActionListener listener = (ActionListener) invocation.getArguments()[2]; - listener.onResponse(AcknowledgedResponse.TRUE); - return null; - }).when(client).execute(eq(PutTrainedModelVocabularyAction.INSTANCE), any(), any()); - - return client; - } - - private ModelPackageConfig mockConfigWithRepoLinks() { - var config = mock(ModelPackageConfig.class); - when(config.getModelRepository()).thenReturn("https://models.models"); - when(config.getPackagedModelId()).thenReturn("my-model"); - return config; - } -} diff --git a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtilsTests.java b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtilsTests.java index f421a7b44e7f1..661cd12f99957 100644 --- a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtilsTests.java +++ b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtilsTests.java @@ -17,7 +17,6 @@ import java.nio.charset.StandardCharsets; import static org.hamcrest.Matchers.contains; -import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.core.Is.is; public class ModelLoaderUtilsTests extends ESTestCase { @@ -81,13 +80,14 @@ public void testSha256AndSize() throws IOException { assertEquals(64, expectedDigest.length()); int chunkSize = randomIntBetween(100, 10_000); - int totalParts = (bytes.length + chunkSize - 1) / chunkSize; ModelLoaderUtils.InputStreamChunker inputStreamChunker = new ModelLoaderUtils.InputStreamChunker( new ByteArrayInputStream(bytes), chunkSize ); + int totalParts = (bytes.length + chunkSize - 1) / chunkSize; + for (int part = 0; part < totalParts - 1; ++part) { assertEquals(chunkSize, inputStreamChunker.next().length()); } @@ -112,40 +112,4 @@ public void testParseVocabulary() throws IOException { assertThat(parsedVocab.merges(), contains("mergefoo", "mergebar", "mergebaz")); assertThat(parsedVocab.scores(), contains(1.0, 2.0, 3.0)); } - - public void testSplitIntoRanges() { - long totalSize = randomLongBetween(10_000, 50_000_000); - int numStreams = randomIntBetween(1, 10); - int chunkSize = 1024; - var ranges = ModelLoaderUtils.split(totalSize, numStreams, chunkSize); - assertThat(ranges, hasSize(numStreams + 1)); - - int expectedNumChunks = (int) ((totalSize + chunkSize - 1) / chunkSize); - assertThat(ranges.stream().mapToInt(ModelLoaderUtils.RequestRange::numParts).sum(), is(expectedNumChunks)); - - long startBytes = 0; - int startPartIndex = 0; - for (int i = 0; i < ranges.size() - 1; i++) { - assertThat(ranges.get(i).rangeStart(), is(startBytes)); - long end = startBytes + ((long) ranges.get(i).numParts() * chunkSize) - 1; - assertThat(ranges.get(i).rangeEnd(), is(end)); - long expectedNumBytesInRange = (long) chunkSize * ranges.get(i).numParts() - 1; - assertThat(ranges.get(i).rangeEnd() - ranges.get(i).rangeStart(), is(expectedNumBytesInRange)); - assertThat(ranges.get(i).startPart(), is(startPartIndex)); - - startBytes = end + 1; - startPartIndex += ranges.get(i).numParts(); - } - - var finalRange = ranges.get(ranges.size() - 1); - assertThat(finalRange.rangeStart(), is(startBytes)); - assertThat(finalRange.rangeEnd(), is(totalSize - 1)); - assertThat(finalRange.numParts(), is(1)); - } - - public void testRangeRequestBytesRange() { - long start = randomLongBetween(0, 2 << 10); - long end = randomLongBetween(start + 1, 2 << 11); - assertEquals("bytes=" + start + "-" + end, new ModelLoaderUtils.RequestRange(start, end, 0, 1).bytesRange()); - } } diff --git a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackageTests.java b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackageTests.java index cbcfd5b760779..a3f59e13f2f5b 100644 --- a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackageTests.java +++ b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackageTests.java @@ -33,7 +33,7 @@ import static org.hamcrest.core.Is.is; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -42,7 +42,7 @@ public class TransportLoadTrainedModelPackageTests extends ESTestCase { private static final String MODEL_IMPORT_FAILURE_MSG_FORMAT = "Model importing failed due to %s [%s]"; public void testSendsFinishedUploadNotification() { - var uploader = createUploader(null); + var uploader = mock(ModelImporter.class); var taskManager = mock(TaskManager.class); var task = mock(Task.class); var client = mock(Client.class); @@ -63,49 +63,49 @@ public void testSendsFinishedUploadNotification() { assertThat(notificationArg.getValue().getMessage(), CoreMatchers.containsString("finished model import after")); } - public void testSendsErrorNotificationForInternalError() throws Exception { + public void testSendsErrorNotificationForInternalError() throws URISyntaxException, IOException { ElasticsearchStatusException exception = new ElasticsearchStatusException("exception", RestStatus.INTERNAL_SERVER_ERROR); String message = format("Model importing failed due to [%s]", exception.toString()); assertUploadCallsOnFailure(exception, message, Level.ERROR); } - public void testSendsErrorNotificationForMalformedURL() throws Exception { + public void testSendsErrorNotificationForMalformedURL() throws URISyntaxException, IOException { MalformedURLException exception = new MalformedURLException("exception"); String message = format(MODEL_IMPORT_FAILURE_MSG_FORMAT, "an invalid URL", exception.toString()); - assertUploadCallsOnFailure(exception, message, RestStatus.BAD_REQUEST, Level.ERROR); + assertUploadCallsOnFailure(exception, message, RestStatus.INTERNAL_SERVER_ERROR, Level.ERROR); } - public void testSendsErrorNotificationForURISyntax() throws Exception { + public void testSendsErrorNotificationForURISyntax() throws URISyntaxException, IOException { URISyntaxException exception = mock(URISyntaxException.class); String message = format(MODEL_IMPORT_FAILURE_MSG_FORMAT, "an invalid URL syntax", exception.toString()); - assertUploadCallsOnFailure(exception, message, RestStatus.BAD_REQUEST, Level.ERROR); + assertUploadCallsOnFailure(exception, message, RestStatus.INTERNAL_SERVER_ERROR, Level.ERROR); } - public void testSendsErrorNotificationForIOException() throws Exception { + public void testSendsErrorNotificationForIOException() throws URISyntaxException, IOException { IOException exception = mock(IOException.class); String message = format(MODEL_IMPORT_FAILURE_MSG_FORMAT, "an IOException", exception.toString()); assertUploadCallsOnFailure(exception, message, RestStatus.SERVICE_UNAVAILABLE, Level.ERROR); } - public void testSendsErrorNotificationForException() throws Exception { + public void testSendsErrorNotificationForException() throws URISyntaxException, IOException { RuntimeException exception = mock(RuntimeException.class); String message = format(MODEL_IMPORT_FAILURE_MSG_FORMAT, "an Exception", exception.toString()); assertUploadCallsOnFailure(exception, message, RestStatus.INTERNAL_SERVER_ERROR, Level.ERROR); } - public void testSendsWarningNotificationForTaskCancelledException() throws Exception { + public void testSendsWarningNotificationForTaskCancelledException() throws URISyntaxException, IOException { TaskCancelledException exception = new TaskCancelledException("cancelled"); String message = format("Model importing failed due to [%s]", exception.toString()); assertUploadCallsOnFailure(exception, message, Level.WARNING); } - public void testCallsOnResponseWithAcknowledgedResponse() throws Exception { + public void testCallsOnResponseWithAcknowledgedResponse() throws URISyntaxException, IOException { var client = mock(Client.class); var taskManager = mock(TaskManager.class); var task = mock(Task.class); @@ -134,13 +134,15 @@ public void testDoesNotCallListenerWhenNotWaitingForCompletion() { ); } - private void assertUploadCallsOnFailure(Exception exception, String message, RestStatus status, Level level) throws Exception { + private void assertUploadCallsOnFailure(Exception exception, String message, RestStatus status, Level level) throws URISyntaxException, + IOException { var esStatusException = new ElasticsearchStatusException(message, status, exception); assertNotificationAndOnFailure(exception, esStatusException, message, level); } - private void assertUploadCallsOnFailure(ElasticsearchException exception, String message, Level level) throws Exception { + private void assertUploadCallsOnFailure(ElasticsearchException exception, String message, Level level) throws URISyntaxException, + IOException { assertNotificationAndOnFailure(exception, exception, message, level); } @@ -149,7 +151,7 @@ private void assertNotificationAndOnFailure( ElasticsearchException onFailureException, String message, Level level - ) throws Exception { + ) throws URISyntaxException, IOException { var client = mock(Client.class); var taskManager = mock(TaskManager.class); var task = mock(Task.class); @@ -177,18 +179,11 @@ private void assertNotificationAndOnFailure( verify(taskManager).unregister(task); } - @SuppressWarnings("unchecked") - private ModelImporter createUploader(Exception exception) { + private ModelImporter createUploader(Exception exception) throws URISyntaxException, IOException { ModelImporter uploader = mock(ModelImporter.class); - doAnswer(invocation -> { - ActionListener listener = (ActionListener) invocation.getArguments()[0]; - if (exception != null) { - listener.onFailure(exception); - } else { - listener.onResponse(AcknowledgedResponse.TRUE); - } - return null; - }).when(uploader).doImport(any(ActionListener.class)); + if (exception != null) { + doThrow(exception).when(uploader).doImport(); + } return uploader; } diff --git a/x-pack/plugin/ml/src/main/java/module-info.java b/x-pack/plugin/ml/src/main/java/module-info.java index 0f3fdd836feca..4984fa8912e28 100644 --- a/x-pack/plugin/ml/src/main/java/module-info.java +++ b/x-pack/plugin/ml/src/main/java/module-info.java @@ -37,6 +37,8 @@ exports org.elasticsearch.xpack.ml; exports org.elasticsearch.xpack.ml.action; + exports org.elasticsearch.xpack.ml.aggs.categorization; exports org.elasticsearch.xpack.ml.autoscaling; + exports org.elasticsearch.xpack.ml.job.categorization; exports org.elasticsearch.xpack.ml.notifications; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java index 58feb24480f87..7d5f1d5517de0 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java @@ -12,11 +12,11 @@ import org.elasticsearch.common.util.BytesRefHash; import org.elasticsearch.core.Releasable; -class CategorizationBytesRefHash implements Releasable { +public class CategorizationBytesRefHash implements Releasable { private final BytesRefHash bytesRefHash; - CategorizationBytesRefHash(BytesRefHash bytesRefHash) { + public CategorizationBytesRefHash(BytesRefHash bytesRefHash) { this.bytesRefHash = bytesRefHash; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java index cedaced0f57ee..e55736cf43607 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java @@ -94,11 +94,11 @@ protected CategorizeTextAggregator( true ); } - this.categorizers = bigArrays().newObjectArray(1); + this.categorizers = context.bigArrays().newObjectArray(1); this.similarityThreshold = similarityThreshold; - this.bucketOrds = LongKeyedBucketOrds.build(bigArrays(), CardinalityUpperBound.MANY); + this.bucketOrds = LongKeyedBucketOrds.build(context.bigArrays(), CardinalityUpperBound.MANY); this.bucketCountThresholds = bucketCountThresholds; - this.bytesRefHash = new CategorizationBytesRefHash(new BytesRefHash(2048, bigArrays())); + this.bytesRefHash = new CategorizationBytesRefHash(new BytesRefHash(2048, context.bigArrays())); // TODO: make it possible to choose a language instead of or as well as English for the part-of-speech dictionary this.partOfSpeechDictionary = CategorizationPartOfSpeechDictionary.getInstance(); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategorizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategorizer.java index e1f2404ee56b5..d0088edcb0805 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategorizer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategorizer.java @@ -14,6 +14,8 @@ import org.apache.lucene.util.Accountable; import org.apache.lucene.util.BytesRef; import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; import org.elasticsearch.search.aggregations.AggregationReduceContext; import org.elasticsearch.search.aggregations.InternalAggregations; import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategory.TokenAndWeight; @@ -40,6 +42,25 @@ */ public class TokenListCategorizer implements Accountable { + /** + * TokenListCategorizer that takes ownership of the CategorizationBytesRefHash and releases it when closed. + */ + public static class CloseableTokenListCategorizer extends TokenListCategorizer implements Releasable { + + public CloseableTokenListCategorizer( + CategorizationBytesRefHash bytesRefHash, + CategorizationPartOfSpeechDictionary partOfSpeechDictionary, + float threshold + ) { + super(bytesRefHash, partOfSpeechDictionary, threshold); + } + + @Override + public void close() { + Releasables.close(super.bytesRefHash); + } + } + public static final int MAX_TOKENS = 100; private static final long SHALLOW_SIZE = shallowSizeOfInstance(TokenListCategorizer.class); private static final long SHALLOW_SIZE_OF_ARRAY_LIST = shallowSizeOfInstance(ArrayList.class); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/categorization/AbstractMlTokenizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/categorization/AbstractMlTokenizer.java index c701216b1984b..d7e7683ce0071 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/categorization/AbstractMlTokenizer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/categorization/AbstractMlTokenizer.java @@ -16,8 +16,11 @@ public abstract class AbstractMlTokenizer extends Tokenizer { + @SuppressWarnings("this-escape") protected final CharTermAttribute termAtt = addAttribute(CharTermAttribute.class); + @SuppressWarnings("this-escape") protected final OffsetAttribute offsetAtt = addAttribute(OffsetAttribute.class); + @SuppressWarnings("this-escape") protected final PositionIncrementAttribute posIncrAtt = addAttribute(PositionIncrementAttribute.class); /** diff --git a/x-pack/plugin/security/qa/multi-cluster/build.gradle b/x-pack/plugin/security/qa/multi-cluster/build.gradle index 625b6806ab520..c7b8f81bb7876 100644 --- a/x-pack/plugin/security/qa/multi-cluster/build.gradle +++ b/x-pack/plugin/security/qa/multi-cluster/build.gradle @@ -23,6 +23,8 @@ dependencies { // esql with enrich clusterModules project(':x-pack:plugin:esql') clusterModules project(':x-pack:plugin:enrich') + clusterModules project(':x-pack:plugin:autoscaling') + clusterModules project(':x-pack:plugin:ml') clusterModules(project(":modules:ingest-common")) } diff --git a/x-pack/plugin/security/qa/multi-cluster/src/javaRestTest/java/org/elasticsearch/xpack/remotecluster/RemoteClusterSecurityEsqlIT.java b/x-pack/plugin/security/qa/multi-cluster/src/javaRestTest/java/org/elasticsearch/xpack/remotecluster/RemoteClusterSecurityEsqlIT.java index f5f9410a145cc..1a236ccb6aa06 100644 --- a/x-pack/plugin/security/qa/multi-cluster/src/javaRestTest/java/org/elasticsearch/xpack/remotecluster/RemoteClusterSecurityEsqlIT.java +++ b/x-pack/plugin/security/qa/multi-cluster/src/javaRestTest/java/org/elasticsearch/xpack/remotecluster/RemoteClusterSecurityEsqlIT.java @@ -56,11 +56,14 @@ public class RemoteClusterSecurityEsqlIT extends AbstractRemoteClusterSecurityTe fulfillingCluster = ElasticsearchCluster.local() .name("fulfilling-cluster") .nodes(3) + .module("x-pack-autoscaling") .module("x-pack-esql") .module("x-pack-enrich") + .module("x-pack-ml") .module("ingest-common") .apply(commonClusterConfig) .setting("remote_cluster.port", "0") + .setting("xpack.ml.enabled", "false") .setting("xpack.security.remote_cluster_server.ssl.enabled", () -> String.valueOf(SSL_ENABLED_REF.get())) .setting("xpack.security.remote_cluster_server.ssl.key", "remote-cluster.key") .setting("xpack.security.remote_cluster_server.ssl.certificate", "remote-cluster.crt") @@ -73,10 +76,13 @@ public class RemoteClusterSecurityEsqlIT extends AbstractRemoteClusterSecurityTe queryCluster = ElasticsearchCluster.local() .name("query-cluster") + .module("x-pack-autoscaling") .module("x-pack-esql") .module("x-pack-enrich") + .module("x-pack-ml") .module("ingest-common") .apply(commonClusterConfig) + .setting("xpack.ml.enabled", "false") .setting("xpack.security.remote_cluster_client.ssl.enabled", () -> String.valueOf(SSL_ENABLED_REF.get())) .setting("xpack.security.remote_cluster_client.ssl.certificate_authorities", "remote-cluster-ca.crt") .setting("xpack.security.authc.token.enabled", "true") diff --git a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/apikey/ApiKeySingleNodeTests.java b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/apikey/ApiKeySingleNodeTests.java index ccdf7704f221a..bc413ff2001ab 100644 --- a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/apikey/ApiKeySingleNodeTests.java +++ b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/apikey/ApiKeySingleNodeTests.java @@ -318,12 +318,12 @@ public void testGrantApiKeyForUserWithRunAs() throws IOException { clientWithGrantedKey.execute(TransportClusterHealthAction.TYPE, new ClusterHealthRequest(TEST_REQUEST_TIMEOUT)).actionGet(); // If the API key is granted with limiting descriptors, it should not be able to read pipeline if (grantApiKeyRequest.getApiKeyRequest().getRoleDescriptors().isEmpty()) { - clientWithGrantedKey.execute(GetPipelineAction.INSTANCE, new GetPipelineRequest()).actionGet(); + clientWithGrantedKey.execute(GetPipelineAction.INSTANCE, new GetPipelineRequest(TEST_REQUEST_TIMEOUT)).actionGet(); } else { assertThat( expectThrows( ElasticsearchSecurityException.class, - () -> clientWithGrantedKey.execute(GetPipelineAction.INSTANCE, new GetPipelineRequest()).actionGet() + () -> clientWithGrantedKey.execute(GetPipelineAction.INSTANCE, new GetPipelineRequest(TEST_REQUEST_TIMEOUT)).actionGet() ).getMessage(), containsString("unauthorized") ); diff --git a/x-pack/qa/multi-cluster-search-security/legacy-with-basic-license/build.gradle b/x-pack/qa/multi-cluster-search-security/legacy-with-basic-license/build.gradle index b5b8495870259..6d41c4eddf31c 100644 --- a/x-pack/qa/multi-cluster-search-security/legacy-with-basic-license/build.gradle +++ b/x-pack/qa/multi-cluster-search-security/legacy-with-basic-license/build.gradle @@ -23,11 +23,14 @@ def fulfillingCluster = testClusters.register('fulfilling-cluster') { module ':modules:data-streams' module ':x-pack:plugin:mapper-constant-keyword' module ':x-pack:plugin:async-search' + module ':x-pack:plugin:autoscaling' module ':x-pack:plugin:esql-core' module ':x-pack:plugin:esql' + module ':x-pack:plugin:ml' module ':modules:ingest-common' module ':x-pack:plugin:enrich' user username: "test_user", password: "x-pack-test-password" + setting 'xpack.ml.enabled', 'false' } def queryingCluster = testClusters.register('querying-cluster') { @@ -38,13 +41,15 @@ def queryingCluster = testClusters.register('querying-cluster') { module ':modules:data-streams' module ':x-pack:plugin:mapper-constant-keyword' module ':x-pack:plugin:async-search' + module ':x-pack:plugin:autoscaling' module ':x-pack:plugin:esql-core' module ':x-pack:plugin:esql' + module ':x-pack:plugin:ml' module ':modules:ingest-common' module ':x-pack:plugin:enrich' setting 'cluster.remote.connections_per_cluster', "1" user username: "test_user", password: "x-pack-test-password" - + setting 'xpack.ml.enabled', 'false' setting 'cluster.remote.my_remote_cluster.skip_unavailable', 'false' if (proxyMode) { setting 'cluster.remote.my_remote_cluster.mode', 'proxy'