Skip to content

Commit

Permalink
Check for closed connection while opening
Browse files Browse the repository at this point in the history
While opening a connection to a node, a channel can subsequently
close. If this happens, a future callback whose purpose is to close all
other channels and disconnect from the node will fire. However, this
future will not be ready to close all the channels because the
connection will not be exposed to the future callback yet. Since this
callback is run once, we will never try to disconnect from this node
again and we will be left with a closed channel. This commit adds a
check that all channels are open before exposing the channel and throws
a general connection exception. In this case, the usual connection retry
logic will take over.

Relates #26932
  • Loading branch information
jasontedor authored Oct 10, 2017
1 parent d6fc4af commit 4c06b8f
Show file tree
Hide file tree
Showing 10 changed files with 116 additions and 21 deletions.
16 changes: 15 additions & 1 deletion core/src/main/java/org/elasticsearch/transport/TcpTransport.java
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,9 @@ public final NodeChannels openConnection(DiscoveryNode node, ConnectionProfile c
nodeChannels = new NodeChannels(nodeChannels, version); // clone the channels - we now have the correct version
transportService.onConnectionOpened(nodeChannels);
connectionRef.set(nodeChannels);
if (Arrays.stream(nodeChannels.channels).allMatch(this::isOpen) == false) {
throw new ConnectTransportException(node, "a channel closed while connecting");
}
success = true;
return nodeChannels;
} catch (ConnectTransportException e) {
Expand Down Expand Up @@ -1034,7 +1037,18 @@ protected void innerOnFailure(Exception e) {
*/
protected abstract void sendMessage(Channel channel, BytesReference reference, ActionListener<Channel> listener);

protected abstract NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile connectionProfile,
/**
* Connect to the node with channels as defined by the specified connection profile. Implementations must invoke the specified channel
* close callback when a channel is closed.
*
* @param node the node to connect to
* @param connectionProfile the connection profile
* @param onChannelClose callback to invoke when a channel is closed
* @return the channels
* @throws IOException if an I/O exception occurs while opening channels
*/
protected abstract NodeChannels connectToChannels(DiscoveryNode node,
ConnectionProfile connectionProfile,
Consumer<Channel> onChannelClose) throws IOException;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import java.util.Objects;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.function.Function;
import java.util.function.Predicate;
Expand Down Expand Up @@ -187,6 +188,15 @@ protected TaskManager createTaskManager() {
return new TaskManager(settings);
}

/**
* The executor service for this transport service.
*
* @return the executor service
*/
protected ExecutorService getExecutorService() {
return threadPool.generic();
}

void setTracerLogInclude(List<String> tracerLogInclude) {
this.tracerLogInclude = tracerLogInclude.toArray(Strings.EMPTY_ARRAY);
}
Expand Down Expand Up @@ -232,7 +242,7 @@ protected void doStop() {
if (holderToNotify != null) {
// callback that an exception happened, but on a different thread since we don't
// want handlers to worry about stack overflows
threadPool.generic().execute(new AbstractRunnable() {
getExecutorService().execute(new AbstractRunnable() {
@Override
public void onRejection(Exception e) {
// if we get rejected during node shutdown we don't wanna bubble it up
Expand Down Expand Up @@ -879,20 +889,20 @@ void onNodeConnected(final DiscoveryNode node) {
// connectToNode(); connection is completed successfully
// addConnectionListener(); this listener shouldn't be called
final Stream<TransportConnectionListener> listenersToNotify = TransportService.this.connectionListeners.stream();
threadPool.generic().execute(() -> listenersToNotify.forEach(listener -> listener.onNodeConnected(node)));
getExecutorService().execute(() -> listenersToNotify.forEach(listener -> listener.onNodeConnected(node)));
}

void onConnectionOpened(Transport.Connection connection) {
// capture listeners before spawning the background callback so the following pattern won't trigger a call
// connectToNode(); connection is completed successfully
// addConnectionListener(); this listener shouldn't be called
final Stream<TransportConnectionListener> listenersToNotify = TransportService.this.connectionListeners.stream();
threadPool.generic().execute(() -> listenersToNotify.forEach(listener -> listener.onConnectionOpened(connection)));
getExecutorService().execute(() -> listenersToNotify.forEach(listener -> listener.onConnectionOpened(connection)));
}

public void onNodeDisconnected(final DiscoveryNode node) {
try {
threadPool.generic().execute( () -> {
getExecutorService().execute( () -> {
for (final TransportConnectionListener connectionListener : connectionListeners) {
connectionListener.onNodeDisconnected(node);
}
Expand All @@ -911,7 +921,7 @@ void onConnectionClosed(Transport.Connection connection) {
if (holderToNotify != null) {
// callback that an exception happened, but on a different thread since we don't
// want handlers to worry about stack overflows
threadPool.generic().execute(() -> holderToNotify.handler().handleException(new NodeDisconnectedException(
getExecutorService().execute(() -> holderToNotify.handler().handleException(new NodeDisconnectedException(
connection.getNode(), holderToNotify.action())));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,8 @@ protected void sendMessage(Object o, BytesReference reference, ActionListener li
}

@Override
protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile,
Consumer onChannelClose) throws IOException {
protected NodeChannels connectToChannels(
DiscoveryNode node, ConnectionProfile profile, Consumer onChannelClose) throws IOException {
return new NodeChannels(node, new Object[profile.getNumConnections()], profile);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
public class SimpleNetty4TransportTests extends AbstractSimpleTransportTestCase {

public static MockTransportService nettyFromThreadPool(Settings settings, ThreadPool threadPool, final Version version,
ClusterSettings clusterSettings, boolean doHandshake) {
ClusterSettings clusterSettings, boolean doHandshake) {
NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList());
Transport transport = new Netty4Transport(settings, threadPool, new NetworkService(Collections.emptyList()),
BigArrays.NON_RECYCLING_INSTANCE, namedWriteableRegistry, new NoneCircuitBreakerService()) {
Expand Down Expand Up @@ -86,6 +86,13 @@ protected MockTransportService build(Settings settings, Version version, Cluster
return transportService;
}

@Override
protected void closeConnectionChannel(Transport transport, Transport.Connection connection) throws IOException {
final Netty4Transport t = (Netty4Transport) transport;
@SuppressWarnings("unchecked") final TcpTransport<Channel>.NodeChannels channels = (TcpTransport<Channel>.NodeChannels) connection;
t.closeChannels(channels.getChannels().subList(0, randomIntBetween(1, channels.getChannels().size())), true, false);
}

public void testConnectException() throws UnknownHostException {
try {
serviceA.connectToNode(new DiscoveryNode("C", new TransportAddress(InetAddress.getByName("localhost"), 9876),
Expand All @@ -108,7 +115,8 @@ public void testBindUnavailableAddress() {
.build();
ClusterSettings clusterSettings = new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
BindTransportException bindTransportException = expectThrows(BindTransportException.class, () -> {
MockTransportService transportService = nettyFromThreadPool(settings, threadPool, Version.CURRENT, clusterSettings, true);
MockTransportService transportService =
nettyFromThreadPool(settings, threadPool, Version.CURRENT, clusterSettings, true);
try {
transportService.start();
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
import java.util.Set;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
Expand Down Expand Up @@ -167,6 +168,17 @@ protected TaskManager createTaskManager() {
}
}

private volatile String executorName;

public void setExecutorName(final String executorName) {
this.executorName = executorName;
}

@Override
protected ExecutorService getExecutorService() {
return executorName == null ? super.getExecutorService() : getThreadPool().executor(executorName);
}

/**
* Clears all the registered rules.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,10 @@

import static java.util.Collections.emptyMap;
import static java.util.Collections.emptySet;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasToString;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.startsWith;
Expand Down Expand Up @@ -147,14 +149,14 @@ public void onNodeDisconnected(DiscoveryNode node) {
private MockTransportService buildService(final String name, final Version version, ClusterSettings clusterSettings,
Settings settings, boolean acceptRequests, boolean doHandshake) {
MockTransportService service = build(
Settings.builder()
.put(settings)
.put(Node.NODE_NAME_SETTING.getKey(), name)
.put(TransportService.TRACE_LOG_INCLUDE_SETTING.getKey(), "")
.put(TransportService.TRACE_LOG_EXCLUDE_SETTING.getKey(), "NOTHING")
.build(),
version,
clusterSettings, doHandshake);
Settings.builder()
.put(settings)
.put(Node.NODE_NAME_SETTING.getKey(), name)
.put(TransportService.TRACE_LOG_INCLUDE_SETTING.getKey(), "")
.put(TransportService.TRACE_LOG_EXCLUDE_SETTING.getKey(), "NOTHING")
.build(),
version,
clusterSettings, doHandshake);
if (acceptRequests) {
service.acceptIncomingRequests();
}
Expand Down Expand Up @@ -2612,4 +2614,33 @@ public void testProfilesIncludesDefault() {
assertEquals(new HashSet<>(Arrays.asList("default", "test")), profileSettings.stream().map(s -> s.profileName).collect(Collectors
.toSet()));
}

public void testChannelCloseWhileConnecting() throws IOException {
try (MockTransportService service = build(Settings.builder().put("name", "close").build(), version0, null, true)) {
service.setExecutorName(ThreadPool.Names.SAME); // make sure stuff is executed in a blocking fashion
service.addConnectionListener(new TransportConnectionListener() {
@Override
public void onConnectionOpened(final Transport.Connection connection) {
try {
closeConnectionChannel(service.getOriginalTransport(), connection);
} catch (final IOException e) {
throw new AssertionError(e);
}
}
});
final ConnectionProfile.Builder builder = new ConnectionProfile.Builder();
builder.addConnections(1,
TransportRequestOptions.Type.BULK,
TransportRequestOptions.Type.PING,
TransportRequestOptions.Type.RECOVERY,
TransportRequestOptions.Type.REG,
TransportRequestOptions.Type.STATE);
final ConnectTransportException e =
expectThrows(ConnectTransportException.class, () -> service.openConnection(nodeA, builder.build()));
assertThat(e, hasToString(containsString(("a channel closed while connecting"))));
}
}

protected abstract void closeConnectionChannel(Transport transport, Transport.Connection connection) throws IOException;

}
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ private void readMessage(MockChannel mockChannel, StreamInput input) throws IOEx
}

@Override
protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile,
protected NodeChannels connectToChannels(DiscoveryNode node,
ConnectionProfile profile,
Consumer<MockChannel> onChannelClose) throws IOException {
final MockChannel[] mockChannels = new MockChannel[1];
final NodeChannels nodeChannels = new NodeChannels(node, mockChannels, LIGHT_PROFILE); // we always use light here
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ public NioClient(Logger logger, OpenChannels openChannels, Supplier<SocketSelect
this.channelFactory = channelFactory;
}

public boolean connectToChannels(DiscoveryNode node, NioSocketChannel[] channels, TimeValue connectTimeout,
public boolean connectToChannels(DiscoveryNode node,
NioSocketChannel[] channels,
TimeValue connectTimeout,
Consumer<NioChannel> closeListener) throws IOException {
boolean allowedToConnect = semaphore.tryAcquire();
if (allowedToConnect == false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.util.Collections;

public class MockTcpTransportTests extends AbstractSimpleTransportTestCase {

@Override
protected MockTransportService build(Settings settings, Version version, ClusterSettings clusterSettings, boolean doHandshake) {
NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList());
Expand All @@ -53,4 +54,13 @@ protected Version executeHandshake(DiscoveryNode node, MockChannel mockChannel,
mockTransportService.start();
return mockTransportService;
}

@Override
protected void closeConnectionChannel(Transport transport, Transport.Connection connection) throws IOException {
final MockTcpTransport t = (MockTcpTransport) transport;
@SuppressWarnings("unchecked") final TcpTransport<MockTcpTransport.MockChannel>.NodeChannels channels =
(TcpTransport<MockTcpTransport.MockChannel>.NodeChannels) connection;
t.closeChannels(channels.getChannels().subList(0, randomIntBetween(1, channels.getChannels().size())), true, false);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
public class SimpleNioTransportTests extends AbstractSimpleTransportTestCase {

public static MockTransportService nioFromThreadPool(Settings settings, ThreadPool threadPool, final Version version,
ClusterSettings clusterSettings, boolean doHandshake) {
ClusterSettings clusterSettings, boolean doHandshake) {
NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList());
NetworkService networkService = new NetworkService(Collections.emptyList());
Transport transport = new NioTransport(settings, threadPool,
Expand Down Expand Up @@ -96,6 +96,13 @@ protected MockTransportService build(Settings settings, Version version, Cluster
return transportService;
}

@Override
protected void closeConnectionChannel(Transport transport, Transport.Connection connection) throws IOException {
final NioTransport t = (NioTransport) transport;
@SuppressWarnings("unchecked") TcpTransport<NioChannel>.NodeChannels channels = (TcpTransport<NioChannel>.NodeChannels) connection;
t.closeChannels(channels.getChannels().subList(0, randomIntBetween(1, channels.getChannels().size())), true, false);
}

public void testConnectException() throws UnknownHostException {
try {
serviceA.connectToNode(new DiscoveryNode("C", new TransportAddress(InetAddress.getByName("localhost"), 9876),
Expand Down

0 comments on commit 4c06b8f

Please sign in to comment.