Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract TransportRequestDeduplication from ShardStateAction #37870

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -1192,12 +1192,12 @@ public void markShardCopyAsStaleIfNeeded(ShardId shardId, String allocationId, R
onSuccess.run();
}

protected final ShardStateAction.Listener createShardActionListener(final Runnable onSuccess,
protected final ActionListener<Void> createShardActionListener(final Runnable onSuccess,
final Consumer<Exception> onPrimaryDemoted,
final Consumer<Exception> onIgnoredFailure) {
return new ShardStateAction.Listener() {
return new ActionListener<Void>() {
@Override
public void onSuccess() {
public void onResponse(Void aVoid) {
onSuccess.run();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateObserver;
Expand All @@ -48,18 +49,17 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.node.NodeClosedException;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.ConnectTransportException;
import org.elasticsearch.transport.EmptyTransportResponseHandler;
import org.elasticsearch.transport.NodeDisconnectedException;
import org.elasticsearch.transport.RemoteTransportException;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportRequestDeduplicator;
import org.elasticsearch.transport.TransportRequestHandler;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.TransportService;
Expand All @@ -71,7 +71,6 @@
import java.util.Locale;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentMap;
import java.util.function.Predicate;

import static org.elasticsearch.index.seqno.SequenceNumbers.UNASSIGNED_PRIMARY_TERM;
Expand All @@ -89,7 +88,7 @@ public class ShardStateAction {

// a list of shards that failed during replication
// we keep track of these shards in order to avoid sending duplicate failed shard requests for a single failing shard.
private final ConcurrentMap<FailedShardEntry, CompositeListener> remoteFailedShardsCache = ConcurrentCollections.newConcurrentMap();
private final TransportRequestDeduplicator<FailedShardEntry> remoteFailedShardsDeduplicator = new TransportRequestDeduplicator<>();

@Inject
public ShardStateAction(ClusterService clusterService, TransportService transportService,
Expand All @@ -106,7 +105,7 @@ public ShardStateAction(ClusterService clusterService, TransportService transpor
}

private void sendShardAction(final String actionName, final ClusterState currentState,
final TransportRequest request, final Listener listener) {
final TransportRequest request, final ActionListener<Void> listener) {
ClusterStateObserver observer =
new ClusterStateObserver(currentState, clusterService, null, logger, threadPool.getThreadContext());
DiscoveryNode masterNode = currentState.nodes().getMasterNode();
Expand All @@ -120,7 +119,7 @@ private void sendShardAction(final String actionName, final ClusterState current
actionName, request, new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {
@Override
public void handleResponse(TransportResponse.Empty response) {
listener.onSuccess();
listener.onResponse(null);
}

@Override
Expand Down Expand Up @@ -163,60 +162,39 @@ private static boolean isMasterChannelException(TransportException exp) {
* @param listener callback upon completion of the request
*/
public void remoteShardFailed(final ShardId shardId, String allocationId, long primaryTerm, boolean markAsStale, final String message,
@Nullable final Exception failure, Listener listener) {
@Nullable final Exception failure, ActionListener<Void> listener) {
assert primaryTerm > 0L : "primary term should be strictly positive";
final FailedShardEntry shardEntry = new FailedShardEntry(shardId, allocationId, primaryTerm, message, failure, markAsStale);
final CompositeListener compositeListener = new CompositeListener(listener);
final CompositeListener existingListener = remoteFailedShardsCache.putIfAbsent(shardEntry, compositeListener);
if (existingListener == null) {
sendShardAction(SHARD_FAILED_ACTION_NAME, clusterService.state(), shardEntry, new Listener() {
@Override
public void onSuccess() {
try {
compositeListener.onSuccess();
} finally {
remoteFailedShardsCache.remove(shardEntry);
}
}
@Override
public void onFailure(Exception e) {
try {
compositeListener.onFailure(e);
} finally {
remoteFailedShardsCache.remove(shardEntry);
}
}
});
} else {
existingListener.addListener(listener);
}
remoteFailedShardsDeduplicator.executeOnce(
new FailedShardEntry(shardId, allocationId, primaryTerm, message, failure, markAsStale), listener,
(req, reqListener) -> sendShardAction(SHARD_FAILED_ACTION_NAME, clusterService.state(), req, reqListener));
}

int remoteShardFailedCacheSize() {
return remoteFailedShardsCache.size();
return remoteFailedShardsDeduplicator.size();
}

/**
* Send a shard failed request to the master node to update the cluster state when a shard on the local node failed.
*/
public void localShardFailed(final ShardRouting shardRouting, final String message,
@Nullable final Exception failure, Listener listener) {
@Nullable final Exception failure, ActionListener<Void> listener) {
localShardFailed(shardRouting, message, failure, listener, clusterService.state());
}

/**
* Send a shard failed request to the master node to update the cluster state when a shard on the local node failed.
*/
public void localShardFailed(final ShardRouting shardRouting, final String message, @Nullable final Exception failure,
Listener listener, final ClusterState currentState) {
ActionListener<Void> listener, final ClusterState currentState) {
FailedShardEntry shardEntry = new FailedShardEntry(shardRouting.shardId(), shardRouting.allocationId().getId(),
0L, message, failure, true);
sendShardAction(SHARD_FAILED_ACTION_NAME, currentState, shardEntry, listener);
}

// visible for testing
protected void waitForNewMasterAndRetry(String actionName, ClusterStateObserver observer,
TransportRequest request, Listener listener, Predicate<ClusterState> changePredicate) {
TransportRequest request, ActionListener<Void> listener,
Predicate<ClusterState> changePredicate) {
observer.waitForNextChange(new ClusterStateObserver.Listener() {
@Override
public void onNewClusterState(ClusterState state) {
Expand Down Expand Up @@ -497,14 +475,14 @@ public int hashCode() {
public void shardStarted(final ShardRouting shardRouting,
final long primaryTerm,
final String message,
final Listener listener) {
final ActionListener<Void> listener) {
shardStarted(shardRouting, primaryTerm, message, listener, clusterService.state());
}

public void shardStarted(final ShardRouting shardRouting,
final long primaryTerm,
final String message,
final Listener listener,
final ActionListener<Void> listener,
final ClusterState currentState) {
StartedShardEntry entry = new StartedShardEntry(shardRouting.shardId(), shardRouting.allocationId().getId(), primaryTerm, message);
sendShardAction(SHARD_STARTED_ACTION_NAME, currentState, entry, listener);
Expand Down Expand Up @@ -670,97 +648,6 @@ public String toString() {
}
}

public interface Listener {

default void onSuccess() {
}

/**
* Notification for non-channel exceptions that are not handled
* by {@link ShardStateAction}.
*
* The exceptions that are handled by {@link ShardStateAction}
* are:
* - {@link NotMasterException}
* - {@link NodeDisconnectedException}
* - {@link FailedToCommitClusterStateException}
*
* Any other exception is communicated to the requester via
* this notification.
*
* @param e the unexpected cause of the failure on the master
*/
default void onFailure(final Exception e) {
}

}

/**
* A composite listener that allows registering multiple listeners dynamically.
*/
static final class CompositeListener implements Listener {
private boolean isNotified = false;
private Exception failure = null;
private final List<Listener> listeners = new ArrayList<>();

CompositeListener(Listener listener) {
listeners.add(listener);
}

void addListener(Listener listener) {
final boolean ready;
synchronized (this) {
ready = this.isNotified;
if (ready == false) {
listeners.add(listener);
}
}
if (ready) {
if (failure != null) {
listener.onFailure(failure);
} else {
listener.onSuccess();
}
}
}

private void onCompleted(Exception failure) {
synchronized (this) {
this.failure = failure;
this.isNotified = true;
}
RuntimeException firstException = null;
for (Listener listener : listeners) {
try {
if (failure != null) {
listener.onFailure(failure);
} else {
listener.onSuccess();
}
} catch (RuntimeException innerEx) {
if (firstException == null) {
firstException = innerEx;
} else {
firstException.addSuppressed(innerEx);
}
}
}
if (firstException != null) {
throw firstException;
}
}

@Override
public void onSuccess() {
onCompleted(null);
}

@Override
public void onFailure(Exception failure) {
onCompleted(failure);
}
}

public static class NoLongerPrimaryShardException extends ElasticsearchException {

public NoLongerPrimaryShardException(ShardId shardId, String msg) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,7 @@ public class IndicesClusterStateService extends AbstractLifecycleComponent imple
private final ShardStateAction shardStateAction;
private final NodeMappingRefreshAction nodeMappingRefreshAction;

private static final ShardStateAction.Listener SHARD_STATE_ACTION_LISTENER = new ShardStateAction.Listener() {
};
private static final ActionListener<Void> SHARD_STATE_ACTION_LISTENER = ActionListener.wrap(() -> {});

private final Settings settings;
// a list of shards that failed during recovery
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.elasticsearch.transport;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentMap;
import java.util.function.BiConsumer;

/**
* Deduplicator for {@link TransportRequest}s that keeps track of {@link TransportRequest}s that should
* not be sent in parallel.
* @param <T> Transport Request Class
*/
public final class TransportRequestDeduplicator<T extends TransportRequest> {

private final ConcurrentMap<T, CompositeListener> requests = ConcurrentCollections.newConcurrentMap();

/**
* Ensures a given request not executed multiple times when another equal request is already in-flight.
* If the request is not yet known to the deduplicator it will invoke the passed callback with an {@link ActionListener}
* that must be completed by the caller when the request completes. Once that listener is completed the request will be removed from
* the deduplicator's internal state. If the request is already known to the deduplicator it will keep
* track of the given listener and invoke it when the listener passed to the callback on first invocation is completed.
* @param request Request to deduplicate
* @param listener Listener to invoke on request completion
* @param callback Callback to be invoked with request and completion listener the first time the request is added to the deduplicator
*/
public void executeOnce(T request, ActionListener<Void> listener, BiConsumer<T, ActionListener<Void>> callback) {
ActionListener<Void> completionListener = requests.computeIfAbsent(request, CompositeListener::new).addListener(listener);
if (completionListener != null) {
callback.accept(request, completionListener);
}
}

public int size() {
return requests.size();
}

private final class CompositeListener implements ActionListener<Void> {

private final List<ActionListener<Void>> listeners = new ArrayList<>();

private final T request;

private boolean isNotified;
private Exception failure;

CompositeListener(T request) {
this.request = request;
}

CompositeListener addListener(ActionListener<Void> listener) {
synchronized (this) {
if (this.isNotified == false) {
listeners.add(listener);
return listeners.size() == 1 ? this : null;
}
}
if (failure != null) {
listener.onFailure(failure);
} else {
listener.onResponse(null);
}
return null;
}

private void onCompleted(Exception failure) {
synchronized (this) {
this.failure = failure;
this.isNotified = true;
}
try {
if (failure == null) {
ActionListener.onResponse(listeners, null);
} else {
ActionListener.onFailure(listeners, failure);
}
} finally {
requests.remove(request);
}
}

@Override
public void onResponse(final Void aVoid) {
onCompleted(null);
}

@Override
public void onFailure(Exception failure) {
onCompleted(failure);
}
}
}
Loading