diff --git a/server/src/main/java/org/elasticsearch/cluster/service/MasterService.java b/server/src/main/java/org/elasticsearch/cluster/service/MasterService.java index 1757548c28b09..4432d864fd36a 100644 --- a/server/src/main/java/org/elasticsearch/cluster/service/MasterService.java +++ b/server/src/main/java/org/elasticsearch/cluster/service/MasterService.java @@ -50,7 +50,6 @@ import org.elasticsearch.discovery.Discovery; import org.elasticsearch.threadpool.ThreadPool; -import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -365,28 +364,11 @@ public void clusterStatePublished(ClusterChangedEvent clusterChangedEvent) { } public Discovery.AckListener createAckListener(ThreadPool threadPool, ClusterState newClusterState) { - ArrayList ackListeners = new ArrayList<>(); - - //timeout straightaway, otherwise we could wait forever as the timeout thread has not started - nonFailedTasks.stream().filter(task -> task.listener instanceof AckedClusterStateTaskListener).forEach(task -> { - final AckedClusterStateTaskListener ackedListener = (AckedClusterStateTaskListener) task.listener; - if (ackedListener.ackTimeout() == null || ackedListener.ackTimeout().millis() == 0) { - ackedListener.onAckTimeout(); - } else { - try { - ackListeners.add(new AckCountDownListener(ackedListener, newClusterState.version(), newClusterState.nodes(), - threadPool)); - } catch (EsRejectedExecutionException ex) { - if (logger.isDebugEnabled()) { - logger.debug("Couldn't schedule timeout thread - node might be shutting down", ex); - } - //timeout straightaway, otherwise we could wait forever as the timeout thread has not started - ackedListener.onAckTimeout(); - } - } - }); - - return new DelegatingAckListener(ackListeners); + return new DelegatingAckListener(nonFailedTasks.stream() + .filter(task -> task.listener instanceof AckedClusterStateTaskListener) + .map(task -> new AckCountDownListener((AckedClusterStateTaskListener) task.listener, newClusterState.version(), + newClusterState.nodes(), threadPool)) + .collect(Collectors.toList())); } public boolean clusterStateUnchanged() { @@ -549,6 +531,13 @@ private DelegatingAckListener(List listeners) { this.listeners = listeners; } + @Override + public void onCommit(TimeValue commitTime) { + for (Discovery.AckListener listener : listeners) { + listener.onCommit(commitTime); + } + } + @Override public void onNodeAck(DiscoveryNode node, @Nullable Exception e) { for (Discovery.AckListener listener : listeners) { @@ -564,14 +553,16 @@ private static class AckCountDownListener implements Discovery.AckListener { private final AckedClusterStateTaskListener ackedTaskListener; private final CountDown countDown; private final DiscoveryNode masterNode; + private final ThreadPool threadPool; private final long clusterStateVersion; - private final Future ackTimeoutCallback; + private volatile Future ackTimeoutCallback; private Exception lastFailure; AckCountDownListener(AckedClusterStateTaskListener ackedTaskListener, long clusterStateVersion, DiscoveryNodes nodes, ThreadPool threadPool) { this.ackedTaskListener = ackedTaskListener; this.clusterStateVersion = clusterStateVersion; + this.threadPool = threadPool; this.masterNode = nodes.getMasterNode(); int countDown = 0; for (DiscoveryNode node : nodes) { @@ -581,8 +572,27 @@ private static class AckCountDownListener implements Discovery.AckListener { } } logger.trace("expecting {} acknowledgements for cluster_state update (version: {})", countDown, clusterStateVersion); - this.countDown = new CountDown(countDown); - this.ackTimeoutCallback = threadPool.schedule(ackedTaskListener.ackTimeout(), ThreadPool.Names.GENERIC, () -> onTimeout()); + this.countDown = new CountDown(countDown + 1); // we also wait for onCommit to be called + } + + @Override + public void onCommit(TimeValue commitTime) { + TimeValue ackTimeout = ackedTaskListener.ackTimeout(); + if (ackTimeout == null) { + ackTimeout = TimeValue.ZERO; + } + final TimeValue timeLeft = TimeValue.timeValueNanos(Math.max(0, ackTimeout.nanos() - commitTime.nanos())); + if (timeLeft.nanos() == 0L) { + onTimeout(); + } else if (countDown.countDown()) { + finish(); + } else { + this.ackTimeoutCallback = threadPool.schedule(timeLeft, ThreadPool.Names.GENERIC, this::onTimeout); + // re-check if onNodeAck has not completed while we were scheduling the timeout + if (countDown.isCountedDown()) { + FutureUtils.cancel(ackTimeoutCallback); + } + } } @Override @@ -599,12 +609,16 @@ public void onNodeAck(DiscoveryNode node, @Nullable Exception e) { } if (countDown.countDown()) { - logger.trace("all expected nodes acknowledged cluster_state update (version: {})", clusterStateVersion); - FutureUtils.cancel(ackTimeoutCallback); - ackedTaskListener.onAllNodesAcked(lastFailure); + finish(); } } + private void finish() { + logger.trace("all expected nodes acknowledged cluster_state update (version: {})", clusterStateVersion); + FutureUtils.cancel(ackTimeoutCallback); + ackedTaskListener.onAllNodesAcked(lastFailure); + } + public void onTimeout() { if (countDown.fastForward()) { logger.trace("timeout waiting for acknowledgement for cluster_state update (version: {})", clusterStateVersion); diff --git a/server/src/main/java/org/elasticsearch/discovery/Discovery.java b/server/src/main/java/org/elasticsearch/discovery/Discovery.java index 9c70876032442..b58f61bac89bb 100644 --- a/server/src/main/java/org/elasticsearch/discovery/Discovery.java +++ b/server/src/main/java/org/elasticsearch/discovery/Discovery.java @@ -25,6 +25,7 @@ import org.elasticsearch.common.Nullable; import org.elasticsearch.common.component.LifecycleComponent; import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.unit.TimeValue; import java.io.IOException; @@ -48,6 +49,19 @@ public interface Discovery extends LifecycleComponent { void publish(ClusterChangedEvent clusterChangedEvent, AckListener ackListener); interface AckListener { + /** + * Should be called when the discovery layer has committed the clusters state (i.e. even if this publication fails, + * it is guaranteed to appear in future publications). + * @param commitTime the time it took to commit the cluster state + */ + void onCommit(TimeValue commitTime); + + /** + * Should be called whenever the discovery layer receives confirmation from a node that it has successfully applied + * the cluster state. In case of failures, an exception should be provided as parameter. + * @param node the node + * @param e the optional exception + */ void onNodeAck(DiscoveryNode node, @Nullable Exception e); } diff --git a/server/src/main/java/org/elasticsearch/discovery/single/SingleNodeDiscovery.java b/server/src/main/java/org/elasticsearch/discovery/single/SingleNodeDiscovery.java index cd775e29f5a2f..d7c37febb5dd2 100644 --- a/server/src/main/java/org/elasticsearch/discovery/single/SingleNodeDiscovery.java +++ b/server/src/main/java/org/elasticsearch/discovery/single/SingleNodeDiscovery.java @@ -30,6 +30,7 @@ import org.elasticsearch.cluster.service.MasterService; import org.elasticsearch.common.component.AbstractLifecycleComponent; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.discovery.Discovery; import org.elasticsearch.discovery.DiscoveryStats; import org.elasticsearch.transport.TransportService; @@ -61,6 +62,7 @@ public SingleNodeDiscovery(final Settings settings, final TransportService trans public synchronized void publish(final ClusterChangedEvent event, final AckListener ackListener) { clusterState = event.state(); + ackListener.onCommit(TimeValue.ZERO); CountDownLatch latch = new CountDownLatch(1); ClusterApplyListener listener = new ClusterApplyListener() { diff --git a/server/src/main/java/org/elasticsearch/discovery/zen/PublishClusterStateAction.java b/server/src/main/java/org/elasticsearch/discovery/zen/PublishClusterStateAction.java index cd87a41526313..5398b2a057ae4 100644 --- a/server/src/main/java/org/elasticsearch/discovery/zen/PublishClusterStateAction.java +++ b/server/src/main/java/org/elasticsearch/discovery/zen/PublishClusterStateAction.java @@ -158,7 +158,8 @@ public void publish(final ClusterChangedEvent clusterChangedEvent, final int min } try { - innerPublish(clusterChangedEvent, nodesToPublishTo, sendingController, sendFullVersion, serializedStates, serializedDiffs); + innerPublish(clusterChangedEvent, nodesToPublishTo, sendingController, ackListener, sendFullVersion, serializedStates, + serializedDiffs); } catch (Discovery.FailedToCommitClusterStateException t) { throw t; } catch (Exception e) { @@ -173,8 +174,9 @@ public void publish(final ClusterChangedEvent clusterChangedEvent, final int min } private void innerPublish(final ClusterChangedEvent clusterChangedEvent, final Set nodesToPublishTo, - final SendingController sendingController, final boolean sendFullVersion, - final Map serializedStates, final Map serializedDiffs) { + final SendingController sendingController, final Discovery.AckListener ackListener, + final boolean sendFullVersion, final Map serializedStates, + final Map serializedDiffs) { final ClusterState clusterState = clusterChangedEvent.state(); final ClusterState previousState = clusterChangedEvent.previousState(); @@ -195,8 +197,12 @@ private void innerPublish(final ClusterChangedEvent clusterChangedEvent, final S sendingController.waitForCommit(discoverySettings.getCommitTimeout()); + final long commitTime = System.nanoTime() - publishingStartInNanos; + + ackListener.onCommit(TimeValue.timeValueNanos(commitTime)); + try { - long timeLeftInNanos = Math.max(0, publishTimeout.nanos() - (System.nanoTime() - publishingStartInNanos)); + long timeLeftInNanos = Math.max(0, publishTimeout.nanos() - commitTime); final BlockingClusterStatePublishResponseHandler publishResponseHandler = sendingController.getPublishResponseHandler(); sendingController.setPublishingTimedOut(!publishResponseHandler.awaitAllNodes(TimeValue.timeValueNanos(timeLeftInNanos))); if (sendingController.getPublishingTimedOut()) { diff --git a/server/src/test/java/org/elasticsearch/cluster/service/MasterServiceTests.java b/server/src/test/java/org/elasticsearch/cluster/service/MasterServiceTests.java index 1b747f2268747..f75363c7ab5c7 100644 --- a/server/src/test/java/org/elasticsearch/cluster/service/MasterServiceTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/service/MasterServiceTests.java @@ -22,6 +22,7 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.Version; +import org.elasticsearch.cluster.AckedClusterStateUpdateTask; import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; @@ -39,6 +40,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.concurrent.BaseFuture; +import org.elasticsearch.discovery.Discovery; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.MockLogAppender; import org.elasticsearch.test.junit.annotations.TestLogging; @@ -65,6 +67,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; import static java.util.Collections.emptyMap; import static java.util.Collections.emptySet; @@ -680,6 +683,132 @@ public void onFailure(String source, Exception e) { mockAppender.assertAllExpectationsMatched(); } + public void testAcking() throws InterruptedException { + final DiscoveryNode node1 = new DiscoveryNode("node1", buildNewFakeTransportAddress(), emptyMap(), emptySet(), Version.CURRENT); + final DiscoveryNode node2 = new DiscoveryNode("node2", buildNewFakeTransportAddress(), emptyMap(), emptySet(), Version.CURRENT); + final DiscoveryNode node3 = new DiscoveryNode("node3", buildNewFakeTransportAddress(), emptyMap(), emptySet(), Version.CURRENT); + TimedMasterService timedMasterService = new TimedMasterService(Settings.builder().put("cluster.name", + MasterServiceTests.class.getSimpleName()).build(), threadPool); + ClusterState initialClusterState = ClusterState.builder(new ClusterName(MasterServiceTests.class.getSimpleName())) + .nodes(DiscoveryNodes.builder() + .add(node1) + .add(node2) + .add(node3) + .localNodeId(node1.getId()) + .masterNodeId(node1.getId())) + .blocks(ClusterBlocks.EMPTY_CLUSTER_BLOCK).build(); + final AtomicReference> publisherRef = new AtomicReference<>(); + timedMasterService.setClusterStatePublisher((cce, l) -> publisherRef.get().accept(cce, l)); + timedMasterService.setClusterStateSupplier(() -> initialClusterState); + timedMasterService.start(); + + + // check that we don't time out before even committing the cluster state + { + final CountDownLatch latch = new CountDownLatch(1); + + publisherRef.set((clusterChangedEvent, ackListener) -> { + throw new Discovery.FailedToCommitClusterStateException("mock exception"); + }); + + timedMasterService.submitStateUpdateTask("test2", new AckedClusterStateUpdateTask(null, null) { + @Override + public ClusterState execute(ClusterState currentState) { + return ClusterState.builder(currentState).build(); + } + + @Override + public TimeValue ackTimeout() { + return TimeValue.ZERO; + } + + @Override + public TimeValue timeout() { + return null; + } + + @Override + public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) { + fail(); + } + + @Override + protected Void newResponse(boolean acknowledged) { + fail(); + return null; + } + + @Override + public void onFailure(String source, Exception e) { + latch.countDown(); + } + + @Override + public void onAckTimeout() { + fail(); + } + }); + + latch.await(); + } + + // check that we timeout if commit took too long + { + final CountDownLatch latch = new CountDownLatch(2); + + final TimeValue ackTimeout = TimeValue.timeValueMillis(randomInt(100)); + + publisherRef.set((clusterChangedEvent, ackListener) -> { + ackListener.onCommit(TimeValue.timeValueMillis(ackTimeout.millis() + randomInt(100))); + ackListener.onNodeAck(node1, null); + ackListener.onNodeAck(node2, null); + ackListener.onNodeAck(node3, null); + }); + + timedMasterService.submitStateUpdateTask("test2", new AckedClusterStateUpdateTask(null, null) { + @Override + public ClusterState execute(ClusterState currentState) { + return ClusterState.builder(currentState).build(); + } + + @Override + public TimeValue ackTimeout() { + return ackTimeout; + } + + @Override + public TimeValue timeout() { + return null; + } + + @Override + public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) { + latch.countDown(); + } + + @Override + protected Void newResponse(boolean acknowledged) { + fail(); + return null; + } + + @Override + public void onFailure(String source, Exception e) { + fail(); + } + + @Override + public void onAckTimeout() { + latch.countDown(); + } + }); + + latch.await(); + } + + timedMasterService.close(); + } + static class TimedMasterService extends MasterService { public volatile Long currentTimeOverride = null; diff --git a/server/src/test/java/org/elasticsearch/discovery/zen/PublishClusterStateActionTests.java b/server/src/test/java/org/elasticsearch/discovery/zen/PublishClusterStateActionTests.java index c8e85382994c7..ac1719269e7ae 100644 --- a/server/src/test/java/org/elasticsearch/discovery/zen/PublishClusterStateActionTests.java +++ b/server/src/test/java/org/elasticsearch/discovery/zen/PublishClusterStateActionTests.java @@ -42,6 +42,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.discovery.Discovery; import org.elasticsearch.discovery.DiscoverySettings; import org.elasticsearch.node.Node; @@ -815,9 +816,16 @@ public AssertingAckListener publishState(PublishClusterStateAction action, Clust public static class AssertingAckListener implements Discovery.AckListener { private final List> errors = new CopyOnWriteArrayList<>(); private final CountDownLatch countDown; + private final CountDownLatch commitCountDown; public AssertingAckListener(int nodeCount) { countDown = new CountDownLatch(nodeCount); + commitCountDown = new CountDownLatch(1); + } + + @Override + public void onCommit(TimeValue commitTime) { + commitCountDown.countDown(); } @Override @@ -830,6 +838,7 @@ public void onNodeAck(DiscoveryNode node, @Nullable Exception e) { public void await(long timeout, TimeUnit unit) throws InterruptedException { assertThat(awaitErrors(timeout, unit), emptyIterable()); + assertTrue(commitCountDown.await(timeout, unit)); } public List> awaitErrors(long timeout, TimeUnit unit) throws InterruptedException {