diff --git a/docs/changelog/113051.yaml b/docs/changelog/113051.yaml new file mode 100644 index 000000000000..9be68f9f2b03 --- /dev/null +++ b/docs/changelog/113051.yaml @@ -0,0 +1,5 @@ +pr: 113051 +summary: Add Search Inference ID To Semantic Text Mapping +area: Mapping +type: enhancement +issues: [] diff --git a/muted-tests.yml b/muted-tests.yml index f8106f2e4174..868a8c0e716e 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -221,9 +221,6 @@ tests: - class: org.elasticsearch.xpack.inference.rest.ServerSentEventsRestActionListenerTests method: testErrorMidStream issue: https://github.com/elastic/elasticsearch/issues/113179 -- class: org.elasticsearch.xpack.core.security.authz.RoleDescriptorTests - method: testHasPrivilegesOtherThanIndex - issue: https://github.com/elastic/elasticsearch/issues/113202 - class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT method: test {categorize.Categorize SYNC} issue: https://github.com/elastic/elasticsearch/issues/113054 @@ -305,6 +302,9 @@ tests: - class: org.elasticsearch.xpack.ml.integration.MlJobIT method: testOutOfOrderData issue: https://github.com/elastic/elasticsearch/issues/113477 +- class: org.elasticsearch.upgrades.UpgradeClusterClientYamlTestSuiteIT + method: test {p0=mixed_cluster/100_analytics_usage/Basic test for usage stats on analytics indices} + issue: https://github.com/elastic/elasticsearch/issues/113497 # Examples: # diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 2cc50a85668c..6b1d73a58c87 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -223,6 +223,7 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_AGGREGATION_OPERATOR_STATUS_FINISH_NANOS = def(8_747_00_0); public static final TransportVersion ML_TELEMETRY_MEMORY_ADDED = def(8_748_00_0); public static final TransportVersion ILM_ADD_SEARCHABLE_SNAPSHOT_TOTAL_SHARDS_PER_NODE = def(8_749_00_0); + public static final TransportVersion SEMANTIC_TEXT_SEARCH_INFERENCE_ID = def(8_750_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java index be0943f8f306..271c60e829a8 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java @@ -23,6 +23,8 @@ import java.util.List; import java.util.Objects; +import static org.elasticsearch.TransportVersions.SEMANTIC_TEXT_SEARCH_INFERENCE_ID; + /** * Contains inference field data for fields. * As inference is done in the coordinator node to avoid re-doing it at shard / replica level, the coordinator needs to check for the need @@ -32,21 +34,33 @@ */ public final class InferenceFieldMetadata implements SimpleDiffable, ToXContentFragment { private static final String INFERENCE_ID_FIELD = "inference_id"; + private static final String SEARCH_INFERENCE_ID_FIELD = "search_inference_id"; private static final String SOURCE_FIELDS_FIELD = "source_fields"; private final String name; private final String inferenceId; + private final String searchInferenceId; private final String[] sourceFields; public InferenceFieldMetadata(String name, String inferenceId, String[] sourceFields) { + this(name, inferenceId, inferenceId, sourceFields); + } + + public InferenceFieldMetadata(String name, String inferenceId, String searchInferenceId, String[] sourceFields) { this.name = Objects.requireNonNull(name); this.inferenceId = Objects.requireNonNull(inferenceId); + this.searchInferenceId = Objects.requireNonNull(searchInferenceId); this.sourceFields = Objects.requireNonNull(sourceFields); } public InferenceFieldMetadata(StreamInput input) throws IOException { this.name = input.readString(); this.inferenceId = input.readString(); + if (input.getTransportVersion().onOrAfter(SEMANTIC_TEXT_SEARCH_INFERENCE_ID)) { + this.searchInferenceId = input.readString(); + } else { + this.searchInferenceId = this.inferenceId; + } this.sourceFields = input.readStringArray(); } @@ -54,6 +68,9 @@ public InferenceFieldMetadata(StreamInput input) throws IOException { public void writeTo(StreamOutput out) throws IOException { out.writeString(name); out.writeString(inferenceId); + if (out.getTransportVersion().onOrAfter(SEMANTIC_TEXT_SEARCH_INFERENCE_ID)) { + out.writeString(searchInferenceId); + } out.writeStringArray(sourceFields); } @@ -64,12 +81,13 @@ public boolean equals(Object o) { InferenceFieldMetadata that = (InferenceFieldMetadata) o; return Objects.equals(name, that.name) && Objects.equals(inferenceId, that.inferenceId) + && Objects.equals(searchInferenceId, that.searchInferenceId) && Arrays.equals(sourceFields, that.sourceFields); } @Override public int hashCode() { - int result = Objects.hash(name, inferenceId); + int result = Objects.hash(name, inferenceId, searchInferenceId); result = 31 * result + Arrays.hashCode(sourceFields); return result; } @@ -82,6 +100,10 @@ public String getInferenceId() { return inferenceId; } + public String getSearchInferenceId() { + return searchInferenceId; + } + public String[] getSourceFields() { return sourceFields; } @@ -94,6 +116,9 @@ public static Diff readDiffFrom(StreamInput in) throws I public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(name); builder.field(INFERENCE_ID_FIELD, inferenceId); + if (searchInferenceId.equals(inferenceId) == false) { + builder.field(SEARCH_INFERENCE_ID_FIELD, searchInferenceId); + } builder.array(SOURCE_FIELDS_FIELD, sourceFields); return builder.endObject(); } @@ -106,6 +131,7 @@ public static InferenceFieldMetadata fromXContent(XContentParser parser) throws String currentFieldName = null; String inferenceId = null; + String searchInferenceId = null; List inputFields = new ArrayList<>(); while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { if (token == XContentParser.Token.FIELD_NAME) { @@ -113,6 +139,8 @@ public static InferenceFieldMetadata fromXContent(XContentParser parser) throws } else if (token == XContentParser.Token.VALUE_STRING) { if (INFERENCE_ID_FIELD.equals(currentFieldName)) { inferenceId = parser.text(); + } else if (SEARCH_INFERENCE_ID_FIELD.equals(currentFieldName)) { + searchInferenceId = parser.text(); } } else if (token == XContentParser.Token.START_ARRAY) { if (SOURCE_FIELDS_FIELD.equals(currentFieldName)) { @@ -128,6 +156,11 @@ public static InferenceFieldMetadata fromXContent(XContentParser parser) throws parser.skipChildren(); } } - return new InferenceFieldMetadata(name, inferenceId, inputFields.toArray(String[]::new)); + return new InferenceFieldMetadata( + name, + inferenceId, + searchInferenceId == null ? inferenceId : searchInferenceId, + inputFields.toArray(String[]::new) + ); } } diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java index 6107246cf8ff..2d5805696320 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java @@ -61,13 +61,15 @@ protected boolean supportsUnknownFields() { private static InferenceFieldMetadata createTestItem() { String name = randomAlphaOfLengthBetween(3, 10); String inferenceId = randomIdentifier(); + String searchInferenceId = randomIdentifier(); String[] inputFields = generateRandomStringArray(5, 10, false, false); - return new InferenceFieldMetadata(name, inferenceId, inputFields); + return new InferenceFieldMetadata(name, inferenceId, searchInferenceId, inputFields); } public void testNullCtorArgsThrowException() { - assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata(null, "inferenceId", new String[0])); - assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", null, new String[0])); - assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", "inferenceId", null)); + assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata(null, "inferenceId", "searchInferenceId", new String[0])); + assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", null, "searchInferenceId", new String[0])); + assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", "inferenceId", null, new String[0])); + assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", "inferenceId", "searchInferenceId", null)); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/RoleDescriptorTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/RoleDescriptorTests.java index d7b9f9ddd5b5..8e1bc7af1bdc 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/RoleDescriptorTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/RoleDescriptorTests.java @@ -1341,7 +1341,8 @@ public void testHasPrivilegesOtherThanIndex() { || roleDescriptor.hasConfigurableClusterPrivileges() || roleDescriptor.hasApplicationPrivileges() || roleDescriptor.hasRunAs() - || roleDescriptor.hasRemoteIndicesPrivileges(); + || roleDescriptor.hasRemoteIndicesPrivileges() + || roleDescriptor.hasWorkflowsRestriction(); assertThat(roleDescriptor.hasUnsupportedPrivilegesInsideAPIKeyConnectedRemoteCluster(), equalTo(expected)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java index 12a32ecdc6d4..fd330a8cf6cc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java @@ -9,6 +9,7 @@ import org.elasticsearch.features.FeatureSpecification; import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder; import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder; @@ -23,7 +24,8 @@ public class InferenceFeatures implements FeatureSpecification { public Set getFeatures() { return Set.of( TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_RETRIEVER_SUPPORTED, - RandomRankRetrieverBuilder.RANDOM_RERANKER_RETRIEVER_SUPPORTED + RandomRankRetrieverBuilder.RANDOM_RERANKER_RETRIEVER_SUPPORTED, + SemanticTextFieldMapper.SEMANTIC_TEXT_SEARCH_INFERENCE_ID ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 81dfba769136..0483296cd2c6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -18,6 +18,7 @@ import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Tuple; +import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.fielddata.FieldDataContext; @@ -79,6 +80,8 @@ * A {@link FieldMapper} for semantic text fields. */ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFieldMapper { + public static final NodeFeature SEMANTIC_TEXT_SEARCH_INFERENCE_ID = new NodeFeature("semantic_text.search_inference_id"); + public static final String CONTENT_TYPE = "semantic_text"; private final IndexSettings indexSettings; @@ -103,6 +106,13 @@ public static class Builder extends FieldMapper.Builder { } }); + private final Parameter searchInferenceId = Parameter.stringParam( + "search_inference_id", + true, + mapper -> ((SemanticTextFieldType) mapper.fieldType()).searchInferenceId, + null + ).acceptsNull(); + private final Parameter modelSettings = new Parameter<>( "model_settings", true, @@ -117,6 +127,17 @@ public static class Builder extends FieldMapper.Builder { private Function inferenceFieldBuilder; + public static Builder from(SemanticTextFieldMapper mapper) { + Builder builder = new Builder( + mapper.leafName(), + mapper.fieldType().indexVersionCreated, + mapper.fieldType().getChunksField().bitsetProducer(), + mapper.indexSettings + ); + builder.init(mapper); + return builder; + } + public Builder( String name, IndexVersion indexVersionCreated, @@ -140,6 +161,11 @@ public Builder setInferenceId(String id) { return this; } + public Builder setSearchInferenceId(String id) { + this.searchInferenceId.setValue(id); + return this; + } + public Builder setModelSettings(SemanticTextField.ModelSettings value) { this.modelSettings.setValue(value); return this; @@ -147,15 +173,17 @@ public Builder setModelSettings(SemanticTextField.ModelSettings value) { @Override protected Parameter[] getParameters() { - return new Parameter[] { inferenceId, modelSettings, meta }; + return new Parameter[] { inferenceId, searchInferenceId, modelSettings, meta }; } @Override protected void merge(FieldMapper mergeWith, Conflicts conflicts, MapperMergeContext mapperMergeContext) { - super.merge(mergeWith, conflicts, mapperMergeContext); + SemanticTextFieldMapper semanticMergeWith = (SemanticTextFieldMapper) mergeWith; + semanticMergeWith = copySettings(semanticMergeWith, mapperMergeContext); + + super.merge(semanticMergeWith, conflicts, mapperMergeContext); conflicts.check(); - var semanticMergeWith = (SemanticTextFieldMapper) mergeWith; - var context = mapperMergeContext.createChildContext(mergeWith.leafName(), ObjectMapper.Dynamic.FALSE); + var context = mapperMergeContext.createChildContext(semanticMergeWith.leafName(), ObjectMapper.Dynamic.FALSE); var inferenceField = inferenceFieldBuilder.apply(context.getMapperBuilderContext()); var mergedInferenceField = inferenceField.merge(semanticMergeWith.fieldType().getInferenceField(), context); inferenceFieldBuilder = c -> mergedInferenceField; @@ -181,6 +209,7 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) { new SemanticTextFieldType( fullName, inferenceId.getValue(), + searchInferenceId.getValue(), modelSettings.getValue(), inferenceField, indexVersionCreated, @@ -190,6 +219,25 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) { indexSettings ); } + + /** + * As necessary, copy settings from this builder to the passed-in mapper. + * Used to preserve {@link SemanticTextField.ModelSettings} when updating a semantic text mapping to one where the model settings + * are not specified. + * + * @param mapper The mapper + * @return A mapper with the copied settings applied + */ + private SemanticTextFieldMapper copySettings(SemanticTextFieldMapper mapper, MapperMergeContext mapperMergeContext) { + SemanticTextFieldMapper returnedMapper = mapper; + if (mapper.fieldType().getModelSettings() == null) { + Builder builder = from(mapper); + builder.setModelSettings(modelSettings.getValue()); + returnedMapper = builder.build(mapperMergeContext.getMapperBuilderContext()); + } + + return returnedMapper; + } } private SemanticTextFieldMapper( @@ -211,9 +259,7 @@ public Iterator iterator() { @Override public FieldMapper.Builder getMergeBuilder() { - return new Builder(leafName(), fieldType().indexVersionCreated, fieldType().getChunksField().bitsetProducer(), indexSettings).init( - this - ); + return Builder.from(this); } @Override @@ -267,7 +313,7 @@ protected void parseCreateField(DocumentParserContext context) throws IOExceptio } } else { Conflicts conflicts = new Conflicts(fullFieldName); - canMergeModelSettings(field.inference().modelSettings(), fieldType().getModelSettings(), conflicts); + canMergeModelSettings(fieldType().getModelSettings(), field.inference().modelSettings(), conflicts); try { conflicts.check(); } catch (Exception exc) { @@ -316,7 +362,7 @@ public InferenceFieldMetadata getMetadata(Set sourcePaths) { String[] copyFields = sourcePaths.toArray(String[]::new); // ensure consistent order Arrays.sort(copyFields); - return new InferenceFieldMetadata(fullPath(), fieldType().inferenceId, copyFields); + return new InferenceFieldMetadata(fullPath(), fieldType().getInferenceId(), fieldType().getSearchInferenceId(), copyFields); } @Override @@ -335,6 +381,7 @@ public Object getOriginalValue(Map sourceAsMap) { public static class SemanticTextFieldType extends SimpleMappedFieldType { private final String inferenceId; + private final String searchInferenceId; private final SemanticTextField.ModelSettings modelSettings; private final ObjectMapper inferenceField; private final IndexVersion indexVersionCreated; @@ -342,6 +389,7 @@ public static class SemanticTextFieldType extends SimpleMappedFieldType { public SemanticTextFieldType( String name, String inferenceId, + String searchInferenceId, SemanticTextField.ModelSettings modelSettings, ObjectMapper inferenceField, IndexVersion indexVersionCreated, @@ -349,6 +397,7 @@ public SemanticTextFieldType( ) { super(name, true, false, false, TextSearchInfo.NONE, meta); this.inferenceId = inferenceId; + this.searchInferenceId = searchInferenceId; this.modelSettings = modelSettings; this.inferenceField = inferenceField; this.indexVersionCreated = indexVersionCreated; @@ -363,6 +412,10 @@ public String getInferenceId() { return inferenceId; } + public String getSearchInferenceId() { + return searchInferenceId == null ? inferenceId : searchInferenceId; + } + public SemanticTextField.ModelSettings getModelSettings() { return modelSettings; } @@ -428,14 +481,7 @@ public QueryBuilder semanticQuery(InferenceResults inferenceResults, float boost case SPARSE_EMBEDDING -> { if (inferenceResults instanceof TextExpansionResults == false) { throw new IllegalArgumentException( - "Field [" - + name() - + "] expected query inference results to be of type [" - + TextExpansionResults.NAME - + "]," - + " got [" - + inferenceResults.getWriteableName() - + "]. Has the inference endpoint configuration changed?" + generateQueryInferenceResultsTypeMismatchMessage(inferenceResults, TextExpansionResults.NAME) ); } @@ -454,14 +500,7 @@ public QueryBuilder semanticQuery(InferenceResults inferenceResults, float boost case TEXT_EMBEDDING -> { if (inferenceResults instanceof MlTextEmbeddingResults == false) { throw new IllegalArgumentException( - "Field [" - + name() - + "] expected query inference results to be of type [" - + MlTextEmbeddingResults.NAME - + "]," - + " got [" - + inferenceResults.getWriteableName() - + "]. Has the inference endpoint configuration changed?" + generateQueryInferenceResultsTypeMismatchMessage(inferenceResults, MlTextEmbeddingResults.NAME) ); } @@ -469,13 +508,7 @@ public QueryBuilder semanticQuery(InferenceResults inferenceResults, float boost float[] inference = textEmbeddingResults.getInferenceAsFloat(); if (inference.length != modelSettings.dimensions()) { throw new IllegalArgumentException( - "Field [" - + name() - + "] expected query inference results with " - + modelSettings.dimensions() - + " dimensions, got " - + inference.length - + " dimensions. Has the inference endpoint configuration changed?" + generateDimensionCountMismatchMessage(inference.length, modelSettings.dimensions()) ); } @@ -484,7 +517,7 @@ public QueryBuilder semanticQuery(InferenceResults inferenceResults, float boost default -> throw new IllegalStateException( "Field [" + name() - + "] configured to use an inference endpoint with an unsupported task type [" + + "] is configured to use an inference endpoint with an unsupported task type [" + modelSettings.taskType() + "]" ); @@ -493,6 +526,51 @@ public QueryBuilder semanticQuery(InferenceResults inferenceResults, float boost return new NestedQueryBuilder(nestedFieldPath, childQueryBuilder, ScoreMode.Max).boost(boost).queryName(queryName); } + + private String generateQueryInferenceResultsTypeMismatchMessage(InferenceResults inferenceResults, String expectedResultsType) { + StringBuilder sb = new StringBuilder( + "Field [" + + name() + + "] expected query inference results to be of type [" + + expectedResultsType + + "]," + + " got [" + + inferenceResults.getWriteableName() + + "]." + ); + + return generateInvalidQueryInferenceResultsMessage(sb); + } + + private String generateDimensionCountMismatchMessage(int inferenceDimCount, int expectedDimCount) { + StringBuilder sb = new StringBuilder( + "Field [" + + name() + + "] expected query inference results with " + + expectedDimCount + + " dimensions, got " + + inferenceDimCount + + " dimensions." + ); + + return generateInvalidQueryInferenceResultsMessage(sb); + } + + private String generateInvalidQueryInferenceResultsMessage(StringBuilder baseMessageBuilder) { + if (searchInferenceId != null && searchInferenceId.equals(inferenceId) == false) { + baseMessageBuilder.append( + " Is the search inference endpoint [" + + searchInferenceId + + "] compatible with the inference endpoint [" + + inferenceId + + "]?" + ); + } else { + baseMessageBuilder.append(" Has the configuration for inference endpoint [" + inferenceId + "] changed?"); + } + + return baseMessageBuilder.toString(); + } } private static ObjectMapper createInferenceField( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index 7f21f94d3327..9f7fcb1ef407 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -284,7 +284,7 @@ private static String getInferenceIdForForField(Collection indexM String inferenceId = null; for (IndexMetadata indexMetadata : indexMetadataCollection) { InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(fieldName); - String indexInferenceId = inferenceFieldMetadata != null ? inferenceFieldMetadata.getInferenceId() : null; + String indexInferenceId = inferenceFieldMetadata != null ? inferenceFieldMetadata.getSearchInferenceId() : null; if (indexInferenceId != null) { if (inferenceId != null && inferenceId.equals(indexInferenceId) == false) { throw new IllegalArgumentException("Field [" + fieldName + "] has multiple inference IDs associated with it"); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index bb0691c69117..1697b33fedd9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -23,6 +23,7 @@ import org.apache.lucene.search.join.QueryBitSetProducer; import org.apache.lucene.search.join.ScoreMode; import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest; +import org.elasticsearch.common.CheckedBiFunction; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.compress.CompressedXContent; @@ -140,6 +141,7 @@ public MappedFieldType getMappedFieldType() { "fake-inference-id", null, null, + null, IndexVersion.current(), Map.of() ); @@ -210,13 +212,28 @@ public void testUpdatesToInferenceIdNotSupported() throws IOException { public void testDynamicUpdate() throws IOException { final String fieldName = "semantic"; final String inferenceId = "test_service"; + final String searchInferenceId = "search_test_service"; - MapperService mapperService = mapperServiceForFieldWithModelSettings( - fieldName, - inferenceId, - new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, null) - ); - assertSemanticTextField(mapperService, fieldName, true); + { + MapperService mapperService = mapperServiceForFieldWithModelSettings( + fieldName, + inferenceId, + new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, null) + ); + assertSemanticTextField(mapperService, fieldName, true); + assertSearchInferenceId(mapperService, fieldName, inferenceId); + } + + { + MapperService mapperService = mapperServiceForFieldWithModelSettings( + fieldName, + inferenceId, + searchInferenceId, + new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, null) + ); + assertSemanticTextField(mapperService, fieldName, true); + assertSearchInferenceId(mapperService, fieldName, searchInferenceId); + } } public void testUpdateModelSettings() throws IOException { @@ -260,19 +277,11 @@ public void testUpdateModelSettings() throws IOException { assertSemanticTextField(mapperService, fieldName, true); } { - Exception exc = expectThrows( - IllegalArgumentException.class, - () -> merge( - mapperService, - mapping( - b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "test_model").endObject() - ) - ) - ); - assertThat( - exc.getMessage(), - containsString("Cannot update parameter [model_settings] " + "from [task_type=sparse_embedding] to [null]") + merge( + mapperService, + mapping(b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "test_model").endObject()) ); + assertSemanticTextField(mapperService, fieldName, true); } { Exception exc = expectThrows( @@ -305,7 +314,60 @@ public void testUpdateModelSettings() throws IOException { } } - static void assertSemanticTextField(MapperService mapperService, String fieldName, boolean expectedModelSettings) { + public void testUpdateSearchInferenceId() throws IOException { + final String inferenceId = "test_inference_id"; + final String searchInferenceId1 = "test_search_inference_id_1"; + final String searchInferenceId2 = "test_search_inference_id_2"; + + CheckedBiFunction buildMapping = (f, sid) -> mapping(b -> { + b.startObject(f).field("type", "semantic_text").field("inference_id", inferenceId); + if (sid != null) { + b.field("search_inference_id", sid); + } + b.endObject(); + }); + + for (int depth = 1; depth < 5; depth++) { + String fieldName = randomFieldName(depth); + MapperService mapperService = createMapperService(buildMapping.apply(fieldName, null)); + assertSemanticTextField(mapperService, fieldName, false); + assertSearchInferenceId(mapperService, fieldName, inferenceId); + + merge(mapperService, buildMapping.apply(fieldName, searchInferenceId1)); + assertSemanticTextField(mapperService, fieldName, false); + assertSearchInferenceId(mapperService, fieldName, searchInferenceId1); + + merge(mapperService, buildMapping.apply(fieldName, searchInferenceId2)); + assertSemanticTextField(mapperService, fieldName, false); + assertSearchInferenceId(mapperService, fieldName, searchInferenceId2); + + merge(mapperService, buildMapping.apply(fieldName, null)); + assertSemanticTextField(mapperService, fieldName, false); + assertSearchInferenceId(mapperService, fieldName, inferenceId); + + mapperService = mapperServiceForFieldWithModelSettings( + fieldName, + inferenceId, + new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, null) + ); + assertSemanticTextField(mapperService, fieldName, true); + assertSearchInferenceId(mapperService, fieldName, inferenceId); + + merge(mapperService, buildMapping.apply(fieldName, searchInferenceId1)); + assertSemanticTextField(mapperService, fieldName, true); + assertSearchInferenceId(mapperService, fieldName, searchInferenceId1); + + merge(mapperService, buildMapping.apply(fieldName, searchInferenceId2)); + assertSemanticTextField(mapperService, fieldName, true); + assertSearchInferenceId(mapperService, fieldName, searchInferenceId2); + + merge(mapperService, buildMapping.apply(fieldName, null)); + assertSemanticTextField(mapperService, fieldName, true); + assertSearchInferenceId(mapperService, fieldName, inferenceId); + } + } + + private static void assertSemanticTextField(MapperService mapperService, String fieldName, boolean expectedModelSettings) { Mapper mapper = mapperService.mappingLookup().getMapper(fieldName); assertNotNull(mapper); assertThat(mapper, instanceOf(SemanticTextFieldMapper.class)); @@ -347,21 +409,34 @@ static void assertSemanticTextField(MapperService mapperService, String fieldNam } } + private static void assertSearchInferenceId(MapperService mapperService, String fieldName, String expectedSearchInferenceId) { + var fieldType = mapperService.fieldType(fieldName); + assertNotNull(fieldType); + assertThat(fieldType, instanceOf(SemanticTextFieldMapper.SemanticTextFieldType.class)); + SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType = (SemanticTextFieldMapper.SemanticTextFieldType) fieldType; + assertEquals(expectedSearchInferenceId, semanticTextFieldType.getSearchInferenceId()); + } + public void testSuccessfulParse() throws IOException { for (int depth = 1; depth < 4; depth++) { final String fieldName1 = randomFieldName(depth); final String fieldName2 = randomFieldName(depth + 1); + final String searchInferenceId = randomAlphaOfLength(8); + final boolean setSearchInferenceId = randomBoolean(); Model model1 = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); Model model2 = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); XContentBuilder mapping = mapping(b -> { - addSemanticTextMapping(b, fieldName1, model1.getInferenceEntityId()); - addSemanticTextMapping(b, fieldName2, model2.getInferenceEntityId()); + addSemanticTextMapping(b, fieldName1, model1.getInferenceEntityId(), setSearchInferenceId ? searchInferenceId : null); + addSemanticTextMapping(b, fieldName2, model2.getInferenceEntityId(), setSearchInferenceId ? searchInferenceId : null); }); MapperService mapperService = createMapperService(mapping); - SemanticTextFieldMapperTests.assertSemanticTextField(mapperService, fieldName1, false); - SemanticTextFieldMapperTests.assertSemanticTextField(mapperService, fieldName2, false); + assertSemanticTextField(mapperService, fieldName1, false); + assertSearchInferenceId(mapperService, fieldName1, setSearchInferenceId ? searchInferenceId : model1.getInferenceEntityId()); + assertSemanticTextField(mapperService, fieldName2, false); + assertSearchInferenceId(mapperService, fieldName2, setSearchInferenceId ? searchInferenceId : model2.getInferenceEntityId()); + DocumentMapper documentMapper = mapperService.documentMapper(); ParsedDocument doc = documentMapper.parse( source( @@ -449,7 +524,7 @@ public void testSuccessfulParse() throws IOException { } public void testMissingInferenceId() throws IOException { - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); + DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id", null))); IllegalArgumentException ex = expectThrows( DocumentParsingException.class, IllegalArgumentException.class, @@ -468,7 +543,7 @@ public void testMissingInferenceId() throws IOException { } public void testMissingModelSettings() throws IOException { - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); + DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id", null))); IllegalArgumentException ex = expectThrows( DocumentParsingException.class, IllegalArgumentException.class, @@ -480,7 +555,7 @@ public void testMissingModelSettings() throws IOException { } public void testMissingTaskType() throws IOException { - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); + DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id", null))); IllegalArgumentException ex = expectThrows( DocumentParsingException.class, IllegalArgumentException.class, @@ -540,12 +615,24 @@ private MapperService mapperServiceForFieldWithModelSettings( String inferenceId, SemanticTextField.ModelSettings modelSettings ) throws IOException { + return mapperServiceForFieldWithModelSettings(fieldName, inferenceId, null, modelSettings); + } + + private MapperService mapperServiceForFieldWithModelSettings( + String fieldName, + String inferenceId, + String searchInferenceId, + SemanticTextField.ModelSettings modelSettings + ) throws IOException { + String mappingParams = "type=semantic_text,inference_id=" + inferenceId; + if (searchInferenceId != null) { + mappingParams += ",search_inference_id=" + searchInferenceId; + } + MapperService mapperService = createMapperService(mapping(b -> {})); mapperService.merge( "_doc", - new CompressedXContent( - Strings.toString(PutMappingRequest.simpleMapping(fieldName, "type=semantic_text,inference_id=" + inferenceId)) - ), + new CompressedXContent(Strings.toString(PutMappingRequest.simpleMapping(fieldName, mappingParams))), MapperService.MergeReason.MAPPING_UPDATE ); @@ -615,10 +702,18 @@ protected void assertExistsQuery(MappedFieldType fieldType, Query query, LuceneD assertThat(query, instanceOf(MatchNoDocsQuery.class)); } - private static void addSemanticTextMapping(XContentBuilder mappingBuilder, String fieldName, String modelId) throws IOException { + private static void addSemanticTextMapping( + XContentBuilder mappingBuilder, + String fieldName, + String inferenceId, + String searchInferenceId + ) throws IOException { mappingBuilder.startObject(fieldName); mappingBuilder.field("type", SemanticTextFieldMapper.CONTENT_TYPE); - mappingBuilder.field("inference_id", modelId); + mappingBuilder.field("inference_id", inferenceId); + if (searchInferenceId != null) { + mappingBuilder.field("search_inference_id", searchInferenceId); + } mappingBuilder.endObject(); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java index c2b99923bae6..f54ce8918307 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java @@ -79,9 +79,11 @@ public class SemanticQueryBuilderTests extends AbstractQueryTestCase randomFrom(DenseVectorFieldMapper.ElementType.values()) ); // TODO: Support bit elements once KNN bit vector queries are available + useSearchInferenceId = randomBoolean(); } @Override @@ -126,11 +129,14 @@ protected Settings createTestIndexSettings() { @Override protected void initializeAdditionalMappings(MapperService mapperService) throws IOException { + String mappingConfig = "type=semantic_text,inference_id=" + INFERENCE_ID; + if (useSearchInferenceId) { + mappingConfig += ",search_inference_id=" + SEARCH_INFERENCE_ID; + } + mapperService.merge( "_doc", - new CompressedXContent( - Strings.toString(PutMappingRequest.simpleMapping(SEMANTIC_TEXT_FIELD, "type=semantic_text,inference_id=" + INFERENCE_ID)) - ), + new CompressedXContent(Strings.toString(PutMappingRequest.simpleMapping(SEMANTIC_TEXT_FIELD, mappingConfig))), MapperService.MergeReason.MAPPING_UPDATE ); @@ -244,6 +250,7 @@ protected Object simulateMethod(Method method, Object[] args) { InferenceAction.Request request = (InferenceAction.Request) args[1]; assertThat(request.getTaskType(), equalTo(TaskType.ANY)); assertThat(request.getInputType(), equalTo(InputType.SEARCH)); + assertThat(request.getInferenceEntityId(), equalTo(useSearchInferenceId ? SEARCH_INFERENCE_ID : INFERENCE_ID)); List input = request.getInput(); assertThat(input.size(), equalTo(1)); diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/40_semantic_text_query.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/40_semantic_text_query.yml index 932ee4854f44..2070b3752791 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/40_semantic_text_query.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/40_semantic_text_query.yml @@ -18,6 +18,21 @@ setup: } } + - do: + inference.put: + task_type: sparse_embedding + inference_id: sparse-inference-id-2 + body: > + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + - do: inference.put: task_type: text_embedding @@ -35,6 +50,23 @@ setup: } } + - do: + inference.put: + task_type: text_embedding + inference_id: dense-inference-id-2 + body: > + { + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_model", + "dimensions": 10, + "api_key": "abc64", + "similarity": "COSINE" + }, + "task_settings": { + } + } + - do: indices.create: index: test-sparse-index @@ -142,6 +174,51 @@ setup: - match: { hits.hits.0._id: "doc_1" } - length: { hits.hits.0._source.inference_field.inference.chunks: 1 } +--- +"Query using a sparse embedding model via a search inference ID": + - requires: + cluster_features: "semantic_text.search_inference_id" + reason: search_inference_id introduced in 8.16.0 + + - skip: + features: [ "headers", "close_to" ] + + - do: + indices.put_mapping: + index: test-sparse-index + body: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + search_inference_id: sparse-inference-id-2 + + - do: + index: + index: test-sparse-index + id: doc_1 + body: + inference_field: [ "inference test", "another inference test" ] + non_inference_field: "non inference test" + refresh: true + + - do: + headers: + # Force JSON content type so that we use a parser that interprets the floating-point score as a double + Content-Type: application/json + search: + index: test-sparse-index + body: + query: + semantic: + field: "inference_field" + query: "inference test" + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - close_to: { hits.hits.0._score: { value: 3.7837332e17, error: 1e10 } } + - length: { hits.hits.0._source.inference_field.inference.chunks: 2 } + --- "Query using a dense embedding model": - skip: @@ -286,6 +363,51 @@ setup: - close_to: { hits.hits.0._score: { value: 1.0, error: 0.0001 } } - length: { hits.hits.0._source.inference_field.inference.chunks: 2 } +--- +"Query using a dense embedding model via a search inference ID": + - requires: + cluster_features: "semantic_text.search_inference_id" + reason: search_inference_id introduced in 8.16.0 + + - skip: + features: [ "headers", "close_to" ] + + - do: + indices.put_mapping: + index: test-dense-index + body: + properties: + inference_field: + type: semantic_text + inference_id: dense-inference-id + search_inference_id: dense-inference-id-2 + + - do: + index: + index: test-dense-index + id: doc_1 + body: + inference_field: ["inference test", "another inference test"] + non_inference_field: "non inference test" + refresh: true + + - do: + headers: + # Force JSON content type so that we use a parser that interprets the floating-point score as a double + Content-Type: application/json + search: + index: test-dense-index + body: + query: + semantic: + field: "inference_field" + query: "inference test" + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - close_to: { hits.hits.0._score: { value: 1.0, error: 0.0001 } } + - length: { hits.hits.0._source.inference_field.inference.chunks: 2 } + --- "Apply boost and query name": - skip: @@ -581,3 +703,139 @@ setup: - match: { error.type: "resource_not_found_exception" } - match: { error.reason: "Inference endpoint not found [invalid-inference-id]" } + +--- +"Query a field with a search inference ID that uses the wrong task type": + - requires: + cluster_features: "semantic_text.search_inference_id" + reason: search_inference_id introduced in 8.16.0 + + - do: + indices.put_mapping: + index: test-sparse-index + body: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + search_inference_id: dense-inference-id + + - do: + index: + index: test-sparse-index + id: doc_1 + body: + inference_field: [ "inference test", "another inference test" ] + non_inference_field: "non inference test" + refresh: true + + - do: + catch: bad_request + search: + index: test-sparse-index + body: + query: + semantic: + field: "inference_field" + query: "inference test" + + - match: { error.caused_by.type: "illegal_argument_exception" } + - match: { error.caused_by.reason: "Field [inference_field] expected query inference results to be of type + [text_expansion_result], got [text_embedding_result]. Is the search inference + endpoint [dense-inference-id] compatible with the inference endpoint + [sparse-inference-id]?" } + +--- +"Query a field with a search inference ID that uses the wrong dimension count": + - requires: + cluster_features: "semantic_text.search_inference_id" + reason: search_inference_id introduced in 8.16.0 + + - do: + inference.put: + task_type: text_embedding + inference_id: dense-inference-id-20-dims + body: > + { + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_model", + "dimensions": 20, + "api_key": "abc64", + "similarity": "COSINE" + }, + "task_settings": { + } + } + + - do: + indices.put_mapping: + index: test-dense-index + body: + properties: + inference_field: + type: semantic_text + inference_id: dense-inference-id + search_inference_id: dense-inference-id-20-dims + + - do: + index: + index: test-dense-index + id: doc_1 + body: + inference_field: ["inference test", "another inference test"] + non_inference_field: "non inference test" + refresh: true + + - do: + catch: bad_request + search: + index: test-dense-index + body: + query: + semantic: + field: "inference_field" + query: "inference test" + + - match: { error.caused_by.type: "illegal_argument_exception" } + - match: { error.caused_by.reason: "Field [inference_field] expected query inference results with 10 dimensions, got + 20 dimensions. Is the search inference endpoint [dense-inference-id-20-dims] + compatible with the inference endpoint [dense-inference-id]?" } + +--- +"Query a field with an invalid search inference ID": + - requires: + cluster_features: "semantic_text.search_inference_id" + reason: search_inference_id introduced in 8.16.0 + + - do: + indices.put_mapping: + index: test-dense-index + body: + properties: + inference_field: + type: semantic_text + inference_id: dense-inference-id + search_inference_id: invalid-inference-id + + - do: + index: + index: test-dense-index + id: doc_1 + body: + inference_field: [ "inference test", "another inference test" ] + non_inference_field: "non inference test" + refresh: true + + - do: + catch: missing + search: + index: test-dense-index + body: + query: + semantic: + field: "inference_field" + query: "inference test" + + - match: { error.type: "resource_not_found_exception" } + - match: { error.reason: "Inference endpoint not found [invalid-inference-id]" } diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml index f6a707391460..51595d40737a 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml @@ -112,8 +112,8 @@ setup: - match: { error.caused_by.type: "illegal_argument_exception" } - match: { error.caused_by.reason: "Field [inference_field] expected query inference results to be of type - [text_expansion_result], got [text_embedding_result]. Has the inference endpoint - configuration changed?" } + [text_expansion_result], got [text_embedding_result]. Has the configuration for + inference endpoint [sparse-inference-id] changed?" } --- "text_embedding changed to sparse_embedding": @@ -149,8 +149,8 @@ setup: - match: { error.caused_by.type: "illegal_argument_exception" } - match: { error.caused_by.reason: "Field [inference_field] expected query inference results to be of type - [text_embedding_result], got [text_expansion_result]. Has the inference endpoint - configuration changed?" } + [text_embedding_result], got [text_expansion_result]. Has the configuration for + inference endpoint [dense-inference-id] changed?" } --- "text_embedding dimension count changed": @@ -188,4 +188,5 @@ setup: - match: { error.caused_by.type: "illegal_argument_exception" } - match: { error.caused_by.reason: "Field [inference_field] expected query inference results with 10 dimensions, got - 20 dimensions. Has the inference endpoint configuration changed?" } + 20 dimensions. Has the configuration for inference endpoint [dense-inference-id] + changed?" }