Skip to content

Commit

Permalink
Max timeout for inference at ingest
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle committed Nov 9, 2023
1 parent 3a09c64 commit b9384e9
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public static Builder parseRequest(String id, XContentParser parser) {
private final List<Map<String, Object>> objectsToInfer;
private final InferenceConfigUpdate update;
private final boolean previouslyLicensed;
private TimeValue inferenceTimeout;
private final TimeValue inferenceTimeout;
// textInput added for uses that accept a query string
// and do know which field the model expects to find its
// input and so cannot construct a document.
Expand All @@ -95,36 +95,63 @@ public static Builder parseRequest(String id, XContentParser parser) {
* the inference queue for) is set to a high value {@code #DEFAULT_TIMEOUT_FOR_INGEST}
* to prefer slow ingest over dropping documents.
*/

/**
* Build a request from a list of documents as maps.
*
* @param id The model Id
* @param docs List of document maps
* @param update Inference config update
* @param previouslyLicensed License has been checked previously
* and can now be skipped
* @param inferenceTimeout The inference timeout (how long the
* request waits in the inference queue for)
* @return the new Request
*/
public static Request forIngestDocs(
String id,
List<Map<String, Object>> docs,
InferenceConfigUpdate update,
boolean previouslyLicensed
boolean previouslyLicensed,
TimeValue inferenceTimeout
) {
return new Request(
ExceptionsHelper.requireNonNull(id, InferModelAction.Request.ID),
update,
ExceptionsHelper.requireNonNull(Collections.unmodifiableList(docs), DOCS),
null,
DEFAULT_TIMEOUT_FOR_INGEST,
inferenceTimeout,
previouslyLicensed
);
}

/**
* Build a request from a list of strings, each string
* is one evaluation of the model.
* The inference timeout (how long the request waits in
* the inference queue for) is set to {@code #DEFAULT_TIMEOUT_FOR_API}
*
* @param id The model Id
* @param update Inference config update
* @param textInput Inference input
* @param previouslyLicensed License has been checked previously
* and can now be skipped
* @param inferenceTimeout The inference timeout (how long the
* request waits in the inference queue for)
* @return the new Request
*/
public static Request forTextInput(String id, InferenceConfigUpdate update, List<String> textInput) {
public static Request forTextInput(
String id,
InferenceConfigUpdate update,
List<String> textInput,
boolean previouslyLicensed,
TimeValue inferenceTimeout
) {
return new Request(
id,
update,
List.of(),
ExceptionsHelper.requireNonNull(textInput, "inference text input"),
DEFAULT_TIMEOUT_FOR_API,
false
inferenceTimeout,
previouslyLicensed
);
}

Expand Down Expand Up @@ -197,11 +224,6 @@ public TimeValue getInferenceTimeout() {
return inferenceTimeout;
}

public Request setInferenceTimeout(TimeValue inferenceTimeout) {
this.inferenceTimeout = inferenceTimeout;
return this;
}

public boolean isHighPriority() {
return highPriority;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,15 @@ protected Request createTestInstance() {
randomAlphaOfLength(10),
Stream.generate(InferModelActionRequestTests::randomMap).limit(randomInt(10)).collect(Collectors.toList()),
randomInferenceConfigUpdate(),
randomBoolean()
randomBoolean(),
TimeValue.timeValueMillis(randomLongBetween(1, 2048))
)
: Request.forTextInput(
randomAlphaOfLength(10),
randomInferenceConfigUpdate(),
Arrays.asList(generateRandomStringArray(3, 5, false))
Arrays.asList(generateRandomStringArray(3, 5, false)),
randomBoolean(),
TimeValue.timeValueMillis(randomLongBetween(1, 2048))
);

request.setHighPriority(randomBoolean());
Expand Down Expand Up @@ -114,7 +117,6 @@ protected Request mutateInstance(Request instance) {

var r = new Request(modelId, update, objectsToInfer, textInput, timeout, previouslyLicensed);
r.setHighPriority(highPriority);
r.setInferenceTimeout(timeout);
return r;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -669,8 +669,9 @@ public void testMachineLearningInferModelRestricted() {
modelId,
Collections.singletonList(Collections.emptyMap()),
RegressionConfigUpdate.EMPTY_PARAMS,
false
).setInferenceTimeout(TimeValue.timeValueSeconds(5)),
false,
TimeValue.timeValueSeconds(5)
),
inferModelSuccess
);
InferModelAction.Response response = inferModelSuccess.actionGet();
Expand All @@ -690,8 +691,9 @@ public void testMachineLearningInferModelRestricted() {
modelId,
Collections.singletonList(Collections.emptyMap()),
RegressionConfigUpdate.EMPTY_PARAMS,
false
).setInferenceTimeout(TimeValue.timeValueSeconds(5))
false,
TimeValue.timeValueSeconds(5)
)
).actionGet();
});
assertThat(e.status(), is(RestStatus.FORBIDDEN));
Expand All @@ -706,8 +708,9 @@ public void testMachineLearningInferModelRestricted() {
modelId,
Collections.singletonList(Collections.emptyMap()),
RegressionConfigUpdate.EMPTY_PARAMS,
true
).setInferenceTimeout(TimeValue.timeValueSeconds(5)),
true,
TimeValue.timeValueSeconds(5)
),
inferModelSuccess
);
response = inferModelSuccess.actionGet();
Expand All @@ -726,8 +729,9 @@ public void testMachineLearningInferModelRestricted() {
modelId,
Collections.singletonList(Collections.emptyMap()),
RegressionConfigUpdate.EMPTY_PARAMS,
false
).setInferenceTimeout(TimeValue.timeValueSeconds(5)),
false,
TimeValue.timeValueSeconds(5)
),
listener
);
assertThat(listener.actionGet().getInferenceResults(), is(not(empty())));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,23 +175,36 @@ public void testInferModels() throws Exception {
modelId1,
toInfer,
RegressionConfigUpdate.EMPTY_PARAMS,
true
true,
InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST
);
InferModelAction.Response response = client().execute(InferModelAction.INSTANCE, request).actionGet();
assertThat(
response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults) i).value()).collect(Collectors.toList()),
contains(1.3, 1.25)
);

