Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Adding retries for starting model deployment #99673

Merged
5 changes: 5 additions & 0 deletions docs/changelog/99673.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 99673
summary: Adding retry logic for start model deployment API
area: Machine Learning
type: bug
issues: [ ]
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,18 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.search.SearchPhaseExecutionException;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.common.breaker.CircuitBreakingException;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.core.CheckedFunction;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.sort.SortBuilders;
Expand All @@ -36,6 +39,7 @@
import java.io.IOException;
import java.io.InputStream;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;

import static org.elasticsearch.core.Strings.format;
Expand All @@ -58,6 +62,8 @@ public class ChunkedTrainedModelRestorer {
private static final Logger logger = LogManager.getLogger(ChunkedTrainedModelRestorer.class);

private static final int MAX_NUM_DEFINITION_DOCS = 20;
private static final int SEARCH_RETRY_LIMIT = 5;
private static final TimeValue SEARCH_FAILURE_RETRY_WAIT_TIME = new TimeValue(5, TimeUnit.SECONDS);

private final Client client;
private final NamedXContentRegistry xContentRegistry;
Expand Down Expand Up @@ -142,7 +148,14 @@ private void doSearch(
UTILITY_THREAD_POOL_NAME,
Thread.currentThread().getName()
);
SearchResponse searchResponse = client.search(searchRequest).actionGet();

SearchResponse searchResponse = retryingSearch(
client,
modelId,
searchRequest,
SEARCH_RETRY_LIMIT,
SEARCH_FAILURE_RETRY_WAIT_TIME
);
if (searchResponse.getHits().getHits().length == 0) {
errorConsumer.accept(new ResourceNotFoundException(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
return;
Expand Down Expand Up @@ -201,6 +214,31 @@ private void doSearch(
}
}

static SearchResponse retryingSearch(Client client, String modelId, SearchRequest searchRequest, int retries, TimeValue sleep)
throws InterruptedException {
int failureCount = 0;

while (true) {
try {
return client.search(searchRequest).actionGet();
} catch (Exception e) {
if (ExceptionsHelper.unwrapCause(e) instanceof SearchPhaseExecutionException == false
&& ExceptionsHelper.unwrapCause(e) instanceof CircuitBreakingException == false) {
throw e;
}

if (failureCount >= retries) {
logger.warn(format("[%s] searching for model part failed %s times, returning failure", modelId, retries));
throw e;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the retry logic in TrainedModelAssignmentNodeService now be removed? If the search has failed after N tries it trying to reload the entire model puts the cluster under more strain.

https://github.com/elastic/elasticsearch/blob/main/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java#L219

Instead of throwing e please wrap it in another exception with a sensible error message saying something along the lines of loading model [model_id] failed after [SEARCH_RETRY_LIMIT] retries. The deployment is now in a failed state, the error may be transient please stop the deployment and restart.

The error message ends up in the failed routing state here and will eventually make it's way back to the caller assuming the request did not time out.

https://github.com/elastic/elasticsearch/blob/main/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java#L718

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good point 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dave and I talked about this offline and we don't need to adjust the retry logic in TrainedModelAssignmentNodeService because using ElasticsearchException cannot be unwrapped and will not trigger that retry logic.

}

failureCount++;
logger.debug(format("[%s] searching for model part failed %s times, retrying", modelId, failureCount));
TimeUnit.SECONDS.sleep(sleep.getSeconds());
}
}
}

private static SearchRequestBuilder buildSearchBuilder(Client client, String modelId, String index, int searchSize) {
return client.prepareSearch(index)
.setQuery(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.ml.inference.persistence;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.search.SearchPhaseExecutionException;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.ShardSearchFailure;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.breaker.CircuitBreakingException;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.test.ESTestCase;

import java.util.concurrent.TimeUnit;

import static org.hamcrest.Matchers.is;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class ChunkedTrainedModelRestorerTests extends ESTestCase {
public void testRetryingSearch_ReturnsSearchResults() throws InterruptedException {
var mockClient = mock(Client.class);
var mockSearchResponse = mock(SearchResponse.class, RETURNS_DEEP_STUBS);

PlainActionFuture<SearchResponse> searchFuture = new PlainActionFuture<>();
searchFuture.onResponse(mockSearchResponse);
when(mockClient.search(any())).thenReturn(searchFuture);

var request = createSearchRequest();

assertThat(
ChunkedTrainedModelRestorer.retryingSearch(mockClient, "", request, 5, new TimeValue(1, TimeUnit.NANOSECONDS)),
is(mockSearchResponse)
);

verify(mockClient, times(1)).search(any());
}

public void testRetryingSearch_ThrowsSearchPhaseExceptionWithNoRetries() {
try (var mockClient = mock(Client.class)) {
var searchPhaseException = new SearchPhaseExecutionException("phase", "error", ShardSearchFailure.EMPTY_ARRAY);
when(mockClient.search(any())).thenThrow(searchPhaseException);

var request = createSearchRequest();

SearchPhaseExecutionException exception = expectThrows(
SearchPhaseExecutionException.class,
() -> ChunkedTrainedModelRestorer.retryingSearch(mockClient, "", request, 0, new TimeValue(1, TimeUnit.NANOSECONDS))
);

assertThat(exception, is(searchPhaseException));
verify(mockClient, times(1)).search(any());
}
}

public void testRetryingSearch_ThrowsSearchPhaseExceptionAfterOneRetry() {
try (var mockClient = mock(Client.class)) {
var searchPhaseException = new SearchPhaseExecutionException("phase", "error", ShardSearchFailure.EMPTY_ARRAY);
when(mockClient.search(any())).thenThrow(searchPhaseException);

var request = createSearchRequest();

SearchPhaseExecutionException exception = expectThrows(
SearchPhaseExecutionException.class,
() -> ChunkedTrainedModelRestorer.retryingSearch(mockClient, "", request, 1, new TimeValue(1, TimeUnit.NANOSECONDS))
);

assertThat(exception, is(searchPhaseException));
verify(mockClient, times(2)).search(any());
}
}

public void testRetryingSearch_ThrowsCircuitBreakingExceptionAfterOneRetry_FromSearchPhaseException() {
try (var mockClient = mock(Client.class)) {
var searchPhaseException = new SearchPhaseExecutionException("phase", "error", ShardSearchFailure.EMPTY_ARRAY);
var circuitBreakerException = new CircuitBreakingException("error", CircuitBreaker.Durability.TRANSIENT);
when(mockClient.search(any())).thenThrow(searchPhaseException).thenThrow(circuitBreakerException);

var request = createSearchRequest();

CircuitBreakingException exception = expectThrows(
CircuitBreakingException.class,
() -> ChunkedTrainedModelRestorer.retryingSearch(mockClient, "", request, 1, new TimeValue(1, TimeUnit.NANOSECONDS))
);

assertThat(exception, is(circuitBreakerException));
verify(mockClient, times(2)).search(any());
}
}

public void testRetryingSearch_ThrowsIOExceptionIgnoringRetries() {
try (var mockClient = mock(Client.class)) {
var exception = new ElasticsearchException("Error");
when(mockClient.search(any())).thenThrow(exception);

var request = createSearchRequest();

ElasticsearchException thrownException = expectThrows(
ElasticsearchException.class,
() -> ChunkedTrainedModelRestorer.retryingSearch(mockClient, "", request, 1, new TimeValue(1, TimeUnit.NANOSECONDS))
);

assertThat(thrownException, is(exception));
verify(mockClient, times(1)).search(any());
}
}

public void testRetryingSearch_ThrowsSearchPhaseExceptionOnce_ThenReturnsResponse() throws InterruptedException {
try (var mockClient = mock(Client.class)) {
var mockSearchResponse = mock(SearchResponse.class, RETURNS_DEEP_STUBS);

PlainActionFuture<SearchResponse> searchFuture = new PlainActionFuture<>();
searchFuture.onResponse(mockSearchResponse);

var searchPhaseException = new SearchPhaseExecutionException("phase", "error", ShardSearchFailure.EMPTY_ARRAY);
when(mockClient.search(any())).thenThrow(searchPhaseException).thenReturn(searchFuture);

var request = createSearchRequest();

assertThat(
ChunkedTrainedModelRestorer.retryingSearch(mockClient, "", request, 1, new TimeValue(1, TimeUnit.NANOSECONDS)),
is(mockSearchResponse)
);

verify(mockClient, times(2)).search(any());
}
}

private static SearchRequest createSearchRequest() {
return new SearchRequest("index");
}
}