From b9384e92917db2352222dad63a1addb395b9fb53 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Thu, 9 Nov 2023 14:53:34 +0000 Subject: [PATCH] Max timeout for inference at ingest --- .../core/ml/action/InferModelAction.java | 48 +++++++++++++----- .../action/InferModelActionRequestTests.java | 8 +-- .../license/MachineLearningLicensingIT.java | 20 +++++--- .../integration/ModelInferenceActionIT.java | 50 +++++++++++++++---- .../inference/ingest/InferenceProcessor.java | 16 +++++- .../ml/queries/TextExpansionQueryBuilder.java | 4 +- .../TextEmbeddingQueryVectorBuilder.java | 4 +- .../ingest/InferenceProcessorTests.java | 18 +++++-- .../TextExpansionQueryBuilderTests.java | 1 + .../TextEmbeddingQueryVectorBuilderTests.java | 1 + 10 files changed, 128 insertions(+), 42 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java index a06d0fe0ce0ce..61e52935f46e9 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java @@ -82,7 +82,7 @@ public static Builder parseRequest(String id, XContentParser parser) { private final List> objectsToInfer; private final InferenceConfigUpdate update; private final boolean previouslyLicensed; - private TimeValue inferenceTimeout; + private final TimeValue inferenceTimeout; // textInput added for uses that accept a query string // and do know which field the model expects to find its // input and so cannot construct a document. @@ -95,18 +95,32 @@ public static Builder parseRequest(String id, XContentParser parser) { * the inference queue for) is set to a high value {@code #DEFAULT_TIMEOUT_FOR_INGEST} * to prefer slow ingest over dropping documents. */ + + /** + * Build a request from a list of documents as maps. + * + * @param id The model Id + * @param docs List of document maps + * @param update Inference config update + * @param previouslyLicensed License has been checked previously + * and can now be skipped + * @param inferenceTimeout The inference timeout (how long the + * request waits in the inference queue for) + * @return the new Request + */ public static Request forIngestDocs( String id, List> docs, InferenceConfigUpdate update, - boolean previouslyLicensed + boolean previouslyLicensed, + TimeValue inferenceTimeout ) { return new Request( ExceptionsHelper.requireNonNull(id, InferModelAction.Request.ID), update, ExceptionsHelper.requireNonNull(Collections.unmodifiableList(docs), DOCS), null, - DEFAULT_TIMEOUT_FOR_INGEST, + inferenceTimeout, previouslyLicensed ); } @@ -114,17 +128,30 @@ public static Request forIngestDocs( /** * Build a request from a list of strings, each string * is one evaluation of the model. - * The inference timeout (how long the request waits in - * the inference queue for) is set to {@code #DEFAULT_TIMEOUT_FOR_API} + * + * @param id The model Id + * @param update Inference config update + * @param textInput Inference input + * @param previouslyLicensed License has been checked previously + * and can now be skipped + * @param inferenceTimeout The inference timeout (how long the + * request waits in the inference queue for) + * @return the new Request */ - public static Request forTextInput(String id, InferenceConfigUpdate update, List textInput) { + public static Request forTextInput( + String id, + InferenceConfigUpdate update, + List textInput, + boolean previouslyLicensed, + TimeValue inferenceTimeout + ) { return new Request( id, update, List.of(), ExceptionsHelper.requireNonNull(textInput, "inference text input"), - DEFAULT_TIMEOUT_FOR_API, - false + inferenceTimeout, + previouslyLicensed ); } @@ -197,11 +224,6 @@ public TimeValue getInferenceTimeout() { return inferenceTimeout; } - public Request setInferenceTimeout(TimeValue inferenceTimeout) { - this.inferenceTimeout = inferenceTimeout; - return this; - } - public boolean isHighPriority() { return highPriority; } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java index 2f073cb32d09a..69c1b23a5ff85 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java @@ -54,12 +54,15 @@ protected Request createTestInstance() { randomAlphaOfLength(10), Stream.generate(InferModelActionRequestTests::randomMap).limit(randomInt(10)).collect(Collectors.toList()), randomInferenceConfigUpdate(), - randomBoolean() + randomBoolean(), + TimeValue.timeValueMillis(randomLongBetween(1, 2048)) ) : Request.forTextInput( randomAlphaOfLength(10), randomInferenceConfigUpdate(), - Arrays.asList(generateRandomStringArray(3, 5, false)) + Arrays.asList(generateRandomStringArray(3, 5, false)), + randomBoolean(), + TimeValue.timeValueMillis(randomLongBetween(1, 2048)) ); request.setHighPriority(randomBoolean()); @@ -114,7 +117,6 @@ protected Request mutateInstance(Request instance) { var r = new Request(modelId, update, objectsToInfer, textInput, timeout, previouslyLicensed); r.setHighPriority(highPriority); - r.setInferenceTimeout(timeout); return r; } diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/license/MachineLearningLicensingIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/license/MachineLearningLicensingIT.java index 9933daa4693ce..01a9c166ff0e4 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/license/MachineLearningLicensingIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/license/MachineLearningLicensingIT.java @@ -669,8 +669,9 @@ public void testMachineLearningInferModelRestricted() { modelId, Collections.singletonList(Collections.emptyMap()), RegressionConfigUpdate.EMPTY_PARAMS, - false - ).setInferenceTimeout(TimeValue.timeValueSeconds(5)), + false, + TimeValue.timeValueSeconds(5) + ), inferModelSuccess ); InferModelAction.Response response = inferModelSuccess.actionGet(); @@ -690,8 +691,9 @@ public void testMachineLearningInferModelRestricted() { modelId, Collections.singletonList(Collections.emptyMap()), RegressionConfigUpdate.EMPTY_PARAMS, - false - ).setInferenceTimeout(TimeValue.timeValueSeconds(5)) + false, + TimeValue.timeValueSeconds(5) + ) ).actionGet(); }); assertThat(e.status(), is(RestStatus.FORBIDDEN)); @@ -706,8 +708,9 @@ public void testMachineLearningInferModelRestricted() { modelId, Collections.singletonList(Collections.emptyMap()), RegressionConfigUpdate.EMPTY_PARAMS, - true - ).setInferenceTimeout(TimeValue.timeValueSeconds(5)), + true, + TimeValue.timeValueSeconds(5) + ), inferModelSuccess ); response = inferModelSuccess.actionGet(); @@ -726,8 +729,9 @@ public void testMachineLearningInferModelRestricted() { modelId, Collections.singletonList(Collections.emptyMap()), RegressionConfigUpdate.EMPTY_PARAMS, - false - ).setInferenceTimeout(TimeValue.timeValueSeconds(5)), + false, + TimeValue.timeValueSeconds(5) + ), listener ); assertThat(listener.actionGet().getInferenceResults(), is(not(empty()))); diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java index b9ca3946412bc..e03445912175a 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -175,7 +175,8 @@ public void testInferModels() throws Exception { modelId1, toInfer, RegressionConfigUpdate.EMPTY_PARAMS, - true + true, + InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST ); InferModelAction.Response response = client().execute(InferModelAction.INSTANCE, request).actionGet(); assertThat( @@ -183,7 +184,13 @@ public void testInferModels() throws Exception { contains(1.3, 1.25) ); - request = InferModelAction.Request.forIngestDocs(modelId1, toInfer2, RegressionConfigUpdate.EMPTY_PARAMS, true); + request = InferModelAction.Request.forIngestDocs( + modelId1, + toInfer2, + RegressionConfigUpdate.EMPTY_PARAMS, + true, + InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST + ); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); assertThat( response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults) i).value()).collect(Collectors.toList()), @@ -191,7 +198,13 @@ public void testInferModels() throws Exception { ); // Test classification - request = InferModelAction.Request.forIngestDocs(modelId2, toInfer, ClassificationConfigUpdate.EMPTY_PARAMS, true); + request = InferModelAction.Request.forIngestDocs( + modelId2, + toInfer, + ClassificationConfigUpdate.EMPTY_PARAMS, + true, + InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST + ); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); assertThat( response.getInferenceResults() @@ -206,7 +219,8 @@ public void testInferModels() throws Exception { modelId2, toInfer, new ClassificationConfigUpdate(2, null, null, null, null), - true + true, + InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST ); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); @@ -234,7 +248,8 @@ public void testInferModels() throws Exception { modelId2, toInfer2, new ClassificationConfigUpdate(1, null, null, null, null), - true + true, + InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST ); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); @@ -338,7 +353,8 @@ public void testInferModelMultiClassModel() throws Exception { modelId, toInfer, ClassificationConfigUpdate.EMPTY_PARAMS, - true + true, + InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST ); InferModelAction.Response response = client().execute(InferModelAction.INSTANCE, request).actionGet(); assertThat( @@ -349,7 +365,13 @@ public void testInferModelMultiClassModel() throws Exception { contains("option_0", "option_2") ); - request = InferModelAction.Request.forIngestDocs(modelId, toInfer2, ClassificationConfigUpdate.EMPTY_PARAMS, true); + request = InferModelAction.Request.forIngestDocs( + modelId, + toInfer2, + ClassificationConfigUpdate.EMPTY_PARAMS, + true, + InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST + ); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); assertThat( response.getInferenceResults() @@ -360,7 +382,13 @@ public void testInferModelMultiClassModel() throws Exception { ); // Get top classes - request = InferModelAction.Request.forIngestDocs(modelId, toInfer, new ClassificationConfigUpdate(3, null, null, null, null), true); + request = InferModelAction.Request.forIngestDocs( + modelId, + toInfer, + new ClassificationConfigUpdate(3, null, null, null, null), + true, + InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST + ); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); ClassificationInferenceResults classificationInferenceResults = (ClassificationInferenceResults) response.getInferenceResults() @@ -382,7 +410,8 @@ public void testInferMissingModel() { model, Collections.emptyList(), RegressionConfigUpdate.EMPTY_PARAMS, - true + true, + InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST ); try { client().execute(InferModelAction.INSTANCE, request).actionGet(); @@ -428,7 +457,8 @@ public void testInferMissingFields() throws Exception { modelId, toInferMissingField, RegressionConfigUpdate.EMPTY_PARAMS, - true + true, + InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST ); try { InferenceResults result = client().execute(InferModelAction.INSTANCE, request).actionGet().getInferenceResults().get(0); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java index 905317713263e..5518903dde125 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -245,7 +245,13 @@ InferModelAction.Request buildRequest(IngestDocument ingestDocument) { } } } - return InferModelAction.Request.forTextInput(modelId, inferenceConfig, requestInputs); + return InferModelAction.Request.forTextInput( + modelId, + inferenceConfig, + requestInputs, + previouslyLicensed, + InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST + ); } else { Map fields = new HashMap<>(ingestDocument.getSourceAndMetadata()); // Add ingestMetadata as previous processors might have added metadata from which we are predicting (see: foreach processor) @@ -254,7 +260,13 @@ InferModelAction.Request buildRequest(IngestDocument ingestDocument) { } LocalModel.mapFieldsIfNecessary(fields, fieldMap); - return InferModelAction.Request.forIngestDocs(modelId, List.of(fields), inferenceConfig, previouslyLicensed); + return InferModelAction.Request.forIngestDocs( + modelId, + List.of(fields), + inferenceConfig, + previouslyLicensed, + InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST + ); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java index 2d74b1b34888f..40e4f5d9ede78 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java @@ -126,7 +126,9 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws InferModelAction.Request inferRequest = InferModelAction.Request.forTextInput( modelId, TextExpansionConfigUpdate.EMPTY_UPDATE, - List.of(modelText) + List.of(modelText), + false, + InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API ); inferRequest.setHighPriority(true); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilder.java index 2dd76c8fab7cc..2e780c9849bd5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilder.java @@ -95,7 +95,9 @@ public void buildVector(Client client, ActionListener listener) { InferModelAction.Request inferRequest = InferModelAction.Request.forTextInput( modelId, TextEmbeddingConfigUpdate.EMPTY_INSTANCE, - List.of(modelText) + List.of(modelText), + false, + InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API ); inferRequest.setHighPriority(true); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java index a68084aa6eb28..88dcc2ba5d697 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java @@ -303,14 +303,19 @@ public void testGenerateRequestWithEmptyMapping() { }; IngestDocument document = TestIngestDocument.ofIngestWithNullableVersion(source, new HashMap<>()); - assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(source)); + var request = processor.buildRequest(document); + assertThat(request.getObjectsToInfer().get(0), equalTo(source)); + assertEquals(InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST, request.getInferenceTimeout()); Map ingestMetadata = Collections.singletonMap("_value", 3); document = TestIngestDocument.ofIngestWithNullableVersion(source, ingestMetadata); Map expected = new HashMap<>(source); expected.put("_ingest", ingestMetadata); - assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(expected)); + + request = processor.buildRequest(document); + assertThat(request.getObjectsToInfer().get(0), equalTo(expected)); + assertEquals(InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST, request.getInferenceTimeout()); } public void testGenerateWithMapping() { @@ -346,14 +351,18 @@ public void testGenerateWithMapping() { expectedMap.put("categorical", "foo"); expectedMap.put("new_categorical", "foo"); expectedMap.put("un_touched", "bar"); - assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(expectedMap)); + var request = processor.buildRequest(document); + assertThat(request.getObjectsToInfer().get(0), equalTo(expectedMap)); + assertEquals(InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST, request.getInferenceTimeout()); Map ingestMetadata = Collections.singletonMap("_value", "baz"); document = TestIngestDocument.ofIngestWithNullableVersion(source, ingestMetadata); expectedMap = new HashMap<>(expectedMap); expectedMap.put("metafield", "baz"); expectedMap.put("_ingest", ingestMetadata); - assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(expectedMap)); + request = processor.buildRequest(document); + assertThat(request.getObjectsToInfer().get(0), equalTo(expectedMap)); + assertEquals(InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST, request.getInferenceTimeout()); } public void testGenerateWithMappingNestedFields() { @@ -597,6 +606,7 @@ public void testBuildRequestWithInputFields() { assertTrue(request.getObjectsToInfer().isEmpty()); var requestInputs = request.getTextInput(); assertThat(requestInputs, contains("body_text", "title_text")); + assertEquals(InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST, request.getInferenceTimeout()); } public void testBuildRequestWithInputFields_WrongType() { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java index 7326dd0754041..a329a55d8afe9 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java @@ -77,6 +77,7 @@ protected boolean canSimulateMethod(Method method, Object[] args) throws NoSuchM @Override protected Object simulateMethod(Method method, Object[] args) { InferModelAction.Request request = (InferModelAction.Request) args[1]; + assertEquals(InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API, request.getInferenceTimeout()); // Randomisation cannot be used here as {@code #doAssertLuceneQuery} // asserts that 2 rewritten queries are the same diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java index 6fc81dca16176..2c83777487685 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java @@ -38,6 +38,7 @@ protected void doAssertClientRequest(ActionRequest request, TextEmbeddingQueryVe assertThat(inferRequest.getTextInput(), hasSize(1)); assertEquals(builder.getModelText(), inferRequest.getTextInput().get(0)); assertEquals(builder.getModelId(), inferRequest.getId()); + assertEquals(InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API, inferRequest.getInferenceTimeout()); } public ActionResponse createResponse(float[] array, TextEmbeddingQueryVectorBuilder builder) {