Skip to content

Commit

Permalink
[ML] Persist progress when setting DFA task to failed (#61782)
Browse files Browse the repository at this point in the history
When an error occurs and we set the task to failed via
the `DataFrameAnalyticsTask.setFailed` method we do not
persist progress. If the job is later restarted, this means
we do not correctly restore from where we can but instead
we start the job from scratch and have to redo the reindexing
phase.

This commit solves this bug by persisting the progress before
setting the task to failed.
  • Loading branch information
dimitris-athanasiou authored Sep 1, 2020
1 parent 5f290c2 commit 2ba4e15
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF

Exception reindexError = getReindexError(task.getParams().getId(), reindexResponse);
if (reindexError != null) {
task.markAsFailed(reindexError);
task.setFailed(reindexError);
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
Expand Down Expand Up @@ -38,7 +37,6 @@
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.tasks.TaskResult;
import org.elasticsearch.xpack.core.ml.MlTasks;
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.action.StopDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
Expand Down Expand Up @@ -216,23 +214,25 @@ public void setFailed(Exception error) {
error);
return;
}
LOGGER.error(new ParameterizedMessage("[{}] Setting task to failed", taskParams.getId()), error);
String reason = ExceptionsHelper.unwrapCause(error).getMessage();
DataFrameAnalyticsTaskState newTaskState =
new DataFrameAnalyticsTaskState(DataFrameAnalyticsState.FAILED, getAllocationId(), reason);
updatePersistentTaskState(
newTaskState,
ActionListener.wrap(
updatedTask -> {
String message = Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_UPDATED_STATE_WITH_REASON,
persistProgress(client, taskParams.getId(), () -> {
LOGGER.error(new ParameterizedMessage("[{}] Setting task to failed", taskParams.getId()), error);
String reason = ExceptionsHelper.unwrapCause(error).getMessage();
DataFrameAnalyticsTaskState newTaskState =
new DataFrameAnalyticsTaskState(DataFrameAnalyticsState.FAILED, getAllocationId(), reason);
updatePersistentTaskState(
newTaskState,
ActionListener.wrap(
updatedTask -> {
String message = Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_UPDATED_STATE_WITH_REASON,
DataFrameAnalyticsState.FAILED, reason);
auditor.info(getParams().getId(), message);
LOGGER.info("[{}] {}", getParams().getId(), message);
},
e -> LOGGER.error(new ParameterizedMessage("[{}] Could not update task state to [{}] with reason [{}]",
getParams().getId(), DataFrameAnalyticsState.FAILED, reason), e)
)
);
auditor.info(getParams().getId(), message);
LOGGER.info("[{}] {}", getParams().getId(), message);
},
e -> LOGGER.error(new ParameterizedMessage("[{}] Could not update task state to [{}] with reason [{}]",
getParams().getId(), DataFrameAnalyticsState.FAILED, reason), e)
)
);
});
}

