Skip to content

Commit

Permalink
[Inference API] alibabacloud ai search service support chunk infer to…
Browse files Browse the repository at this point in the history
… support semantic_text field (elastic#112652)
  • Loading branch information
weizijun authored and davidkyle committed Sep 13, 2024
1 parent 12b4900 commit f33f74b
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 1 deletion.
5 changes: 5 additions & 0 deletions docs/changelog/112652.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 110399
summary: "[Inference API] alibabacloud ai search service support chunk infer to support semantic_text field"
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.alibabacloudsearch.AlibabaCloudSearchActionCreator;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
Expand All @@ -48,6 +49,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
import static org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceFields.EMBEDDING_MAX_BATCH_SIZE;

public class AlibabaCloudSearchService extends SenderService {
public static final String NAME = AlibabaCloudSearchUtils.SERVICE_NAME;
Expand Down Expand Up @@ -243,7 +245,20 @@ protected void doChunkedInfer(
TimeValue timeout,
ActionListener<List<ChunkedInferenceServiceResults>> listener
) {
listener.onFailure(new ElasticsearchStatusException("Chunking not supported by the {} service", RestStatus.BAD_REQUEST, NAME));
if (model instanceof AlibabaCloudSearchModel == false) {
listener.onFailure(createInvalidModelException(model));
return;
}

AlibabaCloudSearchModel alibabaCloudSearchModel = (AlibabaCloudSearchModel) model;
var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents());

var batchedRequests = new EmbeddingRequestChunker(input, EMBEDDING_MAX_BATCH_SIZE, EmbeddingRequestChunker.EmbeddingType.FLOAT)
.batchRequestsWithListeners(listener);
for (var request : batchedRequests) {
var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings, inputType);
action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.alibabacloudsearch;

public class AlibabaCloudSearchServiceFields {
/**
* Taken from https://help.aliyun.com/zh/open-search/search-platform/developer-reference/text-embedding-api-details
*/
static final int EMBEDDING_MAX_BATCH_SIZE = 32;
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,37 @@
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.ChunkingOptions;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.alibabacloudsearch.AlibabaCloudSearchActionVisitor;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchUtils;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModelTests;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsServiceSettingsTests;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsTaskSettingsTests;
import org.hamcrest.CoreMatchers;
import org.hamcrest.MatcherAssert;
import org.junit.After;
import org.junit.Before;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -44,6 +53,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.Mockito.mock;

Expand Down Expand Up @@ -156,6 +166,84 @@ public void doInfer(
}
}

public void testChunkedInfer_Batches() throws IOException {
var input = List.of("foo", "bar");

var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);

try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) {
Map<String, Object> serviceSettingsMap = new HashMap<>();
serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.SERVICE_ID, "service_id");
serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.HOST, "host");
serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.WORKSPACE_NAME, "default");
serviceSettingsMap.put(ServiceFields.DIMENSIONS, 1536);

Map<String, Object> taskSettingsMap = new HashMap<>();

Map<String, Object> secretSettingsMap = new HashMap<>();
secretSettingsMap.put("api_key", "secret");

var model = new AlibabaCloudSearchEmbeddingsModel(
"service",
TaskType.TEXT_EMBEDDING,
AlibabaCloudSearchUtils.SERVICE_NAME,
serviceSettingsMap,
taskSettingsMap,
secretSettingsMap,
null
) {
public ExecutableAction accept(
AlibabaCloudSearchActionVisitor visitor,
Map<String, Object> taskSettings,
InputType inputType
) {
return (inferenceInputs, timeout, listener) -> {
InferenceTextEmbeddingFloatResults results = new InferenceTextEmbeddingFloatResults(
List.of(
new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.0123f, -0.0123f }),
new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.0456f, -0.0456f })
)
);

listener.onResponse(results);
};
}
};

PlainActionFuture<List<ChunkedInferenceServiceResults>> listener = new PlainActionFuture<>();
service.chunkedInfer(
model,
input,
new HashMap<>(),
InputType.INGEST,
new ChunkingOptions(null, null),
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);

var results = listener.actionGet(TIMEOUT);
assertThat(results, hasSize(2));

// first result
{
assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class));
var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(input.get(0), floatResult.chunks().get(0).matchedText());
assertTrue(Arrays.equals(new float[] { 0.0123f, -0.0123f }, floatResult.chunks().get(0).embedding()));
}

// second result
{
assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class));
var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1);
assertThat(floatResult.chunks(), hasSize(1));
assertEquals(input.get(1), floatResult.chunks().get(0).matchedText());
assertTrue(Arrays.equals(new float[] { 0.0456f, -0.0456f }, floatResult.chunks().get(0).embedding()));
}
}
}

private Map<String, Object> getRequestConfigMap(
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
Expand Down

0 comments on commit f33f74b

Please sign in to comment.