diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java index 9f925fc21971..334f7dd2f2e2 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java @@ -290,7 +290,7 @@ void persistProgress(Client client, String jobId, Runnable runnable) { String progressDocId = StoredProgress.documentId(jobId); - // Step 3: Run the runnable provided as the argument + // Step 4: Run the runnable provided as the argument ActionListener indexProgressDocListener = ActionListener.wrap( indexResponse -> { LOGGER.debug("[{}] Successfully indexed progress document", jobId); @@ -303,7 +303,7 @@ void persistProgress(Client client, String jobId, Runnable runnable) { } ); - // Step 2: Create or update the progress document: + // Step 3: Create or update the progress document: // - if the document did not exist, create the new one in the current write index // - if the document did exist, update it in the index where it resides (not necessarily the current write index) ActionListener searchFormerProgressDocListener = ActionListener.wrap( @@ -331,14 +331,26 @@ void persistProgress(Client client, String jobId, Runnable runnable) { } ); - // Step 1: Search for existing progress document in .ml-state* - SearchRequest searchRequest = - new SearchRequest(AnomalyDetectorsIndex.jobStateIndexPattern()) - .source( - new SearchSourceBuilder() - .size(1) - .query(new IdsQueryBuilder().addIds(progressDocId))); - executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, searchFormerProgressDocListener); + // Step 2: Search for existing progress document in .ml-state* + ActionListener reindexProgressUpdateListener = ActionListener.wrap( + aVoid -> { + SearchRequest searchRequest = + new SearchRequest(AnomalyDetectorsIndex.jobStateIndexPattern()) + .source( + new SearchSourceBuilder() + .size(1) + .query(new IdsQueryBuilder().addIds(progressDocId))); + executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, searchFormerProgressDocListener); + }, + e -> { + LOGGER.error(new ParameterizedMessage( + "[{}] cannot persist progress as an error occurred while updating reindexing task progress", taskParams.getId()), e); + runnable.run(); + } + ); + + // Step 1: Update reindexing progress as it could be stale + updateReindexTaskProgress(reindexProgressUpdateListener); } /** diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTaskTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTaskTests.java index b9695352062b..12de679c1b1d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTaskTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTaskTests.java @@ -46,6 +46,7 @@ import java.util.Map; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; import static org.mockito.Matchers.same; @@ -216,8 +217,9 @@ private void testSetFailed(boolean nodeShuttingDown) throws IOException { PersistentTasksService persistentTasksService = new PersistentTasksService(clusterService, mock(ThreadPool.class), client); TaskManager taskManager = mock(TaskManager.class); + // We leave reindexing progress here to zero in order to check it is updated before it is persisted List progress = List.of( - new PhaseProgress(ProgressTracker.REINDEXING, 100), + new PhaseProgress(ProgressTracker.REINDEXING, 0), new PhaseProgress(ProgressTracker.LOADING_DATA, 100), new PhaseProgress(ProgressTracker.WRITING_RESULTS, 30)); @@ -239,6 +241,7 @@ private void testSetFailed(boolean nodeShuttingDown) throws IOException { new DataFrameAnalyticsTask( 123, "type", "action", null, Map.of(), client, clusterService, analyticsManager, auditor, taskParams); task.init(persistentTasksService, taskManager, "task-id", 42); + task.setReindexingFinished(); Exception exception = new Exception("some exception"); task.setFailed(exception); @@ -260,7 +263,8 @@ private void testSetFailed(boolean nodeShuttingDown) throws IOException { try (XContentParser parser = JsonXContent.jsonXContent.createParser( NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, indexRequest.source().utf8ToString())) { StoredProgress parsedProgress = StoredProgress.PARSER.apply(parser, null); - assertThat(parsedProgress.get(), equalTo(progress)); + assertThat(parsedProgress.get(), hasSize(3)); + assertThat(parsedProgress.get().get(0), equalTo(new PhaseProgress("reindexing", 100))); } verify(client).execute( @@ -269,7 +273,7 @@ private void testSetFailed(boolean nodeShuttingDown) throws IOException { "task-id", 42, new DataFrameAnalyticsTaskState(DataFrameAnalyticsState.FAILED, 42, "some exception"))), any()); } - verifyNoMoreInteractions(client, clusterService, analyticsManager, auditor, taskManager); + verifyNoMoreInteractions(client, analyticsManager, auditor, taskManager); } @SuppressWarnings("unchecked")