public void updateReindexTaskProgress(ActionListener<Void> listener) {
Expand Down Expand Up @@ -285,13 +285,12 @@ private TaskId getReindexTaskId() {
}

// Visible for testing
static void persistProgress(Client client, String jobId, Runnable runnable) {
void persistProgress(Client client, String jobId, Runnable runnable) {
LOGGER.debug("[{}] Persisting progress", jobId);

String progressDocId = StoredProgress.documentId(jobId);
SetOnce<GetDataFrameAnalyticsStatsAction.Response.Stats> stats = new SetOnce<>();

// Step 4: Run the runnable provided as the argument
// Step 3: Run the runnable provided as the argument
ActionListener<IndexResponse> indexProgressDocListener = ActionListener.wrap(
indexResponse -> {
LOGGER.debug("[{}] Successfully indexed progress document", jobId);
Expand All @@ -304,7 +303,7 @@ static void persistProgress(Client client, String jobId, Runnable runnable) {
}
);

// Step 3: Create or update the progress document:
// Step 2: Create or update the progress document:
// - if the document did not exist, create the new one in the current write index
// - if the document did exist, update it in the index where it resides (not necessarily the current write index)
ActionListener<SearchResponse> searchFormerProgressDocListener = ActionListener.wrap(
Expand All @@ -317,9 +316,10 @@ static void persistProgress(Client client, String jobId, Runnable runnable) {
.id(progressDocId)
.setRequireAlias(AnomalyDetectorsIndex.jobStateIndexWriteAlias().equals(indexOrAlias))
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
List<PhaseProgress> progress = statsHolder.getProgressTracker().report();
try (XContentBuilder jsonBuilder = JsonXContent.contentBuilder()) {
LOGGER.debug("[{}] Persisting progress is: {}", jobId, stats.get().getProgress());
new StoredProgress(stats.get().getProgress()).toXContent(jsonBuilder, Payload.XContent.EMPTY_PARAMS);
LOGGER.debug("[{}] Persisting progress is: {}", jobId, progress);
new StoredProgress(progress).toXContent(jsonBuilder, Payload.XContent.EMPTY_PARAMS);
indexRequest.source(jsonBuilder);
}
executeAsyncWithOrigin(client, ML_ORIGIN, IndexAction.INSTANCE, indexRequest, indexProgressDocListener);
Expand All @@ -331,28 +331,14 @@ static void persistProgress(Client client, String jobId, Runnable runnable) {
}
);

// Step 2: Search for existing progress document in .ml-state*
ActionListener<GetDataFrameAnalyticsStatsAction.Response> getStatsListener = ActionListener.wrap(
statsResponse -> {
stats.set(statsResponse.getResponse().results().get(0));
SearchRequest searchRequest =
new SearchRequest(AnomalyDetectorsIndex.jobStateIndexPattern())
.source(
new SearchSourceBuilder()
.size(1)
.query(new IdsQueryBuilder().addIds(progressDocId)));
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, searchFormerProgressDocListener);
},
e -> {
LOGGER.error(new ParameterizedMessage(
"[{}] cannot persist progress as an error occurred while retrieving stats", jobId), e);
runnable.run();
}
);

// Step 1: Fetch progress to be persisted
GetDataFrameAnalyticsStatsAction.Request getStatsRequest = new GetDataFrameAnalyticsStatsAction.Request(jobId);
executeAsyncWithOrigin(client, ML_ORIGIN, GetDataFrameAnalyticsStatsAction.INSTANCE, getStatsRequest, getStatsListener);
// Step 1: Search for existing progress document in .ml-state*
SearchRequest searchRequest =
new SearchRequest(AnomalyDetectorsIndex.jobStateIndexPattern())
.source(
new SearchSourceBuilder()
.size(1)
.query(new IdsQueryBuilder().addIds(progressDocId)));
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, searchFormerProgressDocListener);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,21 @@
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.DeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.persistent.PersistentTasksService;
import org.elasticsearch.persistent.UpdatePersistentTaskStatusAction;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsActionResponseTests;
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.StartingState;
import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
Expand All @@ -36,6 +39,7 @@
import org.mockito.InOrder;
import org.mockito.stubbing.Answer;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
Expand Down Expand Up @@ -126,14 +130,24 @@ public void testDetermineStartingState_GivenEmptyProgress() {
assertThat(startingState, equalTo(StartingState.FINISHED));
}

private void testPersistProgress(SearchHits searchHits, String expectedIndexOrAlias) {
private void testPersistProgress(SearchHits searchHits, String expectedIndexOrAlias) throws IOException {
Client client = mock(Client.class);
ThreadPool threadPool = mock(ThreadPool.class);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
when(client.threadPool()).thenReturn(threadPool);

GetDataFrameAnalyticsStatsAction.Response getStatsResponse = GetDataFrameAnalyticsStatsActionResponseTests.randomResponse(1);
doAnswer(withResponse(getStatsResponse)).when(client).execute(eq(GetDataFrameAnalyticsStatsAction.INSTANCE), any(), any());
ClusterService clusterService = mock(ClusterService.class);
DataFrameAnalyticsManager analyticsManager = mock(DataFrameAnalyticsManager.class);
DataFrameAnalyticsAuditor auditor = mock(DataFrameAnalyticsAuditor.class);
PersistentTasksService persistentTasksService = new PersistentTasksService(clusterService, mock(ThreadPool.class), client);

List<PhaseProgress> progress = List.of(
new PhaseProgress(ProgressTracker.REINDEXING, 100),
new PhaseProgress(ProgressTracker.LOADING_DATA, 50),
new PhaseProgress(ProgressTracker.WRITING_RESULTS, 0));

StartDataFrameAnalyticsAction.TaskParams taskParams = new StartDataFrameAnalyticsAction.TaskParams(
"task_id", Version.CURRENT, progress, false);

SearchResponse searchResponse = mock(SearchResponse.class);
when(searchResponse.getHits()).thenReturn(searchHits);
Expand All @@ -142,14 +156,20 @@ private void testPersistProgress(SearchHits searchHits, String expectedIndexOrAl
IndexResponse indexResponse = mock(IndexResponse.class);
doAnswer(withResponse(indexResponse)).when(client).execute(eq(IndexAction.INSTANCE), any(), any());

TaskManager taskManager = mock(TaskManager.class);

Runnable runnable = mock(Runnable.class);

DataFrameAnalyticsTask.persistProgress(client, "task_id", runnable);
DataFrameAnalyticsTask task =
new DataFrameAnalyticsTask(
123, "type", "action", null, Map.of(), client, clusterService, analyticsManager, auditor, taskParams);
task.init(persistentTasksService, taskManager, "task-id", 42);

task.persistProgress(client, "task_id", runnable);

ArgumentCaptor<IndexRequest> indexRequestCaptor = ArgumentCaptor.forClass(IndexRequest.class);

InOrder inOrder = inOrder(client, runnable);
inOrder.verify(client).execute(eq(GetDataFrameAnalyticsStatsAction.INSTANCE), any(), any());
inOrder.verify(client).execute(eq(SearchAction.INSTANCE), any(), any());
inOrder.verify(client).execute(eq(IndexAction.INSTANCE), indexRequestCaptor.capture(), any());
inOrder.verify(runnable).run();
Expand All @@ -158,27 +178,33 @@ private void testPersistProgress(SearchHits searchHits, String expectedIndexOrAl
IndexRequest indexRequest = indexRequestCaptor.getValue();
assertThat(indexRequest.index(), equalTo(expectedIndexOrAlias));
assertThat(indexRequest.id(), equalTo("data_frame_analytics-task_id-progress"));

try (XContentParser parser = JsonXContent.jsonXContent.createParser(
NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, indexRequest.source().utf8ToString())) {
StoredProgress parsedProgress = StoredProgress.PARSER.apply(parser, null);
assertThat(parsedProgress.get(), equalTo(progress));
}
}

public void testPersistProgress_ProgressDocumentCreated() {
public void testPersistProgress_ProgressDocumentCreated() throws IOException {
testPersistProgress(SearchHits.empty(), ".ml-state-write");
}

public void testPersistProgress_ProgressDocumentUpdated() {
public void testPersistProgress_ProgressDocumentUpdated() throws IOException {
testPersistProgress(
new SearchHits(new SearchHit[]{ SearchHit.createFromMap(Map.of("_index", ".ml-state-dummy")) }, null, 0.0f),
".ml-state-dummy");
}

public void testSetFailed() {
public void testSetFailed() throws IOException {
testSetFailed(false);
}

public void testSetFailedDuringNodeShutdown() {
public void testSetFailedDuringNodeShutdown() throws IOException {
testSetFailed(true);
}

private void testSetFailed(boolean nodeShuttingDown) {
private void testSetFailed(boolean nodeShuttingDown) throws IOException {
ThreadPool threadPool = mock(ThreadPool.class);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
Client client = mock(Client.class);
Expand All @@ -190,15 +216,25 @@ private void testSetFailed(boolean nodeShuttingDown) {
PersistentTasksService persistentTasksService = new PersistentTasksService(clusterService, mock(ThreadPool.class), client);
TaskManager taskManager = mock(TaskManager.class);

List<PhaseProgress> progress = List.of(
new PhaseProgress(ProgressTracker.REINDEXING, 100),
new PhaseProgress(ProgressTracker.LOADING_DATA, 100),
new PhaseProgress(ProgressTracker.WRITING_RESULTS, 30));

StartDataFrameAnalyticsAction.TaskParams taskParams =
new StartDataFrameAnalyticsAction.TaskParams(
"job-id",
Version.CURRENT,
List.of(
new PhaseProgress(ProgressTracker.REINDEXING, 0),
new PhaseProgress(ProgressTracker.LOADING_DATA, 0),
new PhaseProgress(ProgressTracker.WRITING_RESULTS, 0)),
progress,
false);

SearchResponse searchResponse = mock(SearchResponse.class);
when(searchResponse.getHits()).thenReturn(SearchHits.empty());
doAnswer(withResponse(searchResponse)).when(client).execute(eq(SearchAction.INSTANCE), any(), any());

IndexResponse indexResponse = mock(IndexResponse.class);
doAnswer(withResponse(indexResponse)).when(client).execute(eq(IndexAction.INSTANCE), any(), any());

DataFrameAnalyticsTask task =
new DataFrameAnalyticsTask(
123, "type", "action", null, Map.of(), client, clusterService, analyticsManager, auditor, taskParams);
Expand All @@ -210,7 +246,23 @@ private void testSetFailed(boolean nodeShuttingDown) {
verify(analyticsManager).isNodeShuttingDown();
verify(client, atLeastOnce()).settings();
verify(client, atLeastOnce()).threadPool();

if (nodeShuttingDown == false) {
// Verify progress was persisted
ArgumentCaptor<IndexRequest> indexRequestCaptor = ArgumentCaptor.forClass(IndexRequest.class);
verify(client).execute(eq(SearchAction.INSTANCE), any(), any());
verify(client).execute(eq(IndexAction.INSTANCE), indexRequestCaptor.capture(), any());

IndexRequest indexRequest = indexRequestCaptor.getValue();
assertThat(indexRequest.index(), equalTo(AnomalyDetectorsIndex.jobStateIndexWriteAlias()));
assertThat(indexRequest.id(), equalTo("data_frame_analytics-job-id-progress"));

try (XContentParser parser = JsonXContent.jsonXContent.createParser(
NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, indexRequest.source().utf8ToString())) {
StoredProgress parsedProgress = StoredProgress.PARSER.apply(parser, null);
assertThat(parsedProgress.get(), equalTo(progress));
}

verify(client).execute(
same(UpdatePersistentTaskStatusAction.INSTANCE),
eq(new UpdatePersistentTaskStatusAction.Request(
Expand Down

0 comments on commit 2ba4e15

Please sign in to comment.