Skip to content

Commit

Permalink
Add sni name to SSLEngine in nio transport
Browse files Browse the repository at this point in the history
This commit is related to elastic#32517. It allows an "sni_server_name"
attribute on a DiscoveryNode to be propagated to the server using
the TLS SNI extentsion. Prior to this commit, this functionality
was only support for the netty transport. This commit adds this
functionality to the security nio transport.
  • Loading branch information
Tim-Brooks committed Nov 26, 2018
1 parent 3f7cae3 commit 8eaead4
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 197 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ private static void closeRawChannel(Closeable c, Exception e) {
}
}

protected static class RawChannelFactory {
public static class RawChannelFactory {

private final boolean tcpNoDelay;
private final boolean tcpKeepAlive;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import java.nio.channels.SocketChannel;
import java.util.concurrent.ConcurrentMap;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;

import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap;
Expand All @@ -67,7 +68,7 @@ public class NioTransport extends TcpTransport {
protected final PageCacheRecycler pageCacheRecycler;
private final ConcurrentMap<String, TcpChannelFactory> profileToChannelFactory = newConcurrentMap();
private volatile NioGroup nioGroup;
private volatile TcpChannelFactory clientChannelFactory;
private volatile Function<DiscoveryNode, TcpChannelFactory> clientChannelFactory;

protected NioTransport(Settings settings, Version version, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays,
PageCacheRecycler pageCacheRecycler, NamedWriteableRegistry namedWriteableRegistry,
Expand All @@ -85,8 +86,7 @@ protected NioTcpServerChannel bind(String name, InetSocketAddress address) throw
@Override
protected NioTcpChannel initiateChannel(DiscoveryNode node) throws IOException {
InetSocketAddress address = node.getAddress().address();
NioTcpChannel channel = nioGroup.openChannel(address, clientChannelFactory);
return channel;
return nioGroup.openChannel(address, clientChannelFactory.apply(node));
}

@Override
Expand All @@ -97,13 +97,13 @@ protected void doStart() {
NioTransport.NIO_WORKER_COUNT.get(settings), (s) -> new EventHandler(this::onNonChannelException, s));

ProfileSettings clientProfileSettings = new ProfileSettings(settings, "default");
clientChannelFactory = channelFactory(clientProfileSettings, true);
clientChannelFactory = clientChannelFactoryFunction(clientProfileSettings);

if (NetworkService.NETWORK_SERVER.get(settings)) {
// loop through all profiles and start them up, special handling for default one
for (ProfileSettings profileSettings : profileSettings) {
String profileName = profileSettings.profileName;
TcpChannelFactory factory = channelFactory(profileSettings, false);
TcpChannelFactory factory = serverChannelFactory(profileSettings);
profileToChannelFactory.putIfAbsent(profileName, factory);
bindServer(profileSettings);
}
Expand Down Expand Up @@ -134,8 +134,12 @@ protected void acceptChannel(NioSocketChannel channel) {
serverAcceptedChannel((NioTcpChannel) channel);
}

protected TcpChannelFactory channelFactory(ProfileSettings settings, boolean isClient) {
return new TcpChannelFactoryImpl(settings);
protected TcpChannelFactory serverChannelFactory(ProfileSettings profileSettings) {
return new TcpChannelFactoryImpl(profileSettings);
}

protected Function<DiscoveryNode, TcpChannelFactory> clientChannelFactoryFunction(ProfileSettings profileSettings) {
return (n) -> new TcpChannelFactoryImpl(profileSettings);
}

protected abstract class TcpChannelFactory extends ChannelFactory<NioTcpServerChannel, NioTcpChannel> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,14 +262,14 @@ public boolean isSSLClientAuthEnabled(SSLConfiguration sslConfiguration) {
/**
* Returns the {@link SSLContext} for the global configuration. Mainly used for testing
*/
SSLContext sslContext() {
public SSLContext sslContext() {
return sslContextHolder(globalSSLConfiguration).sslContext();
}

/**
* Returns the {@link SSLContext} for the configuration
* Returns the {@link SSLContext} for the configuration. Mainly used for testing
*/
SSLContext sslContext(SSLConfiguration configuration) {
public SSLContext sslContext(SSLConfiguration configuration) {
return sslContextHolder(configuration).sslContext();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.Version;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.network.CloseableChannel;
Expand All @@ -19,12 +20,14 @@
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.nio.BytesChannelContext;
import org.elasticsearch.nio.ChannelFactory;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.ServerChannelContext;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.ConnectTransportException;
import org.elasticsearch.transport.TcpChannel;
import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.nio.NioTcpChannel;
Expand All @@ -38,7 +41,9 @@
import org.elasticsearch.xpack.core.ssl.SSLService;
import org.elasticsearch.xpack.security.transport.filter.IPFilter;

import javax.net.ssl.SNIHostName;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLParameters;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
Expand All @@ -47,6 +52,7 @@
import java.util.Collections;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;

import static org.elasticsearch.xpack.core.security.SecurityField.setting;
Expand Down Expand Up @@ -128,8 +134,29 @@ public void onException(TcpChannel channel, Exception e) {
}

@Override
protected TcpChannelFactory channelFactory(ProfileSettings profileSettings, boolean isClient) {
return new SecurityTcpChannelFactory(profileSettings, isClient);
protected TcpChannelFactory serverChannelFactory(ProfileSettings profileSettings) {
return new SecurityTcpChannelFactory(profileSettings, false);
}

@Override
protected Function<DiscoveryNode, TcpChannelFactory> clientChannelFactoryFunction(ProfileSettings profileSettings) {
return (node) -> {
final ChannelFactory.RawChannelFactory rawChannelFactory = new ChannelFactory.RawChannelFactory(profileSettings.tcpNoDelay,
profileSettings.tcpKeepAlive, profileSettings.reuseAddress, Math.toIntExact(profileSettings.sendBufferSize.getBytes()),
Math.toIntExact(profileSettings.receiveBufferSize.getBytes()));
SNIHostName serverName;
String configuredServerName = node.getAttributes().get("server_name");
if (configuredServerName != null) {
try {
serverName = new SNIHostName(configuredServerName);
} catch (IllegalArgumentException e) {
throw new ConnectTransportException(node, "invalid DiscoveryNode server_name [" + configuredServerName + "]", e);
}
} else {
serverName = null;
}
return new SecurityClientTcpChannelFactory(rawChannelFactory, serverName);
};
}

private class SecurityTcpChannelFactory extends TcpChannelFactory {
Expand All @@ -139,12 +166,16 @@ private class SecurityTcpChannelFactory extends TcpChannelFactory {
private final NioIPFilter ipFilter;

private SecurityTcpChannelFactory(ProfileSettings profileSettings, boolean isClient) {
super(new RawChannelFactory(profileSettings.tcpNoDelay,
this(new RawChannelFactory(profileSettings.tcpNoDelay,
profileSettings.tcpKeepAlive,
profileSettings.reuseAddress,
Math.toIntExact(profileSettings.sendBufferSize.getBytes()),
Math.toIntExact(profileSettings.receiveBufferSize.getBytes())));
this.profileName = profileSettings.profileName;
Math.toIntExact(profileSettings.receiveBufferSize.getBytes())), profileSettings.profileName, isClient);
}

private SecurityTcpChannelFactory(RawChannelFactory rawChannelFactory, String profileName, boolean isClient) {
super(rawChannelFactory);
this.profileName = profileName;
this.isClient = isClient;
this.ipFilter = new NioIPFilter(authenticator, profileName);
}
Expand All @@ -162,18 +193,7 @@ public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel)

SocketChannelContext context;
if (sslEnabled) {
SSLEngine sslEngine;
SSLConfiguration defaultConfig = profileConfiguration.get(TcpTransport.DEFAULT_PROFILE);
SSLConfiguration sslConfig = profileConfiguration.getOrDefault(profileName, defaultConfig);
boolean hostnameVerificationEnabled = sslConfig.verificationMode().isHostnameVerificationEnabled();
if (hostnameVerificationEnabled) {
InetSocketAddress inetSocketAddress = (InetSocketAddress) channel.getRemoteAddress();
// we create the socket based on the name given. don't reverse DNS
sslEngine = sslService.createSSLEngine(sslConfig, inetSocketAddress.getHostString(), inetSocketAddress.getPort());
} else {
sslEngine = sslService.createSSLEngine(sslConfig, null, -1);
}
SSLDriver sslDriver = new SSLDriver(sslEngine, isClient);
SSLDriver sslDriver = new SSLDriver(createSSLEngine(channel), isClient);
context = new SSLChannelContext(nioChannel, selector, exceptionHandler, sslDriver, readWriteHandler, buffer, ipFilter);
} else {
context = new BytesChannelContext(nioChannel, selector, exceptionHandler, readWriteHandler, buffer, ipFilter);
Expand All @@ -192,5 +212,46 @@ public NioTcpServerChannel createServerChannel(NioSelector selector, ServerSocke
nioChannel.setContext(context);
return nioChannel;
}

protected SSLEngine createSSLEngine(SocketChannel channel) throws IOException {
SSLEngine sslEngine;
SSLConfiguration defaultConfig = profileConfiguration.get(TcpTransport.DEFAULT_PROFILE);
SSLConfiguration sslConfig = profileConfiguration.getOrDefault(profileName, defaultConfig);
boolean hostnameVerificationEnabled = sslConfig.verificationMode().isHostnameVerificationEnabled();
if (hostnameVerificationEnabled) {
InetSocketAddress inetSocketAddress = (InetSocketAddress) channel.getRemoteAddress();
// we create the socket based on the name given. don't reverse DNS
sslEngine = sslService.createSSLEngine(sslConfig, inetSocketAddress.getHostString(), inetSocketAddress.getPort());
} else {
sslEngine = sslService.createSSLEngine(sslConfig, null, -1);
}
return sslEngine;
}
}

private class SecurityClientTcpChannelFactory extends SecurityTcpChannelFactory {

private final SNIHostName serverName;

private SecurityClientTcpChannelFactory(RawChannelFactory rawChannelFactory, SNIHostName serverName) {
super(rawChannelFactory, TcpTransport.DEFAULT_PROFILE, true);
this.serverName = serverName;
}

@Override
public NioTcpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) {
throw new AssertionError("Cannot create TcpServerChannel with client factory");
}

@Override
protected SSLEngine createSSLEngine(SocketChannel channel) throws IOException {
SSLEngine sslEngine = super.createSSLEngine(channel);
if (serverName != null) {
SSLParameters sslParameters = sslEngine.getSSLParameters();
sslParameters.setServerNames(Collections.singletonList(serverName));
sslEngine.setSSLParameters(sslParameters);
}
return sslEngine;
}
}
}
Loading

0 comments on commit 8eaead4

Please sign in to comment.