diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/AcceptorEventHandler.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/AcceptorEventHandler.java index 49bba47ef0256..ba0fa9356b9eb 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/AcceptorEventHandler.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/AcceptorEventHandler.java @@ -22,13 +22,11 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.transport.nio.channel.ChannelFactory; -import org.elasticsearch.transport.nio.channel.NioChannel; import org.elasticsearch.transport.nio.channel.NioServerSocketChannel; import org.elasticsearch.transport.nio.channel.NioSocketChannel; import org.elasticsearch.transport.nio.channel.SelectionKeyUtils; import java.io.IOException; -import java.util.function.Consumer; import java.util.function.Supplier; /** @@ -37,15 +35,10 @@ public class AcceptorEventHandler extends EventHandler { private final Supplier selectorSupplier; - private final Consumer acceptedChannelCallback; - private final OpenChannels openChannels; - public AcceptorEventHandler(Logger logger, OpenChannels openChannels, Supplier selectorSupplier, - Consumer acceptedChannelCallback) { - super(logger, openChannels); - this.openChannels = openChannels; + public AcceptorEventHandler(Logger logger, Supplier selectorSupplier) { + super(logger); this.selectorSupplier = selectorSupplier; - this.acceptedChannelCallback = acceptedChannelCallback; } /** @@ -56,7 +49,6 @@ public AcceptorEventHandler(Logger logger, OpenChannels openChannels, Supplier NIO_ACCEPTOR_COUNT = intSetting("transport.nio.acceptor_count", 1, 1, Setting.Property.NodeScope); - protected final OpenChannels openChannels = new OpenChannels(logger); - private final ConcurrentMap profileToChannelFactory = newConcurrentMap(); + private final OpenChannels openChannels = new OpenChannels(logger); + private final ConcurrentMap profileToChannelFactory = newConcurrentMap(); private final ArrayList acceptors = new ArrayList<>(); private final ArrayList socketSelectors = new ArrayList<>(); private RoundRobinSelectorSupplier clientSelectorSupplier; - private ChannelFactory clientChannelFactory; + private TcpChannelFactory clientChannelFactory; private int acceptorNumber; public NioTransport(Settings settings, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays, @@ -84,17 +86,21 @@ public long getNumOpenServerConnections() { } @Override - protected NioServerSocketChannel bind(String name, InetSocketAddress address) throws IOException { - ChannelFactory channelFactory = this.profileToChannelFactory.get(name); + protected TcpNioServerSocketChannel bind(String name, InetSocketAddress address) throws IOException { + TcpChannelFactory channelFactory = this.profileToChannelFactory.get(name); AcceptingSelector selector = acceptors.get(++acceptorNumber % NioTransport.NIO_ACCEPTOR_COUNT.get(settings)); - return channelFactory.openNioServerSocketChannel(address, selector); + TcpNioServerSocketChannel serverChannel = channelFactory.openNioServerSocketChannel(address, selector); + openChannels.serverChannelOpened(serverChannel); + serverChannel.addCloseListener(ActionListener.wrap(() -> openChannels.channelClosed(serverChannel))); + return serverChannel; } @Override - protected NioChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener connectListener) + protected TcpNioSocketChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener connectListener) throws IOException { - NioSocketChannel channel = clientChannelFactory.openNioChannel(node.getAddress().address(), clientSelectorSupplier.get()); + TcpNioSocketChannel channel = clientChannelFactory.openNioChannel(node.getAddress().address(), clientSelectorSupplier.get()); openChannels.clientChannelOpened(channel); + channel.addCloseListener(ActionListener.wrap(() -> openChannels.channelClosed(channel))); channel.addConnectListener(connectListener); return channel; } @@ -119,14 +125,14 @@ protected void doStart() { Consumer clientContextSetter = getContextSetter("client-socket"); clientSelectorSupplier = new RoundRobinSelectorSupplier(socketSelectors); - clientChannelFactory = new ChannelFactory(new ProfileSettings(settings, "default"), clientContextSetter); + ProfileSettings clientProfileSettings = new ProfileSettings(settings, "default"); + clientChannelFactory = new TcpChannelFactory(clientProfileSettings, clientContextSetter, getServerContextSetter()); if (NetworkService.NETWORK_SERVER.get(settings)) { int acceptorCount = NioTransport.NIO_ACCEPTOR_COUNT.get(settings); for (int i = 0; i < acceptorCount; ++i) { Supplier selectorSupplier = new RoundRobinSelectorSupplier(socketSelectors); - AcceptorEventHandler eventHandler = new AcceptorEventHandler(logger, openChannels, selectorSupplier, - this::serverAcceptedChannel); + AcceptorEventHandler eventHandler = new AcceptorEventHandler(logger, selectorSupplier); AcceptingSelector acceptor = new AcceptingSelector(eventHandler); acceptors.add(acceptor); } @@ -143,7 +149,8 @@ protected void doStart() { for (ProfileSettings profileSettings : profileSettings) { String profileName = profileSettings.profileName; Consumer contextSetter = getContextSetter(profileName); - profileToChannelFactory.putIfAbsent(profileName, new ChannelFactory(profileSettings, contextSetter)); + TcpChannelFactory factory = new TcpChannelFactory(profileSettings, contextSetter, getServerContextSetter()); + profileToChannelFactory.putIfAbsent(profileName, factory); bindServer(profileSettings); } } @@ -169,14 +176,27 @@ protected void stopInternal() { } protected SocketEventHandler getSocketEventHandler() { - return new SocketEventHandler(logger, this::exceptionCaught, openChannels); + return new SocketEventHandler(logger); } final void exceptionCaught(NioSocketChannel channel, Exception exception) { - onException(channel, exception); + onException((TcpNioSocketChannel) channel, exception); } private Consumer getContextSetter(String profileName) { - return (c) -> c.setContexts(new TcpReadContext(c, new TcpReadHandler(profileName,this)), new TcpWriteContext(c)); + return (c) -> c.setContexts(new TcpReadContext(c, new TcpReadHandler(profileName,this)), new TcpWriteContext(c), + this::exceptionCaught); + } + + private void acceptChannel(NioSocketChannel channel) { + TcpNioSocketChannel tcpChannel = (TcpNioSocketChannel) channel; + openChannels.acceptedChannelOpened(tcpChannel); + tcpChannel.addCloseListener(ActionListener.wrap(() -> openChannels.channelClosed(channel))); + serverAcceptedChannel(tcpChannel); + + } + + private Consumer getServerContextSetter() { + return (c) -> c.setAcceptContext(this::acceptChannel); } } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/OpenChannels.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/OpenChannels.java index 68bb2f99bf3c5..12c12aaa48eb1 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/OpenChannels.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/OpenChannels.java @@ -25,6 +25,8 @@ import org.elasticsearch.transport.nio.channel.NioChannel; import org.elasticsearch.transport.nio.channel.NioServerSocketChannel; import org.elasticsearch.transport.nio.channel.NioSocketChannel; +import org.elasticsearch.transport.nio.channel.TcpNioServerSocketChannel; +import org.elasticsearch.transport.nio.channel.TcpNioSocketChannel; import java.util.ArrayList; import java.util.HashSet; @@ -38,9 +40,9 @@ public class OpenChannels implements Releasable { // TODO: Maybe set concurrency levels? - private final ConcurrentMap openClientChannels = newConcurrentMap(); - private final ConcurrentMap openAcceptedChannels = newConcurrentMap(); - private final ConcurrentMap openServerChannels = newConcurrentMap(); + private final ConcurrentMap openClientChannels = newConcurrentMap(); + private final ConcurrentMap openAcceptedChannels = newConcurrentMap(); + private final ConcurrentMap openServerChannels = newConcurrentMap(); private final Logger logger; @@ -48,7 +50,7 @@ public OpenChannels(Logger logger) { this.logger = logger; } - public void serverChannelOpened(NioServerSocketChannel channel) { + public void serverChannelOpened(TcpNioServerSocketChannel channel) { boolean added = openServerChannels.putIfAbsent(channel, System.nanoTime()) == null; if (added && logger.isTraceEnabled()) { logger.trace("server channel opened: {}", channel); @@ -59,7 +61,7 @@ public long serverChannelsCount() { return openServerChannels.size(); } - public void acceptedChannelOpened(NioSocketChannel channel) { + public void acceptedChannelOpened(TcpNioSocketChannel channel) { boolean added = openAcceptedChannels.putIfAbsent(channel, System.nanoTime()) == null; if (added && logger.isTraceEnabled()) { logger.trace("accepted channel opened: {}", channel); @@ -70,14 +72,14 @@ public HashSet getAcceptedChannels() { return new HashSet<>(openAcceptedChannels.keySet()); } - public void clientChannelOpened(NioSocketChannel channel) { + public void clientChannelOpened(TcpNioSocketChannel channel) { boolean added = openClientChannels.putIfAbsent(channel, System.nanoTime()) == null; if (added && logger.isTraceEnabled()) { logger.trace("client channel opened: {}", channel); } } - public Map getClientChannels() { + public Map getClientChannels() { return openClientChannels; } @@ -105,7 +107,7 @@ public void closeServerChannels() { @Override public void close() { - Stream channels = Stream.concat(openClientChannels.keySet().stream(), openAcceptedChannels.keySet().stream()); + Stream channels = Stream.concat(openClientChannels.keySet().stream(), openAcceptedChannels.keySet().stream()); TcpChannel.closeChannels(channels.collect(Collectors.toList()), true); openClientChannels.clear(); diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/SocketEventHandler.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/SocketEventHandler.java index 46292f63d1bda..50362c5a665d9 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/SocketEventHandler.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/SocketEventHandler.java @@ -27,19 +27,16 @@ import org.elasticsearch.transport.nio.channel.WriteContext; import java.io.IOException; -import java.util.function.BiConsumer; /** * Event handler designed to handle events from non-server sockets */ public class SocketEventHandler extends EventHandler { - private final BiConsumer exceptionHandler; private final Logger logger; - public SocketEventHandler(Logger logger, BiConsumer exceptionHandler, OpenChannels openChannels) { - super(logger, openChannels); - this.exceptionHandler = exceptionHandler; + public SocketEventHandler(Logger logger) { + super(logger); this.logger = logger; } @@ -150,6 +147,6 @@ void genericChannelException(NioChannel channel, Exception exception) { } private void exceptionCaught(NioSocketChannel channel, Exception e) { - exceptionHandler.accept(channel, e); + channel.getExceptionContext().accept(channel, e); } } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/TcpReadHandler.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/TcpReadHandler.java index 1260546d34cab..5c2ecea54c3f0 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/TcpReadHandler.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/TcpReadHandler.java @@ -21,6 +21,7 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.transport.nio.channel.NioSocketChannel; +import org.elasticsearch.transport.nio.channel.TcpNioSocketChannel; import java.io.IOException; @@ -34,7 +35,7 @@ public TcpReadHandler(String profile, NioTransport transport) { this.transport = transport; } - public void handleMessage(BytesReference reference, NioSocketChannel channel, int messageBytesLength) { + public void handleMessage(BytesReference reference, TcpNioSocketChannel channel, int messageBytesLength) { try { transport.messageReceived(reference, channel, profile, channel.getRemoteAddress(), messageBytesLength); } catch (IOException e) { diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/AbstractNioChannel.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/AbstractNioChannel.java index 7743fe0d83c2a..7b08d831df83e 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/AbstractNioChannel.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/AbstractNioChannel.java @@ -137,25 +137,18 @@ public S getRawChannel() { return socketChannel; } + @Override + public void addCloseListener(ActionListener listener) { + closeContext.whenComplete(ActionListener.toBiConsumer(listener)); + } + // Package visibility for testing void setSelectionKey(SelectionKey selectionKey) { this.selectionKey = selectionKey; } - // Package visibility for testing + void closeRawChannel() throws IOException { socketChannel.close(); } - - @Override - public void addCloseListener(ActionListener listener) { - closeContext.whenComplete(ActionListener.toBiConsumer(listener)); - } - - @Override - public void setSoLinger(int value) throws IOException { - if (isOpen()) { - socketChannel.setOption(StandardSocketOptions.SO_LINGER, value); - } - } } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/ChannelFactory.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/ChannelFactory.java index 84385de062681..97433cf4d0aad 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/ChannelFactory.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/ChannelFactory.java @@ -19,74 +19,79 @@ package org.elasticsearch.transport.nio.channel; - import org.elasticsearch.mocksocket.PrivilegedSocketAccess; -import org.elasticsearch.transport.TcpTransport; import org.elasticsearch.transport.nio.AcceptingSelector; import org.elasticsearch.transport.nio.SocketSelector; import java.io.Closeable; import java.io.IOException; import java.net.InetSocketAddress; -import java.net.ServerSocket; -import java.net.Socket; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; -import java.util.function.Consumer; -public class ChannelFactory { +public abstract class ChannelFactory { - private final Consumer contextSetter; - private final RawChannelFactory rawChannelFactory; + private final ChannelFactory.RawChannelFactory rawChannelFactory; /** - * This will create a {@link ChannelFactory} using the profile settings and context setter passed to this - * constructor. The context setter must be a {@link Consumer} that calls - * {@link NioSocketChannel#setContexts(ReadContext, WriteContext)} with the appropriate read and write - * contexts. The read and write contexts handle the protocol specific encoding and decoding of messages. + * This will create a {@link ChannelFactory} using the raw channel factory passed to the constructor. * - * @param profileSettings the profile settings channels opened by this factory - * @param contextSetter a consumer that takes a channel and sets the read and write contexts + * @param rawChannelFactory a factory that will construct the raw socket channels */ - public ChannelFactory(TcpTransport.ProfileSettings profileSettings, Consumer contextSetter) { - this(new RawChannelFactory(profileSettings.tcpNoDelay, - profileSettings.tcpKeepAlive, - profileSettings.reuseAddress, - Math.toIntExact(profileSettings.sendBufferSize.getBytes()), - Math.toIntExact(profileSettings.receiveBufferSize.getBytes())), contextSetter); - } - - ChannelFactory(RawChannelFactory rawChannelFactory, Consumer contextSetter) { - this.contextSetter = contextSetter; + ChannelFactory(RawChannelFactory rawChannelFactory) { this.rawChannelFactory = rawChannelFactory; } - public NioSocketChannel openNioChannel(InetSocketAddress remoteAddress, SocketSelector selector) throws IOException { + public Socket openNioChannel(InetSocketAddress remoteAddress, SocketSelector selector) throws IOException { SocketChannel rawChannel = rawChannelFactory.openNioChannel(remoteAddress); - NioSocketChannel channel = createChannel(selector, rawChannel); + Socket channel = internalCreateChannel(selector, rawChannel); scheduleChannel(channel, selector); return channel; } - public NioSocketChannel acceptNioChannel(NioServerSocketChannel serverChannel, SocketSelector selector) throws IOException { + public Socket acceptNioChannel(NioServerSocketChannel serverChannel, SocketSelector selector) throws IOException { SocketChannel rawChannel = rawChannelFactory.acceptNioChannel(serverChannel); - NioSocketChannel channel = createChannel(selector, rawChannel); + Socket channel = internalCreateChannel(selector, rawChannel); scheduleChannel(channel, selector); return channel; } - public NioServerSocketChannel openNioServerSocketChannel(InetSocketAddress address, AcceptingSelector selector) - throws IOException { + public ServerSocket openNioServerSocketChannel(InetSocketAddress address, AcceptingSelector selector) throws IOException { ServerSocketChannel rawChannel = rawChannelFactory.openNioServerSocketChannel(address); - NioServerSocketChannel serverChannel = createServerChannel(selector, rawChannel); + ServerSocket serverChannel = internalCreateServerChannel(selector, rawChannel); scheduleServerChannel(serverChannel, selector); return serverChannel; } - private NioSocketChannel createChannel(SocketSelector selector, SocketChannel rawChannel) throws IOException { + /** + * This method should return a new {@link NioSocketChannel} implementation. When this method has + * returned, the channel should be fully created and setup. Read and write contexts and the channel + * exception handler should have been set. + * + * @param selector the channel will be registered with + * @param channel the raw channel + * @return the channel + * @throws IOException related to the creation of the channel + */ + public abstract Socket createChannel(SocketSelector selector, SocketChannel channel) throws IOException; + + /** + * This method should return a new {@link NioServerSocketChannel} implementation. When this method has + * returned, the channel should be fully created and setup. + * + * @param selector the channel will be registered with + * @param channel the raw channel + * @return the server channel + * @throws IOException related to the creation of the channel + */ + public abstract ServerSocket createServerChannel(AcceptingSelector selector, ServerSocketChannel channel) throws IOException; + + private Socket internalCreateChannel(SocketSelector selector, SocketChannel rawChannel) throws IOException { try { - NioSocketChannel channel = new NioSocketChannel(rawChannel, selector); - setContexts(channel); + Socket channel = createChannel(selector, rawChannel); + assert channel.getReadContext() != null : "read context should have been set on channel"; + assert channel.getWriteContext() != null : "write context should have been set on channel"; + assert channel.getExceptionContext() != null : "exception handler should have been set on channel"; return channel; } catch (Exception e) { closeRawChannel(rawChannel, e); @@ -94,16 +99,16 @@ private NioSocketChannel createChannel(SocketSelector selector, SocketChannel ra } } - private NioServerSocketChannel createServerChannel(AcceptingSelector selector, ServerSocketChannel rawChannel) throws IOException { + private ServerSocket internalCreateServerChannel(AcceptingSelector selector, ServerSocketChannel rawChannel) throws IOException { try { - return new NioServerSocketChannel(rawChannel, this, selector); + return createServerChannel(selector, rawChannel); } catch (Exception e) { closeRawChannel(rawChannel, e); throw e; } } - private void scheduleChannel(NioSocketChannel channel, SocketSelector selector) { + private void scheduleChannel(Socket channel, SocketSelector selector) { try { selector.scheduleForRegistration(channel); } catch (IllegalStateException e) { @@ -112,7 +117,7 @@ private void scheduleChannel(NioSocketChannel channel, SocketSelector selector) } } - private void scheduleServerChannel(NioServerSocketChannel channel, AcceptingSelector selector) { + private void scheduleServerChannel(ServerSocket channel, AcceptingSelector selector) { try { selector.scheduleForRegistration(channel); } catch (IllegalStateException e) { @@ -121,12 +126,6 @@ private void scheduleServerChannel(NioServerSocketChannel channel, AcceptingSele } } - private void setContexts(NioSocketChannel channel) { - contextSetter.accept(channel); - assert channel.getReadContext() != null : "read context should have been set on channel"; - assert channel.getWriteContext() != null : "write context should have been set on channel"; - } - private static void closeRawChannel(Closeable c, Exception e) { try { c.close(); @@ -179,7 +178,7 @@ SocketChannel acceptNioChannel(NioServerSocketChannel serverChannel) throws IOEx ServerSocketChannel openNioServerSocketChannel(InetSocketAddress address) throws IOException { ServerSocketChannel serverSocketChannel = ServerSocketChannel.open(); serverSocketChannel.configureBlocking(false); - ServerSocket socket = serverSocketChannel.socket(); + java.net.ServerSocket socket = serverSocketChannel.socket(); try { socket.setReuseAddress(tcpReusedAddress); serverSocketChannel.bind(address); @@ -192,7 +191,7 @@ ServerSocketChannel openNioServerSocketChannel(InetSocketAddress address) throws private void configureSocketChannel(SocketChannel channel) throws IOException { channel.configureBlocking(false); - Socket socket = channel.socket(); + java.net.Socket socket = channel.socket(); socket.setTcpNoDelay(tcpNoDelay); socket.setKeepAlive(tcpKeepAlive); socket.setReuseAddress(tcpReusedAddress); diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioChannel.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioChannel.java index 76262da6f1558..93bc4faa4c5d5 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioChannel.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioChannel.java @@ -19,7 +19,7 @@ package org.elasticsearch.transport.nio.channel; -import org.elasticsearch.transport.TcpChannel; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.transport.nio.ESSelector; import java.io.IOException; @@ -28,7 +28,9 @@ import java.nio.channels.NetworkChannel; import java.nio.channels.SelectionKey; -public interface NioChannel extends TcpChannel { +public interface NioChannel { + + boolean isOpen(); InetSocketAddress getLocalAddress(); @@ -43,4 +45,13 @@ public interface NioChannel extends TcpChannel { SelectionKey getSelectionKey(); NetworkChannel getRawChannel(); + + /** + * Adds a close listener to the channel. Multiple close listeners can be added. There is no guarantee + * about the order in which close listeners will be executed. If the channel is already closed, the + * listener is executed immediately. + * + * @param listener to be called at close + */ + void addCloseListener(ActionListener listener); } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannel.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannel.java index 0396a53f45459..ffbd8f7a9874e 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannel.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannel.java @@ -19,16 +19,16 @@ package org.elasticsearch.transport.nio.channel; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.transport.nio.AcceptingSelector; import java.io.IOException; import java.nio.channels.ServerSocketChannel; +import java.util.function.Consumer; public class NioServerSocketChannel extends AbstractNioChannel { private final ChannelFactory channelFactory; + private Consumer acceptContext; public NioServerSocketChannel(ServerSocketChannel socketChannel, ChannelFactory channelFactory, AcceptingSelector selector) throws IOException { @@ -40,9 +40,18 @@ public ChannelFactory getChannelFactory() { return channelFactory; } - @Override - public void sendMessage(BytesReference reference, ActionListener listener) { - throw new UnsupportedOperationException("Cannot send a message to a server channel."); + /** + * This method sets the accept context for a server socket channel. The accept context is called when a + * new channel is accepted. The parameter passed to the context is the new channel. + * + * @param acceptContext to call + */ + public void setAcceptContext(Consumer acceptContext) { + this.acceptContext = acceptContext; + } + + public Consumer getAcceptContext() { + return acceptContext; } @Override diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioSocketChannel.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioSocketChannel.java index d0c3d9c3330d9..b56731aee10b2 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioSocketChannel.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioSocketChannel.java @@ -20,7 +20,6 @@ package org.elasticsearch.transport.nio.channel; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.transport.nio.NetworkBytesReference; import org.elasticsearch.transport.nio.SocketSelector; @@ -31,14 +30,18 @@ import java.nio.channels.SocketChannel; import java.util.Arrays; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.BiConsumer; public class NioSocketChannel extends AbstractNioChannel { private final InetSocketAddress remoteAddress; private final CompletableFuture connectContext = new CompletableFuture<>(); private final SocketSelector socketSelector; + private final AtomicBoolean contextsSet = new AtomicBoolean(false); private WriteContext writeContext; private ReadContext readContext; + private BiConsumer exceptionContext; private Exception connectException; public NioSocketChannel(SocketChannel socketChannel, SocketSelector selector) throws IOException { @@ -47,11 +50,6 @@ public NioSocketChannel(SocketChannel socketChannel, SocketSelector selector) th this.socketSelector = selector; } - @Override - public void sendMessage(BytesReference reference, ActionListener listener) { - writeContext.sendMessage(reference, listener); - } - @Override public void closeFromSelector() throws IOException { assert socketSelector.isOnCurrentThread() : "Should only call from selector thread"; @@ -99,9 +97,14 @@ public int read(NetworkBytesReference reference) throws IOException { return bytesRead; } - public void setContexts(ReadContext readContext, WriteContext writeContext) { - this.readContext = readContext; - this.writeContext = writeContext; + public void setContexts(ReadContext readContext, WriteContext writeContext, BiConsumer exceptionContext) { + if (contextsSet.compareAndSet(false, true)) { + this.readContext = readContext; + this.writeContext = writeContext; + this.exceptionContext = exceptionContext; + } else { + throw new IllegalStateException("Contexts on this channel were already set. They should only be once."); + } } public WriteContext getWriteContext() { @@ -112,6 +115,10 @@ public ReadContext getReadContext() { return readContext; } + public BiConsumer getExceptionContext() { + return exceptionContext; + } + public InetSocketAddress getRemoteAddress() { return remoteAddress; } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpChannelFactory.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpChannelFactory.java new file mode 100644 index 0000000000000..03d6db18e5a41 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpChannelFactory.java @@ -0,0 +1,66 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio.channel; + +import org.elasticsearch.transport.TcpTransport; +import org.elasticsearch.transport.nio.AcceptingSelector; +import org.elasticsearch.transport.nio.SocketSelector; + +import java.io.IOException; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.util.function.Consumer; + +/** + * This is an implementation of {@link ChannelFactory} which returns channels that adhere to the + * {@link org.elasticsearch.transport.TcpChannel} interface. The channels will use the provided + * {@link TcpTransport.ProfileSettings}. The provided context setters will be called with the channel after + * construction. + */ +public class TcpChannelFactory extends ChannelFactory { + + private final Consumer contextSetter; + private final Consumer serverContextSetter; + + public TcpChannelFactory(TcpTransport.ProfileSettings profileSettings, Consumer contextSetter, + Consumer serverContextSetter) { + super(new RawChannelFactory(profileSettings.tcpNoDelay, + profileSettings.tcpKeepAlive, + profileSettings.reuseAddress, + Math.toIntExact(profileSettings.sendBufferSize.getBytes()), + Math.toIntExact(profileSettings.receiveBufferSize.getBytes()))); + this.contextSetter = contextSetter; + this.serverContextSetter = serverContextSetter; + } + + @Override + public TcpNioSocketChannel createChannel(SocketSelector selector, SocketChannel channel) throws IOException { + TcpNioSocketChannel nioChannel = new TcpNioSocketChannel(channel, selector); + contextSetter.accept(nioChannel); + return nioChannel; + } + + @Override + public TcpNioServerSocketChannel createServerChannel(AcceptingSelector selector, ServerSocketChannel channel) throws IOException { + TcpNioServerSocketChannel nioServerChannel = new TcpNioServerSocketChannel(channel, this, selector); + serverContextSetter.accept(nioServerChannel); + return nioServerChannel; + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpNioServerSocketChannel.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpNioServerSocketChannel.java new file mode 100644 index 0000000000000..496295bd3203b --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpNioServerSocketChannel.java @@ -0,0 +1,57 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio.channel; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.transport.TcpChannel; +import org.elasticsearch.transport.nio.AcceptingSelector; + +import java.io.IOException; +import java.nio.channels.ServerSocketChannel; + +/** + * This is an implementation of {@link NioServerSocketChannel} that adheres to the {@link TcpChannel} + * interface. As it is a server socket, setting SO_LINGER and sending messages is not supported. + */ +public class TcpNioServerSocketChannel extends NioServerSocketChannel implements TcpChannel { + + TcpNioServerSocketChannel(ServerSocketChannel socketChannel, TcpChannelFactory channelFactory, AcceptingSelector selector) + throws IOException { + super(socketChannel, channelFactory, selector); + } + + @Override + public void sendMessage(BytesReference reference, ActionListener listener) { + throw new UnsupportedOperationException("Cannot send a message to a server channel."); + } + + @Override + public void setSoLinger(int value) throws IOException { + throw new UnsupportedOperationException("Cannot set SO_LINGER on a server channel."); + } + + @Override + public String toString() { + return "TcpNioServerSocketChannel{" + + "localAddress=" + getLocalAddress() + + '}'; + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpNioSocketChannel.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpNioSocketChannel.java new file mode 100644 index 0000000000000..f1ee1bd4e67ad --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpNioSocketChannel.java @@ -0,0 +1,55 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.nio.channel; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.transport.TcpChannel; +import org.elasticsearch.transport.nio.SocketSelector; + +import java.io.IOException; +import java.net.StandardSocketOptions; +import java.nio.channels.SocketChannel; + +public class TcpNioSocketChannel extends NioSocketChannel implements TcpChannel { + + public TcpNioSocketChannel(SocketChannel socketChannel, SocketSelector selector) throws IOException { + super(socketChannel, selector); + } + + public void sendMessage(BytesReference reference, ActionListener listener) { + getWriteContext().sendMessage(reference, listener); + } + + @Override + public void setSoLinger(int value) throws IOException { + if (isOpen()) { + getRawChannel().setOption(StandardSocketOptions.SO_LINGER, value); + } + } + + @Override + public String toString() { + return "TcpNioSocketChannel{" + + "localAddress=" + getLocalAddress() + + ", remoteAddress=" + getRemoteAddress() + + '}'; + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpReadContext.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpReadContext.java index 57aa16ce15e3b..8eeb32a976cac 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpReadContext.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/TcpReadContext.java @@ -34,16 +34,16 @@ public class TcpReadContext implements ReadContext { private static final int DEFAULT_READ_LENGTH = 1 << 14; private final TcpReadHandler handler; - private final NioSocketChannel channel; + private final TcpNioSocketChannel channel; private final TcpFrameDecoder frameDecoder; private final LinkedList references = new LinkedList<>(); private int rawBytesCount = 0; public TcpReadContext(NioSocketChannel channel, TcpReadHandler handler) { - this(channel, handler, new TcpFrameDecoder()); + this((TcpNioSocketChannel) channel, handler, new TcpFrameDecoder()); } - public TcpReadContext(NioSocketChannel channel, TcpReadHandler handler, TcpFrameDecoder frameDecoder) { + public TcpReadContext(TcpNioSocketChannel channel, TcpReadHandler handler, TcpFrameDecoder frameDecoder) { this.handler = handler; this.channel = channel; this.frameDecoder = frameDecoder; diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/AcceptorEventHandlerTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/AcceptorEventHandlerTests.java index 3f23531407cb0..aedff1721f8d9 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/AcceptorEventHandlerTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/AcceptorEventHandlerTests.java @@ -19,7 +19,6 @@ package org.elasticsearch.transport.nio; -import org.apache.lucene.util.IOUtils; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.transport.nio.channel.ChannelFactory; import org.elasticsearch.transport.nio.channel.DoNotRegisterServerChannel; @@ -34,11 +33,9 @@ import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import java.util.ArrayList; -import java.util.Collections; -import java.util.HashSet; +import java.util.function.BiConsumer; import java.util.function.Consumer; -import static org.mockito.Matchers.any; import static org.mockito.Matchers.same; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -49,7 +46,6 @@ public class AcceptorEventHandlerTests extends ESTestCase { private AcceptorEventHandler handler; private SocketSelector socketSelector; private ChannelFactory channelFactory; - private OpenChannels openChannels; private NioServerSocketChannel channel; private Consumer acceptedChannelCallback; @@ -59,24 +55,16 @@ public void setUpHandler() throws IOException { channelFactory = mock(ChannelFactory.class); socketSelector = mock(SocketSelector.class); acceptedChannelCallback = mock(Consumer.class); - openChannels = new OpenChannels(logger); ArrayList selectors = new ArrayList<>(); selectors.add(socketSelector); - handler = new AcceptorEventHandler(logger, openChannels, new RoundRobinSelectorSupplier(selectors), acceptedChannelCallback); + handler = new AcceptorEventHandler(logger, new RoundRobinSelectorSupplier(selectors)); AcceptingSelector selector = mock(AcceptingSelector.class); channel = new DoNotRegisterServerChannel(mock(ServerSocketChannel.class), channelFactory, selector); + channel.setAcceptContext(acceptedChannelCallback); channel.register(); } - public void testHandleRegisterAdjustsOpenChannels() { - assertEquals(0, openChannels.serverChannelsCount()); - - handler.serverChannelRegistered(channel); - - assertEquals(1, openChannels.serverChannelsCount()); - } - public void testHandleRegisterSetsOP_ACCEPTInterest() { assertEquals(0, channel.getSelectionKey().interestOps()); @@ -96,18 +84,13 @@ public void testHandleAcceptCallsChannelFactory() throws IOException { } @SuppressWarnings("unchecked") - public void testHandleAcceptAddsToOpenChannelsAndIsRemovedOnClose() throws IOException { - SocketChannel rawChannel = SocketChannel.open(); - NioSocketChannel childChannel = new NioSocketChannel(rawChannel, socketSelector); - childChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class)); + public void testHandleAcceptCallsServerAcceptCallback() throws IOException { + NioSocketChannel childChannel = new NioSocketChannel(mock(SocketChannel.class), socketSelector); + childChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class), mock(BiConsumer.class)); when(channelFactory.acceptNioChannel(same(channel), same(socketSelector))).thenReturn(childChannel); handler.acceptChannel(channel); verify(acceptedChannelCallback).accept(childChannel); - - assertEquals(new HashSet<>(Collections.singletonList(childChannel)), openChannels.getAcceptedChannels()); - - IOUtils.closeWhileHandlingException(rawChannel); } } diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java index bc02a89a5c18d..55bca45d1c81f 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java @@ -77,7 +77,7 @@ protected Version getCurrentVersion() { @Override protected SocketEventHandler getSocketEventHandler() { - return new TestingSocketEventHandler(logger, this::exceptionCaught, openChannels); + return new TestingSocketEventHandler(logger); } }; MockTransportService mockTransportService = diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketEventHandlerTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketEventHandlerTests.java index cd4e70ab3acb0..8f270d11e5a35 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketEventHandlerTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketEventHandlerTests.java @@ -55,13 +55,13 @@ public class SocketEventHandlerTests extends ESTestCase { public void setUpHandler() throws IOException { exceptionHandler = mock(BiConsumer.class); SocketSelector socketSelector = mock(SocketSelector.class); - handler = new SocketEventHandler(logger, exceptionHandler, mock(OpenChannels.class)); + handler = new SocketEventHandler(logger); rawChannel = mock(SocketChannel.class); channel = new DoNotRegisterChannel(rawChannel, socketSelector); readContext = mock(ReadContext.class); when(rawChannel.finishConnect()).thenReturn(true); - channel.setContexts(readContext, new TcpWriteContext(channel)); + channel.setContexts(readContext, new TcpWriteContext(channel), exceptionHandler); channel.register(); channel.finishConnect(); diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketSelectorTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketSelectorTests.java index 0de1bb72063ba..61a9499f8db32 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketSelectorTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketSelectorTests.java @@ -22,7 +22,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.transport.nio.channel.NioChannel; import org.elasticsearch.transport.nio.channel.NioSocketChannel; import org.elasticsearch.transport.nio.channel.WriteContext; import org.elasticsearch.transport.nio.utils.TestSelectionKey; diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/TestingSocketEventHandler.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/TestingSocketEventHandler.java index 65759cf770552..a3cb92ad37663 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/TestingSocketEventHandler.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/TestingSocketEventHandler.java @@ -26,12 +26,11 @@ import java.util.Collections; import java.util.Set; import java.util.WeakHashMap; -import java.util.function.BiConsumer; public class TestingSocketEventHandler extends SocketEventHandler { - public TestingSocketEventHandler(Logger logger, BiConsumer exceptionHandler, OpenChannels openChannels) { - super(logger, exceptionHandler, openChannels); + public TestingSocketEventHandler(Logger logger) { + super(logger); } private Set hasConnectedMap = Collections.newSetFromMap(new WeakHashMap<>()); diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/WriteOperationTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/WriteOperationTests.java index 1f6f95e62af3e..351ac87eb561e 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/WriteOperationTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/WriteOperationTests.java @@ -22,7 +22,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.transport.nio.channel.NioChannel; import org.elasticsearch.transport.nio.channel.NioSocketChannel; import org.junit.Before; @@ -30,7 +29,6 @@ import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class WriteOperationTests extends ESTestCase { diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/ChannelFactoryTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/ChannelFactoryTests.java index f6bcf26a02c2f..91e1c2023e74c 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/ChannelFactoryTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/ChannelFactoryTests.java @@ -30,11 +30,10 @@ import java.net.InetSocketAddress; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; -import java.util.function.Consumer; +import java.util.function.BiConsumer; import static org.mockito.Matchers.any; import static org.mockito.Matchers.same; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -52,19 +51,12 @@ public class ChannelFactoryTests extends ESTestCase { @Before @SuppressWarnings("unchecked") public void setupFactory() throws IOException { - rawChannelFactory = mock(ChannelFactory.RawChannelFactory.class); - Consumer contextSetter = mock(Consumer.class); - channelFactory = new ChannelFactory(rawChannelFactory, contextSetter); + rawChannelFactory = mock(TcpChannelFactory.RawChannelFactory.class); + channelFactory = new TestChannelFactory(rawChannelFactory); socketSelector = mock(SocketSelector.class); acceptingSelector = mock(AcceptingSelector.class); rawChannel = SocketChannel.open(); rawServerChannel = ServerSocketChannel.open(); - - doAnswer(invocationOnMock -> { - NioSocketChannel channel = (NioSocketChannel) invocationOnMock.getArguments()[0]; - channel.setContexts(mock(ReadContext.class), mock(WriteContext.class)); - return null; - }).when(contextSetter).accept(any()); } @After @@ -138,4 +130,24 @@ public void testOpenedServerChannelRejected() throws IOException { assertFalse(rawServerChannel.isOpen()); } + + private static class TestChannelFactory extends ChannelFactory { + + TestChannelFactory(RawChannelFactory rawChannelFactory) { + super(rawChannelFactory); + } + + @SuppressWarnings("unchecked") + @Override + public NioSocketChannel createChannel(SocketSelector selector, SocketChannel channel) throws IOException { + NioSocketChannel nioSocketChannel = new NioSocketChannel(channel, selector); + nioSocketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class), mock(BiConsumer.class)); + return nioSocketChannel; + } + + @Override + public NioServerSocketChannel createServerChannel(AcceptingSelector selector, ServerSocketChannel channel) throws IOException { + return new NioServerSocketChannel(channel, this, selector); + } + } } diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannelTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannelTests.java index 9c01f5edc6106..ba5d47fe8f8dd 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannelTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannelTests.java @@ -22,10 +22,8 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.transport.TcpChannel; import org.elasticsearch.transport.nio.AcceptingSelector; import org.elasticsearch.transport.nio.AcceptorEventHandler; -import org.elasticsearch.transport.nio.OpenChannels; import org.junit.After; import org.junit.Before; @@ -33,8 +31,6 @@ import java.nio.channels.ServerSocketChannel; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; import java.util.function.Supplier; import static org.mockito.Mockito.mock; @@ -48,7 +44,7 @@ public class NioServerSocketChannelTests extends ESTestCase { @Before @SuppressWarnings("unchecked") public void setSelector() throws IOException { - selector = new AcceptingSelector(new AcceptorEventHandler(logger, mock(OpenChannels.class), mock(Supplier.class), (c) -> {})); + selector = new AcceptingSelector(new AcceptorEventHandler(logger, mock(Supplier.class))); thread = new Thread(selector::runLoop); closedRawChannel = new AtomicBoolean(false); thread.start(); diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioSocketChannelTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioSocketChannelTests.java index e3053a3e73a3c..fecaf8fe9701e 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioSocketChannelTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioSocketChannelTests.java @@ -22,8 +22,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.transport.TcpChannel; -import org.elasticsearch.transport.nio.OpenChannels; import org.elasticsearch.transport.nio.SocketEventHandler; import org.elasticsearch.transport.nio.SocketSelector; import org.junit.After; @@ -36,9 +34,7 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; -import java.util.function.Consumer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -48,13 +44,11 @@ public class NioSocketChannelTests extends ESTestCase { private SocketSelector selector; private AtomicBoolean closedRawChannel; private Thread thread; - private OpenChannels openChannels; @Before @SuppressWarnings("unchecked") public void startSelector() throws IOException { - openChannels = new OpenChannels(logger); - selector = new SocketSelector(new SocketEventHandler(logger, mock(BiConsumer.class), openChannels)); + selector = new SocketSelector(new SocketEventHandler(logger)); thread = new Thread(selector::runLoop); closedRawChannel = new AtomicBoolean(false); thread.start(); @@ -67,13 +61,13 @@ public void stopSelector() throws IOException, InterruptedException { thread.join(); } + @SuppressWarnings("unchecked") public void testClose() throws Exception { AtomicBoolean isClosed = new AtomicBoolean(false); CountDownLatch latch = new CountDownLatch(1); NioSocketChannel socketChannel = new DoNotCloseChannel(mock(SocketChannel.class), selector); - openChannels.clientChannelOpened(socketChannel); - socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class)); + socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class), mock(BiConsumer.class)); socketChannel.addCloseListener(new ActionListener() { @Override public void onResponse(Void o) { @@ -90,7 +84,6 @@ public void onFailure(Exception e) { assertTrue(socketChannel.isOpen()); assertFalse(closedRawChannel.get()); assertFalse(isClosed.get()); - assertTrue(openChannels.getClientChannels().containsKey(socketChannel)); PlainActionFuture closeFuture = PlainActionFuture.newFuture(); socketChannel.addCloseListener(closeFuture); @@ -99,16 +92,16 @@ public void onFailure(Exception e) { assertTrue(closedRawChannel.get()); assertFalse(socketChannel.isOpen()); - assertFalse(openChannels.getClientChannels().containsKey(socketChannel)); latch.await(); assertTrue(isClosed.get()); } + @SuppressWarnings("unchecked") public void testConnectSucceeds() throws Exception { SocketChannel rawChannel = mock(SocketChannel.class); when(rawChannel.finishConnect()).thenReturn(true); NioSocketChannel socketChannel = new DoNotCloseChannel(rawChannel, selector); - socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class)); + socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class), mock(BiConsumer.class)); selector.scheduleForRegistration(socketChannel); PlainActionFuture connectFuture = PlainActionFuture.newFuture(); @@ -120,11 +113,12 @@ public void testConnectSucceeds() throws Exception { assertFalse(closedRawChannel.get()); } + @SuppressWarnings("unchecked") public void testConnectFails() throws Exception { SocketChannel rawChannel = mock(SocketChannel.class); when(rawChannel.finishConnect()).thenThrow(new ConnectException()); NioSocketChannel socketChannel = new DoNotCloseChannel(rawChannel, selector); - socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class)); + socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class), mock(BiConsumer.class)); selector.scheduleForRegistration(socketChannel); PlainActionFuture connectFuture = PlainActionFuture.newFuture(); diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/TcpReadContextTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/TcpReadContextTests.java index 2dc0b32ae5bea..7586b5abd91e0 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/TcpReadContextTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/TcpReadContextTests.java @@ -39,10 +39,9 @@ public class TcpReadContextTests extends ESTestCase { - private static String PROFILE = "profile"; private TcpReadHandler handler; private int messageLength; - private NioSocketChannel channel; + private TcpNioSocketChannel channel; private TcpReadContext readContext; @Before @@ -50,7 +49,7 @@ public void init() throws IOException { handler = mock(TcpReadHandler.class); messageLength = randomInt(96) + 4; - channel = mock(NioSocketChannel.class); + channel = mock(TcpNioSocketChannel.class); readContext = new TcpReadContext(channel, handler); } @@ -144,5 +143,4 @@ private static byte[] createMessage(int length) { } return bytes; } - }