Skip to content

Commit

Permalink
Remove tcp profile from low level nio channel (#27441)
Browse files Browse the repository at this point in the history
This is related to #27260. Currently every nio channel has a profile
field. Profile is a concept that only relates to the tcp transport. Http
channels will not have profiles. This commit moves the profile from the
nio channel to the read context. The context is the level that protocol
specific features and logic should live.
  • Loading branch information
Tim-Brooks committed Nov 20, 2017
1 parent 01bac8d commit aec44a0
Show file tree
Hide file tree
Showing 15 changed files with 79 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ public class NioTransport extends TcpTransport {
intSetting("transport.nio.acceptor_count", 1, 1, Setting.Property.NodeScope);

protected final OpenChannels openChannels = new OpenChannels(logger);
private final Consumer<NioSocketChannel> contextSetter;
private final ConcurrentMap<String, ChannelFactory> profileToChannelFactory = newConcurrentMap();
private final ArrayList<AcceptingSelector> acceptors = new ArrayList<>();
private final ArrayList<SocketSelector> socketSelectors = new ArrayList<>();
Expand All @@ -77,7 +76,6 @@ public class NioTransport extends TcpTransport {
public NioTransport(Settings settings, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays,
NamedWriteableRegistry namedWriteableRegistry, CircuitBreakerService circuitBreakerService) {
super("nio", settings, threadPool, bigArrays, circuitBreakerService, namedWriteableRegistry, networkService);
contextSetter = (c) -> c.setContexts(new TcpReadContext(c, new TcpReadHandler(this)), new TcpWriteContext(c));
}

@Override
Expand All @@ -89,7 +87,7 @@ public long getNumOpenServerConnections() {
protected NioServerSocketChannel bind(String name, InetSocketAddress address) throws IOException {
ChannelFactory channelFactory = this.profileToChannelFactory.get(name);
AcceptingSelector selector = acceptors.get(++acceptorNumber % NioTransport.NIO_ACCEPTOR_COUNT.get(settings));
return channelFactory.openNioServerSocketChannel(name, address, selector);
return channelFactory.openNioServerSocketChannel(address, selector);
}

@Override
Expand Down Expand Up @@ -119,8 +117,9 @@ protected void doStart() {
}
}

Consumer<NioSocketChannel> clientContextSetter = getContextSetter("client-socket");
clientSelectorSupplier = new RoundRobinSelectorSupplier(socketSelectors);
clientChannelFactory = new ChannelFactory(new ProfileSettings(settings, "default"), contextSetter);
clientChannelFactory = new ChannelFactory(new ProfileSettings(settings, "default"), clientContextSetter);

if (NetworkService.NETWORK_SERVER.get(settings)) {
int acceptorCount = NioTransport.NIO_ACCEPTOR_COUNT.get(settings);
Expand All @@ -142,7 +141,9 @@ protected void doStart() {

// loop through all profiles and start them up, special handling for default one
for (ProfileSettings profileSettings : profileSettings) {
profileToChannelFactory.putIfAbsent(profileSettings.profileName, new ChannelFactory(profileSettings, contextSetter));
String profileName = profileSettings.profileName;
Consumer<NioSocketChannel> contextSetter = getContextSetter(profileName);
profileToChannelFactory.putIfAbsent(profileName, new ChannelFactory(profileSettings, contextSetter));
bindServer(profileSettings);
}
}
Expand Down Expand Up @@ -174,4 +175,8 @@ protected SocketEventHandler getSocketEventHandler() {
final void exceptionCaught(NioSocketChannel channel, Exception exception) {
onException(channel, exception);
}

private Consumer<NioSocketChannel> getContextSetter(String profileName) {
return (c) -> c.setContexts(new TcpReadContext(c, new TcpReadHandler(profileName,this)), new TcpWriteContext(c));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,17 @@

public class TcpReadHandler {

private final String profile;
private final NioTransport transport;

public TcpReadHandler(NioTransport transport) {
public TcpReadHandler(String profile, NioTransport transport) {
this.profile = profile;
this.transport = transport;
}

public void handleMessage(BytesReference reference, NioSocketChannel channel, int messageBytesLength) {
try {
transport.messageReceived(reference, channel, channel.getProfile(), channel.getRemoteAddress(), messageBytesLength);
transport.messageReceived(reference, channel, profile, channel.getRemoteAddress(), messageBytesLength);
} catch (IOException e) {
handleException(channel, e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,11 @@ public abstract class AbstractNioChannel<S extends SelectableChannel & NetworkCh
final AtomicBoolean isClosing = new AtomicBoolean(false);

private final InetSocketAddress localAddress;
private final String profile;
private final CompletableFuture<Void> closeContext = new CompletableFuture<>();
private final ESSelector selector;
private SelectionKey selectionKey;

AbstractNioChannel(String profile, S socketChannel, ESSelector selector) throws IOException {
this.profile = profile;
AbstractNioChannel(S socketChannel, ESSelector selector) throws IOException {
this.socketChannel = socketChannel;
this.localAddress = (InetSocketAddress) socketChannel.getLocalAddress();
this.selector = selector;
Expand All @@ -78,11 +76,6 @@ public InetSocketAddress getLocalAddress() {
return localAddress;
}

@Override
public String getProfile() {
return profile;
}

/**
* Schedules a channel to be closed by the selector event loop with which it is registered.
* <p>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
package org.elasticsearch.transport.nio.channel;


import org.apache.lucene.util.IOUtils;
import org.elasticsearch.mocksocket.PrivilegedSocketAccess;
import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.nio.AcceptingSelector;
Expand Down Expand Up @@ -64,33 +63,51 @@ public ChannelFactory(TcpTransport.ProfileSettings profileSettings, Consumer<Nio

public NioSocketChannel openNioChannel(InetSocketAddress remoteAddress, SocketSelector selector) throws IOException {
SocketChannel rawChannel = rawChannelFactory.openNioChannel(remoteAddress);
NioSocketChannel channel = new NioSocketChannel(NioChannel.CLIENT, rawChannel, selector);
setContexts(channel);
NioSocketChannel channel = createChannel(selector, rawChannel);
scheduleChannel(channel, selector);
return channel;
}

public NioSocketChannel acceptNioChannel(NioServerSocketChannel serverChannel, SocketSelector selector) throws IOException {
SocketChannel rawChannel = rawChannelFactory.acceptNioChannel(serverChannel);
NioSocketChannel channel = new NioSocketChannel(serverChannel.getProfile(), rawChannel, selector);
setContexts(channel);
NioSocketChannel channel = createChannel(selector, rawChannel);
scheduleChannel(channel, selector);
return channel;
}

public NioServerSocketChannel openNioServerSocketChannel(String profileName, InetSocketAddress address, AcceptingSelector selector)
public NioServerSocketChannel openNioServerSocketChannel(InetSocketAddress address, AcceptingSelector selector)
throws IOException {
ServerSocketChannel rawChannel = rawChannelFactory.openNioServerSocketChannel(address);
NioServerSocketChannel serverChannel = new NioServerSocketChannel(profileName, rawChannel, this, selector);
NioServerSocketChannel serverChannel = createServerChannel(selector, rawChannel);
scheduleServerChannel(serverChannel, selector);
return serverChannel;
}

private NioSocketChannel createChannel(SocketSelector selector, SocketChannel rawChannel) throws IOException {
try {
NioSocketChannel channel = new NioSocketChannel(rawChannel, selector);
setContexts(channel);
return channel;
} catch (Exception e) {
closeRawChannel(rawChannel, e);
throw e;
}
}

private NioServerSocketChannel createServerChannel(AcceptingSelector selector, ServerSocketChannel rawChannel) throws IOException {
try {
return new NioServerSocketChannel(rawChannel, this, selector);
} catch (Exception e) {
closeRawChannel(rawChannel, e);
throw e;
}
}

private void scheduleChannel(NioSocketChannel channel, SocketSelector selector) {
try {
selector.scheduleForRegistration(channel);
} catch (IllegalStateException e) {
IOUtils.closeWhileHandlingException(channel.getRawChannel());
closeRawChannel(channel.getRawChannel(), e);
throw e;
}
}
Expand All @@ -99,7 +116,7 @@ private void scheduleServerChannel(NioServerSocketChannel channel, AcceptingSele
try {
selector.scheduleForRegistration(channel);
} catch (IllegalStateException e) {
IOUtils.closeWhileHandlingException(channel.getRawChannel());
closeRawChannel(channel.getRawChannel(), e);
throw e;
}
}
Expand All @@ -110,6 +127,14 @@ private void setContexts(NioSocketChannel channel) {
assert channel.getWriteContext() != null : "write context should have been set on channel";
}

private static void closeRawChannel(Closeable c, Exception e) {
try {
c.close();
} catch (IOException closeException) {
e.addSuppressed(closeException);
}
}

static class RawChannelFactory {

private final boolean tcpNoDelay;
Expand Down Expand Up @@ -142,7 +167,12 @@ SocketChannel openNioChannel(InetSocketAddress remoteAddress) throws IOException
SocketChannel acceptNioChannel(NioServerSocketChannel serverChannel) throws IOException {
ServerSocketChannel serverSocketChannel = serverChannel.getRawChannel();
SocketChannel socketChannel = PrivilegedSocketAccess.accept(serverSocketChannel);
configureSocketChannel(socketChannel);
try {
configureSocketChannel(socketChannel);
} catch (IOException e) {
closeRawChannel(socketChannel, e);
throw e;
}
return socketChannel;
}

Expand All @@ -160,14 +190,6 @@ ServerSocketChannel openNioServerSocketChannel(InetSocketAddress address) throws
return serverSocketChannel;
}

private void closeRawChannel(Closeable c, IOException e) {
try {
c.close();
} catch (IOException closeException) {
e.addSuppressed(closeException);
}
}

private void configureSocketChannel(SocketChannel channel) throws IOException {
channel.configureBlocking(false);
Socket socket = channel.socket();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,8 @@

public interface NioChannel extends TcpChannel {

String CLIENT = "client-socket";

InetSocketAddress getLocalAddress();

String getProfile();

void close();

void closeFromSelector() throws IOException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ public class NioServerSocketChannel extends AbstractNioChannel<ServerSocketChann

private final ChannelFactory channelFactory;

public NioServerSocketChannel(String profile, ServerSocketChannel socketChannel, ChannelFactory channelFactory,
AcceptingSelector selector) throws IOException {
super(profile, socketChannel, selector);
public NioServerSocketChannel(ServerSocketChannel socketChannel, ChannelFactory channelFactory, AcceptingSelector selector)
throws IOException {
super(socketChannel, selector);
this.channelFactory = channelFactory;
}

Expand All @@ -48,8 +48,7 @@ public void sendMessage(BytesReference reference, ActionListener<Void> listener)
@Override
public String toString() {
return "NioServerSocketChannel{" +
"profile=" + getProfile() +
", localAddress=" + getLocalAddress() +
"localAddress=" + getLocalAddress() +
'}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ public class NioSocketChannel extends AbstractNioChannel<SocketChannel> {
private ReadContext readContext;
private Exception connectException;

public NioSocketChannel(String profile, SocketChannel socketChannel, SocketSelector selector) throws IOException {
super(profile, socketChannel, selector);
public NioSocketChannel(SocketChannel socketChannel, SocketSelector selector) throws IOException {
super(socketChannel, selector);
this.remoteAddress = (InetSocketAddress) socketChannel.getRemoteAddress();
this.socketSelector = selector;
}
Expand Down Expand Up @@ -169,8 +169,7 @@ public void addConnectListener(ActionListener<Void> listener) {
@Override
public String toString() {
return "NioSocketChannel{" +
"profile=" + getProfile() +
", localAddress=" + getLocalAddress() +
"localAddress=" + getLocalAddress() +
", remoteAddress=" + remoteAddress +
'}';
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,11 @@
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.transport.nio.channel.ChannelFactory;
import org.elasticsearch.transport.nio.channel.DoNotRegisterServerChannel;
import org.elasticsearch.transport.nio.channel.NioChannel;
import org.elasticsearch.transport.nio.channel.NioServerSocketChannel;
import org.elasticsearch.transport.nio.channel.NioSocketChannel;
import org.elasticsearch.transport.nio.channel.ReadContext;
import org.elasticsearch.transport.nio.channel.WriteContext;
import org.junit.Before;
import org.mockito.ArgumentCaptor;

import java.io.IOException;
import java.nio.channels.SelectionKey;
Expand Down Expand Up @@ -67,7 +65,7 @@ public void setUpHandler() throws IOException {
handler = new AcceptorEventHandler(logger, openChannels, new RoundRobinSelectorSupplier(selectors), acceptedChannelCallback);

AcceptingSelector selector = mock(AcceptingSelector.class);
channel = new DoNotRegisterServerChannel("", mock(ServerSocketChannel.class), channelFactory, selector);
channel = new DoNotRegisterServerChannel(mock(ServerSocketChannel.class), channelFactory, selector);
channel.register();
}

Expand All @@ -88,7 +86,7 @@ public void testHandleRegisterSetsOP_ACCEPTInterest() {
}

public void testHandleAcceptCallsChannelFactory() throws IOException {
NioSocketChannel childChannel = new NioSocketChannel("", mock(SocketChannel.class), socketSelector);
NioSocketChannel childChannel = new NioSocketChannel(mock(SocketChannel.class), socketSelector);
when(channelFactory.acceptNioChannel(same(channel), same(socketSelector))).thenReturn(childChannel);

handler.acceptChannel(channel);
Expand All @@ -100,7 +98,7 @@ public void testHandleAcceptCallsChannelFactory() throws IOException {
@SuppressWarnings("unchecked")
public void testHandleAcceptAddsToOpenChannelsAndIsRemovedOnClose() throws IOException {
SocketChannel rawChannel = SocketChannel.open();
NioSocketChannel childChannel = new NioSocketChannel("", rawChannel, socketSelector);
NioSocketChannel childChannel = new NioSocketChannel(rawChannel, socketSelector);
childChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class));
when(channelFactory.acceptNioChannel(same(channel), same(socketSelector))).thenReturn(childChannel);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public void setUpHandler() throws IOException {
SocketSelector socketSelector = mock(SocketSelector.class);
handler = new SocketEventHandler(logger, exceptionHandler, mock(OpenChannels.class));
rawChannel = mock(SocketChannel.class);
channel = new DoNotRegisterChannel("", rawChannel, socketSelector);
channel = new DoNotRegisterChannel(rawChannel, socketSelector);
readContext = mock(ReadContext.class);
when(rawChannel.finishConnect()).thenReturn(true);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,12 @@ public void ensureClosed() throws IOException {
public void testAcceptChannel() throws IOException {
NioServerSocketChannel serverChannel = mock(NioServerSocketChannel.class);
when(rawChannelFactory.acceptNioChannel(serverChannel)).thenReturn(rawChannel);
when(serverChannel.getProfile()).thenReturn("parent-profile");

NioSocketChannel channel = channelFactory.acceptNioChannel(serverChannel, socketSelector);

verify(socketSelector).scheduleForRegistration(channel);

assertEquals(socketSelector, channel.getSelector());
assertEquals("parent-profile", channel.getProfile());
assertEquals(rawChannel, channel.getRawChannel());
}

Expand All @@ -106,7 +104,6 @@ public void testOpenChannel() throws IOException {
verify(socketSelector).scheduleForRegistration(channel);

assertEquals(socketSelector, channel.getSelector());
assertEquals("client-socket", channel.getProfile());
assertEquals(rawChannel, channel.getRawChannel());
}

Expand All @@ -124,13 +121,11 @@ public void testOpenServerChannel() throws IOException {
InetSocketAddress address = mock(InetSocketAddress.class);
when(rawChannelFactory.openNioServerSocketChannel(same(address))).thenReturn(rawServerChannel);

String profile = "profile";
NioServerSocketChannel channel = channelFactory.openNioServerSocketChannel(profile, address, acceptingSelector);
NioServerSocketChannel channel = channelFactory.openNioServerSocketChannel(address, acceptingSelector);

verify(acceptingSelector).scheduleForRegistration(channel);

assertEquals(acceptingSelector, channel.getSelector());
assertEquals(profile, channel.getProfile());
assertEquals(rawServerChannel, channel.getRawChannel());
}

Expand All @@ -139,7 +134,7 @@ public void testOpenedServerChannelRejected() throws IOException {
when(rawChannelFactory.openNioServerSocketChannel(same(address))).thenReturn(rawServerChannel);
doThrow(new IllegalStateException()).when(acceptingSelector).scheduleForRegistration(any());

expectThrows(IllegalStateException.class, () -> channelFactory.openNioServerSocketChannel("", address, acceptingSelector));
expectThrows(IllegalStateException.class, () -> channelFactory.openNioServerSocketChannel(address, acceptingSelector));

assertFalse(rawServerChannel.isOpen());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@

public class DoNotRegisterChannel extends NioSocketChannel {

public DoNotRegisterChannel(String profile, SocketChannel socketChannel, SocketSelector selector) throws IOException {
super(profile, socketChannel, selector);
public DoNotRegisterChannel(SocketChannel socketChannel, SocketSelector selector) throws IOException {
super(socketChannel, selector);
}

@Override
Expand Down
Loading

0 comments on commit aec44a0

Please sign in to comment.