Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Adding retries for starting model deployment #99673

Merged
5 changes: 5 additions & 0 deletions docs/changelog/99673.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 99673
summary: Adding retry logic for start model deployment API
area: Machine Learning
type: bug
issues: [ ]
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,16 @@ void loadQueuedModels() {
logger.debug(() -> "[" + deploymentId + "] Start deployment failed as model [" + modelId + "] was not found", ex);
handleLoadFailure(loadingTask, ExceptionsHelper.missingTrainedModel(modelId, ex));
} else if (ExceptionsHelper.unwrapCause(ex) instanceof SearchPhaseExecutionException) {
/*
* This case will not catch the ElasticsearchException generated from the ChunkedTrainedModelRestorer in a scenario
* where the maximum number of retries for a SearchPhaseExecutionException or CBE occur. This is intentional. If the
* retry logic fails after retrying we should return the error and not retry here. The generated
* ElasticsearchException will contain the SearchPhaseExecutionException or CBE but cannot be unwrapped.
*/
logger.debug(() -> "[" + deploymentId + "] Start deployment failed, will retry", ex);
// A search phase execution failure should be retried, push task back to the queue

// This will cause the entire model to be reloaded (all the chunks)
loadingToRetry.add(loadingTask);
} else {
handleLoadFailure(loadingTask, ex);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,20 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.search.SearchPhaseExecutionException;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.common.breaker.CircuitBreakingException;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.core.CheckedFunction;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.sort.SortBuilders;
Expand All @@ -36,6 +40,7 @@
import java.io.IOException;
import java.io.InputStream;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;

import static org.elasticsearch.core.Strings.format;
Expand All @@ -58,6 +63,8 @@ public class ChunkedTrainedModelRestorer {
private static final Logger logger = LogManager.getLogger(ChunkedTrainedModelRestorer.class);

private static final int MAX_NUM_DEFINITION_DOCS = 20;
private static final int SEARCH_RETRY_LIMIT = 5;
private static final TimeValue SEARCH_FAILURE_RETRY_WAIT_TIME = new TimeValue(5, TimeUnit.SECONDS);

private final Client client;
private final NamedXContentRegistry xContentRegistry;
Expand Down Expand Up @@ -142,7 +149,14 @@ private void doSearch(
UTILITY_THREAD_POOL_NAME,
Thread.currentThread().getName()
);
SearchResponse searchResponse = client.search(searchRequest).actionGet();

SearchResponse searchResponse = retryingSearch(
client,
modelId,
searchRequest,
SEARCH_RETRY_LIMIT,
SEARCH_FAILURE_RETRY_WAIT_TIME
);
if (searchResponse.getHits().getHits().length == 0) {
errorConsumer.accept(new ResourceNotFoundException(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
return;
Expand Down Expand Up @@ -201,6 +215,47 @@ private void doSearch(
}
}

static SearchResponse retryingSearch(Client client, String modelId, SearchRequest searchRequest, int retries, TimeValue sleep)
throws InterruptedException {
int failureCount = 0;

while (true) {
try {
return client.search(searchRequest).actionGet();
} catch (Exception e) {
if (ExceptionsHelper.unwrapCause(e) instanceof SearchPhaseExecutionException == false
&& ExceptionsHelper.unwrapCause(e) instanceof CircuitBreakingException == false) {
throw e;
}

if (failureCount >= retries) {
logger.warn(format("[%s] searching for model part failed %s times, returning failure", modelId, retries));
/*
* ElasticsearchException does not implement the ElasticsearchWrapperException interface so this exception cannot
* be unwrapped. This is important because the TrainedModelAssignmentNodeService has retry logic when a
* SearchPhaseExecutionException occurs:
* https://github.com/elastic/elasticsearch/blob/main/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java#L219
* This intentionally prevents that code from attempting to retry loading the entire model. If the retry logic here
* fails after the set retries we should not retry loading the entire model to avoid additional strain on the cluster.
*/
throw new ElasticsearchException(
format(
"loading model [%s] failed after [%s] retries. The deployment is now in a failed state, "
+ "the error may be transient please stop the deployment and restart",
modelId,
retries
),
e
);
}

failureCount++;
logger.debug(format("[%s] searching for model part failed %s times, retrying", modelId, failureCount));
TimeUnit.SECONDS.sleep(sleep.getSeconds());
}
}
}

private static SearchRequestBuilder buildSearchBuilder(Client client, String modelId, String index, int searchSize) {
return client.prepareSearch(index)
.setQuery(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.ml.inference.persistence;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.search.SearchPhaseExecutionException;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.ShardSearchFailure;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.breaker.CircuitBreakingException;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.test.ESTestCase;

import java.util.concurrent.TimeUnit;

import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class ChunkedTrainedModelRestorerTests extends ESTestCase {
public void testRetryingSearch_ReturnsSearchResults() throws InterruptedException {
var mockClient = mock(Client.class);
var mockSearchResponse = mock(SearchResponse.class, RETURNS_DEEP_STUBS);

PlainActionFuture<SearchResponse> searchFuture = new PlainActionFuture<>();
searchFuture.onResponse(mockSearchResponse);
when(mockClient.search(any())).thenReturn(searchFuture);

var request = createSearchRequest();

assertThat(
ChunkedTrainedModelRestorer.retryingSearch(mockClient, "", request, 5, new TimeValue(1, TimeUnit.NANOSECONDS)),
is(mockSearchResponse)
);

verify(mockClient, times(1)).search(any());
}

public void testRetryingSearch_ThrowsSearchPhaseExceptionWithNoRetries() {
try (var mockClient = mock(Client.class)) {
var searchPhaseException = new SearchPhaseExecutionException("phase", "error", ShardSearchFailure.EMPTY_ARRAY);
when(mockClient.search(any())).thenThrow(searchPhaseException);

var request = createSearchRequest();

ElasticsearchException exception = expectThrows(
ElasticsearchException.class,
() -> ChunkedTrainedModelRestorer.retryingSearch(mockClient, "1", request, 0, new TimeValue(1, TimeUnit.NANOSECONDS))
);

assertThat(exception.getCause(), is(searchPhaseException));
assertThat(
exception.getMessage(),
is(
"loading model [1] failed after [0] retries. The deployment is now in a failed state, the error may be "
+ "transient please stop the deployment and restart"
)
);
verify(mockClient, times(1)).search(any());
}
}

public void testRetryingSearch_ThrowsSearchPhaseExceptionAfterOneRetry() {
try (var mockClient = mock(Client.class)) {
var searchPhaseException = new SearchPhaseExecutionException("phase", "error", ShardSearchFailure.EMPTY_ARRAY);
when(mockClient.search(any())).thenThrow(searchPhaseException);

var request = createSearchRequest();

ElasticsearchException exception = expectThrows(
ElasticsearchException.class,
() -> ChunkedTrainedModelRestorer.retryingSearch(mockClient, "", request, 1, new TimeValue(1, TimeUnit.NANOSECONDS))
);

assertThat(exception.getCause(), is(searchPhaseException));
verify(mockClient, times(2)).search(any());
}
}

public void testRetryingSearch_ThrowsCircuitBreakingExceptionAfterOneRetry_FromSearchPhaseException() {
try (var mockClient = mock(Client.class)) {
var searchPhaseException = new SearchPhaseExecutionException("phase", "error", ShardSearchFailure.EMPTY_ARRAY);
var circuitBreakerException = new CircuitBreakingException("error", CircuitBreaker.Durability.TRANSIENT);
when(mockClient.search(any())).thenThrow(searchPhaseException).thenThrow(circuitBreakerException);

var request = createSearchRequest();

ElasticsearchException exception = expectThrows(
ElasticsearchException.class,
() -> ChunkedTrainedModelRestorer.retryingSearch(mockClient, "", request, 1, new TimeValue(1, TimeUnit.NANOSECONDS))
);

assertThat(exception.getCause(), is(circuitBreakerException));
verify(mockClient, times(2)).search(any());
}
}

public void testRetryingSearch_EnsureExceptionCannotBeUnwrapped() {
try (var mockClient = mock(Client.class)) {
var searchPhaseExecutionException = new SearchPhaseExecutionException("phase", "error", ShardSearchFailure.EMPTY_ARRAY);
when(mockClient.search(any())).thenThrow(searchPhaseExecutionException);

var request = createSearchRequest();

ElasticsearchException exception = expectThrows(
ElasticsearchException.class,
() -> ChunkedTrainedModelRestorer.retryingSearch(mockClient, "", request, 1, new TimeValue(1, TimeUnit.NANOSECONDS))
);

assertThat(ExceptionsHelper.unwrapCause(exception), is(exception));
assertThat(ExceptionsHelper.unwrapCause(exception), instanceOf(ElasticsearchException.class));
verify(mockClient, times(2)).search(any());
}
}

public void testRetryingSearch_ThrowsIllegalArgumentExceptionIgnoringRetries() {
try (var mockClient = mock(Client.class)) {
var exception = new IllegalArgumentException("Error");
when(mockClient.search(any())).thenThrow(exception);

var request = createSearchRequest();

IllegalArgumentException thrownException = expectThrows(
IllegalArgumentException.class,
() -> ChunkedTrainedModelRestorer.retryingSearch(mockClient, "", request, 1, new TimeValue(1, TimeUnit.NANOSECONDS))
);

assertThat(thrownException, is(exception));
verify(mockClient, times(1)).search(any());
}
}

public void testRetryingSearch_ThrowsSearchPhaseExceptionOnce_ThenReturnsResponse() throws InterruptedException {
try (var mockClient = mock(Client.class)) {
var mockSearchResponse = mock(SearchResponse.class, RETURNS_DEEP_STUBS);

PlainActionFuture<SearchResponse> searchFuture = new PlainActionFuture<>();
searchFuture.onResponse(mockSearchResponse);

var searchPhaseException = new SearchPhaseExecutionException("phase", "error", ShardSearchFailure.EMPTY_ARRAY);
when(mockClient.search(any())).thenThrow(searchPhaseException).thenReturn(searchFuture);

var request = createSearchRequest();

assertThat(
ChunkedTrainedModelRestorer.retryingSearch(mockClient, "", request, 1, new TimeValue(1, TimeUnit.NANOSECONDS)),
is(mockSearchResponse)
);

verify(mockClient, times(2)).search(any());
}
}

private static SearchRequest createSearchRequest() {
return new SearchRequest("index");
}
}