diff --git a/docs/changelog/112652.yaml b/docs/changelog/112652.yaml new file mode 100644 index 0000000000000..c7ddcd4bffdc8 --- /dev/null +++ b/docs/changelog/112652.yaml @@ -0,0 +1,5 @@ +pr: 110399 +summary: "[Inference API] alibabacloud ai search service support chunk infer to support semantic_text field" +area: Machine Learning +type: enhancement +issues: [] diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index 0c48c99b4b81e..2888713358ae6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -24,6 +24,7 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.alibabacloudsearch.AlibabaCloudSearchActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -48,6 +49,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceFields.EMBEDDING_MAX_BATCH_SIZE; public class AlibabaCloudSearchService extends SenderService { public static final String NAME = AlibabaCloudSearchUtils.SERVICE_NAME; @@ -243,7 +245,20 @@ protected void doChunkedInfer( TimeValue timeout, ActionListener> listener ) { - listener.onFailure(new ElasticsearchStatusException("Chunking not supported by the {} service", RestStatus.BAD_REQUEST, NAME)); + if (model instanceof AlibabaCloudSearchModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + AlibabaCloudSearchModel alibabaCloudSearchModel = (AlibabaCloudSearchModel) model; + var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents()); + + var batchedRequests = new EmbeddingRequestChunker(input, EMBEDDING_MAX_BATCH_SIZE, EmbeddingRequestChunker.EmbeddingType.FLOAT) + .batchRequestsWithListeners(listener); + for (var request : batchedRequests) { + var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings, inputType); + action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener()); + } } /** diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceFields.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceFields.java new file mode 100644 index 0000000000000..e110aefb7c75f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceFields.java @@ -0,0 +1,15 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.alibabacloudsearch; + +public class AlibabaCloudSearchServiceFields { + /** + * Taken from https://help.aliyun.com/zh/open-search/search-platform/developer-reference/text-embedding-api-details + */ + static final int EMBEDDING_MAX_BATCH_SIZE = 32; +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java index cc70b61226fe3..13cb6d65b70db 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -11,6 +11,8 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -18,21 +20,28 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.alibabacloudsearch.AlibabaCloudSearchActionVisitor; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchUtils; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsServiceSettingsTests; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsTaskSettingsTests; +import org.hamcrest.CoreMatchers; import org.hamcrest.MatcherAssert; import org.junit.After; import org.junit.Before; import java.io.IOException; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -44,6 +53,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.mockito.Mockito.mock; @@ -156,6 +166,84 @@ public void doInfer( } } + public void testChunkedInfer_Batches() throws IOException { + var input = List.of("foo", "bar"); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + Map serviceSettingsMap = new HashMap<>(); + serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.SERVICE_ID, "service_id"); + serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.HOST, "host"); + serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.WORKSPACE_NAME, "default"); + serviceSettingsMap.put(ServiceFields.DIMENSIONS, 1536); + + Map taskSettingsMap = new HashMap<>(); + + Map secretSettingsMap = new HashMap<>(); + secretSettingsMap.put("api_key", "secret"); + + var model = new AlibabaCloudSearchEmbeddingsModel( + "service", + TaskType.TEXT_EMBEDDING, + AlibabaCloudSearchUtils.SERVICE_NAME, + serviceSettingsMap, + taskSettingsMap, + secretSettingsMap, + null + ) { + public ExecutableAction accept( + AlibabaCloudSearchActionVisitor visitor, + Map taskSettings, + InputType inputType + ) { + return (inferenceInputs, timeout, listener) -> { + InferenceTextEmbeddingFloatResults results = new InferenceTextEmbeddingFloatResults( + List.of( + new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.0123f, -0.0123f }), + new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.0456f, -0.0456f }) + ) + ); + + listener.onResponse(results); + }; + } + }; + + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + input, + new HashMap<>(), + InputType.INGEST, + new ChunkingOptions(null, null), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + assertThat(results, hasSize(2)); + + // first result + { + assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); + var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(floatResult.chunks(), hasSize(1)); + assertEquals(input.get(0), floatResult.chunks().get(0).matchedText()); + assertTrue(Arrays.equals(new float[] { 0.0123f, -0.0123f }, floatResult.chunks().get(0).embedding())); + } + + // second result + { + assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); + var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); + assertThat(floatResult.chunks(), hasSize(1)); + assertEquals(input.get(1), floatResult.chunks().get(0).matchedText()); + assertTrue(Arrays.equals(new float[] { 0.0456f, -0.0456f }, floatResult.chunks().get(0).embedding())); + } + } + } + private Map getRequestConfigMap( Map serviceSettings, Map taskSettings,