request = InferModelAction.Request.forIngestDocs(modelId1, toInfer2, RegressionConfigUpdate.EMPTY_PARAMS, true);
request = InferModelAction.Request.forIngestDocs(
modelId1,
toInfer2,
RegressionConfigUpdate.EMPTY_PARAMS,
true,
InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST
);
response = client().execute(InferModelAction.INSTANCE, request).actionGet();
assertThat(
response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults) i).value()).collect(Collectors.toList()),
contains(1.65, 1.55)
);

// Test classification
request = InferModelAction.Request.forIngestDocs(modelId2, toInfer, ClassificationConfigUpdate.EMPTY_PARAMS, true);
request = InferModelAction.Request.forIngestDocs(
modelId2,
toInfer,
ClassificationConfigUpdate.EMPTY_PARAMS,
true,
InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST
);
response = client().execute(InferModelAction.INSTANCE, request).actionGet();
assertThat(
response.getInferenceResults()
Expand All @@ -206,7 +219,8 @@ public void testInferModels() throws Exception {
modelId2,
toInfer,
new ClassificationConfigUpdate(2, null, null, null, null),
true
true,
InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST
);
response = client().execute(InferModelAction.INSTANCE, request).actionGet();

Expand Down Expand Up @@ -234,7 +248,8 @@ public void testInferModels() throws Exception {
modelId2,
toInfer2,
new ClassificationConfigUpdate(1, null, null, null, null),
true
true,
InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST
);
response = client().execute(InferModelAction.INSTANCE, request).actionGet();

Expand Down Expand Up @@ -338,7 +353,8 @@ public void testInferModelMultiClassModel() throws Exception {
modelId,
toInfer,
ClassificationConfigUpdate.EMPTY_PARAMS,
true
true,
InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST
);
InferModelAction.Response response = client().execute(InferModelAction.INSTANCE, request).actionGet();
assertThat(
Expand All @@ -349,7 +365,13 @@ public void testInferModelMultiClassModel() throws Exception {
contains("option_0", "option_2")
);

request = InferModelAction.Request.forIngestDocs(modelId, toInfer2, ClassificationConfigUpdate.EMPTY_PARAMS, true);
request = InferModelAction.Request.forIngestDocs(
modelId,
toInfer2,
ClassificationConfigUpdate.EMPTY_PARAMS,
true,
InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST
);
response = client().execute(InferModelAction.INSTANCE, request).actionGet();
assertThat(
response.getInferenceResults()
Expand All @@ -360,7 +382,13 @@ public void testInferModelMultiClassModel() throws Exception {
);

// Get top classes
request = InferModelAction.Request.forIngestDocs(modelId, toInfer, new ClassificationConfigUpdate(3, null, null, null, null), true);
request = InferModelAction.Request.forIngestDocs(
modelId,
toInfer,
new ClassificationConfigUpdate(3, null, null, null, null),
true,
InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST
);
response = client().execute(InferModelAction.INSTANCE, request).actionGet();

ClassificationInferenceResults classificationInferenceResults = (ClassificationInferenceResults) response.getInferenceResults()
Expand All @@ -382,7 +410,8 @@ public void testInferMissingModel() {
model,
Collections.emptyList(),
RegressionConfigUpdate.EMPTY_PARAMS,
true
true,
InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST
);
try {
client().execute(InferModelAction.INSTANCE, request).actionGet();
Expand Down Expand Up @@ -428,7 +457,8 @@ public void testInferMissingFields() throws Exception {
modelId,
toInferMissingField,
RegressionConfigUpdate.EMPTY_PARAMS,
true
true,
InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST
);
try {
InferenceResults result = client().execute(InferModelAction.INSTANCE, request).actionGet().getInferenceResults().get(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,13 @@ InferModelAction.Request buildRequest(IngestDocument ingestDocument) {
}
}
}
return InferModelAction.Request.forTextInput(modelId, inferenceConfig, requestInputs);
return InferModelAction.Request.forTextInput(
modelId,
inferenceConfig,
requestInputs,
previouslyLicensed,
InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST
);
} else {
Map<String, Object> fields = new HashMap<>(ingestDocument.getSourceAndMetadata());
// Add ingestMetadata as previous processors might have added metadata from which we are predicting (see: foreach processor)
Expand All @@ -254,7 +260,13 @@ InferModelAction.Request buildRequest(IngestDocument ingestDocument) {
}

LocalModel.mapFieldsIfNecessary(fields, fieldMap);
return InferModelAction.Request.forIngestDocs(modelId, List.of(fields), inferenceConfig, previouslyLicensed);
return InferModelAction.Request.forIngestDocs(
modelId,
List.of(fields),
inferenceConfig,
previouslyLicensed,
InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST
);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws
InferModelAction.Request inferRequest = InferModelAction.Request.forTextInput(
modelId,
TextExpansionConfigUpdate.EMPTY_UPDATE,
List.of(modelText)
List.of(modelText),
false,
InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API
);
inferRequest.setHighPriority(true);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ public void buildVector(Client client, ActionListener<float[]> listener) {
InferModelAction.Request inferRequest = InferModelAction.Request.forTextInput(
modelId,
TextEmbeddingConfigUpdate.EMPTY_INSTANCE,
List.of(modelText)
List.of(modelText),
false,
InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API
);
inferRequest.setHighPriority(true);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,14 +303,19 @@ public void testGenerateRequestWithEmptyMapping() {
};
IngestDocument document = TestIngestDocument.ofIngestWithNullableVersion(source, new HashMap<>());

assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(source));
var request = processor.buildRequest(document);
assertThat(request.getObjectsToInfer().get(0), equalTo(source));
assertEquals(InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST, request.getInferenceTimeout());

Map<String, Object> ingestMetadata = Collections.singletonMap("_value", 3);
document = TestIngestDocument.ofIngestWithNullableVersion(source, ingestMetadata);

Map<String, Object> expected = new HashMap<>(source);
expected.put("_ingest", ingestMetadata);
assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(expected));

request = processor.buildRequest(document);
assertThat(request.getObjectsToInfer().get(0), equalTo(expected));
assertEquals(InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST, request.getInferenceTimeout());
}

public void testGenerateWithMapping() {
Expand Down Expand Up @@ -346,14 +351,18 @@ public void testGenerateWithMapping() {
expectedMap.put("categorical", "foo");
expectedMap.put("new_categorical", "foo");
expectedMap.put("un_touched", "bar");
assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(expectedMap));
var request = processor.buildRequest(document);
assertThat(request.getObjectsToInfer().get(0), equalTo(expectedMap));
assertEquals(InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST, request.getInferenceTimeout());

Map<String, Object> ingestMetadata = Collections.singletonMap("_value", "baz");
document = TestIngestDocument.ofIngestWithNullableVersion(source, ingestMetadata);
expectedMap = new HashMap<>(expectedMap);
expectedMap.put("metafield", "baz");
expectedMap.put("_ingest", ingestMetadata);
assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(expectedMap));
request = processor.buildRequest(document);
assertThat(request.getObjectsToInfer().get(0), equalTo(expectedMap));
assertEquals(InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST, request.getInferenceTimeout());
}

public void testGenerateWithMappingNestedFields() {
Expand Down Expand Up @@ -597,6 +606,7 @@ public void testBuildRequestWithInputFields() {
assertTrue(request.getObjectsToInfer().isEmpty());
var requestInputs = request.getTextInput();
assertThat(requestInputs, contains("body_text", "title_text"));
assertEquals(InferModelAction.Request.DEFAULT_TIMEOUT_FOR_INGEST, request.getInferenceTimeout());
}

public void testBuildRequestWithInputFields_WrongType() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ protected boolean canSimulateMethod(Method method, Object[] args) throws NoSuchM
@Override
protected Object simulateMethod(Method method, Object[] args) {
InferModelAction.Request request = (InferModelAction.Request) args[1];
assertEquals(InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API, request.getInferenceTimeout());

// Randomisation cannot be used here as {@code #doAssertLuceneQuery}
// asserts that 2 rewritten queries are the same
Expand Down
Loading

0 comments on commit b9384e9

Please sign in to comment.