diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/ChannelFactory.java b/libs/nio/src/main/java/org/elasticsearch/nio/ChannelFactory.java index f0dc3e567fef6..97750af2432c3 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/ChannelFactory.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/ChannelFactory.java @@ -146,7 +146,7 @@ private static void closeRawChannel(Closeable c, Exception e) { } } - protected static class RawChannelFactory { + public static class RawChannelFactory { private final boolean tcpNoDelay; private final boolean tcpKeepAlive; diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java index b91015214e02f..c0005cffc79a7 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java @@ -51,6 +51,7 @@ import java.nio.channels.SocketChannel; import java.util.concurrent.ConcurrentMap; import java.util.function.Consumer; +import java.util.function.Function; import java.util.function.Supplier; import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap; @@ -67,7 +68,7 @@ public class NioTransport extends TcpTransport { protected final PageCacheRecycler pageCacheRecycler; private final ConcurrentMap profileToChannelFactory = newConcurrentMap(); private volatile NioGroup nioGroup; - private volatile TcpChannelFactory clientChannelFactory; + private volatile Function clientChannelFactory; protected NioTransport(Settings settings, Version version, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays, PageCacheRecycler pageCacheRecycler, NamedWriteableRegistry namedWriteableRegistry, @@ -85,8 +86,7 @@ protected NioTcpServerChannel bind(String name, InetSocketAddress address) throw @Override protected NioTcpChannel initiateChannel(DiscoveryNode node) throws IOException { InetSocketAddress address = node.getAddress().address(); - NioTcpChannel channel = nioGroup.openChannel(address, clientChannelFactory); - return channel; + return nioGroup.openChannel(address, clientChannelFactory.apply(node)); } @Override @@ -97,13 +97,13 @@ protected void doStart() { NioTransport.NIO_WORKER_COUNT.get(settings), (s) -> new EventHandler(this::onNonChannelException, s)); ProfileSettings clientProfileSettings = new ProfileSettings(settings, "default"); - clientChannelFactory = channelFactory(clientProfileSettings, true); + clientChannelFactory = clientChannelFactoryFunction(clientProfileSettings); if (NetworkService.NETWORK_SERVER.get(settings)) { // loop through all profiles and start them up, special handling for default one for (ProfileSettings profileSettings : profileSettings) { String profileName = profileSettings.profileName; - TcpChannelFactory factory = channelFactory(profileSettings, false); + TcpChannelFactory factory = serverChannelFactory(profileSettings); profileToChannelFactory.putIfAbsent(profileName, factory); bindServer(profileSettings); } @@ -134,8 +134,12 @@ protected void acceptChannel(NioSocketChannel channel) { serverAcceptedChannel((NioTcpChannel) channel); } - protected TcpChannelFactory channelFactory(ProfileSettings settings, boolean isClient) { - return new TcpChannelFactoryImpl(settings); + protected TcpChannelFactory serverChannelFactory(ProfileSettings profileSettings) { + return new TcpChannelFactoryImpl(profileSettings); + } + + protected Function clientChannelFactoryFunction(ProfileSettings profileSettings) { + return (n) -> new TcpChannelFactoryImpl(profileSettings); } protected abstract class TcpChannelFactory extends ChannelFactory { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ssl/SSLService.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ssl/SSLService.java index 428daf56059ca..1a7641ef64b88 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ssl/SSLService.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ssl/SSLService.java @@ -262,14 +262,14 @@ public boolean isSSLClientAuthEnabled(SSLConfiguration sslConfiguration) { /** * Returns the {@link SSLContext} for the global configuration. Mainly used for testing */ - SSLContext sslContext() { + public SSLContext sslContext() { return sslContextHolder(globalSSLConfiguration).sslContext(); } /** - * Returns the {@link SSLContext} for the configuration + * Returns the {@link SSLContext} for the configuration. Mainly used for testing */ - SSLContext sslContext(SSLConfiguration configuration) { + public SSLContext sslContext(SSLConfiguration configuration) { return sslContextHolder(configuration).sslContext(); } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java index 3c576d8350e26..57d6b39236039 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java @@ -9,6 +9,7 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.Version; +import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.CloseableChannel; @@ -19,12 +20,14 @@ import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.nio.BytesChannelContext; +import org.elasticsearch.nio.ChannelFactory; import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.ServerChannelContext; import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.ConnectTransportException; import org.elasticsearch.transport.TcpChannel; import org.elasticsearch.transport.TcpTransport; import org.elasticsearch.transport.nio.NioTcpChannel; @@ -38,7 +41,9 @@ import org.elasticsearch.xpack.core.ssl.SSLService; import org.elasticsearch.xpack.security.transport.filter.IPFilter; +import javax.net.ssl.SNIHostName; import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLParameters; import java.io.IOException; import java.net.InetSocketAddress; import java.nio.ByteBuffer; @@ -47,6 +52,7 @@ import java.util.Collections; import java.util.Map; import java.util.function.Consumer; +import java.util.function.Function; import java.util.function.Supplier; import static org.elasticsearch.xpack.core.security.SecurityField.setting; @@ -128,8 +134,29 @@ public void onException(TcpChannel channel, Exception e) { } @Override - protected TcpChannelFactory channelFactory(ProfileSettings profileSettings, boolean isClient) { - return new SecurityTcpChannelFactory(profileSettings, isClient); + protected TcpChannelFactory serverChannelFactory(ProfileSettings profileSettings) { + return new SecurityTcpChannelFactory(profileSettings, false); + } + + @Override + protected Function clientChannelFactoryFunction(ProfileSettings profileSettings) { + return (node) -> { + final ChannelFactory.RawChannelFactory rawChannelFactory = new ChannelFactory.RawChannelFactory(profileSettings.tcpNoDelay, + profileSettings.tcpKeepAlive, profileSettings.reuseAddress, Math.toIntExact(profileSettings.sendBufferSize.getBytes()), + Math.toIntExact(profileSettings.receiveBufferSize.getBytes())); + SNIHostName serverName; + String configuredServerName = node.getAttributes().get("server_name"); + if (configuredServerName != null) { + try { + serverName = new SNIHostName(configuredServerName); + } catch (IllegalArgumentException e) { + throw new ConnectTransportException(node, "invalid DiscoveryNode server_name [" + configuredServerName + "]", e); + } + } else { + serverName = null; + } + return new SecurityClientTcpChannelFactory(rawChannelFactory, serverName); + }; } private class SecurityTcpChannelFactory extends TcpChannelFactory { @@ -139,12 +166,16 @@ private class SecurityTcpChannelFactory extends TcpChannelFactory { private final NioIPFilter ipFilter; private SecurityTcpChannelFactory(ProfileSettings profileSettings, boolean isClient) { - super(new RawChannelFactory(profileSettings.tcpNoDelay, + this(new RawChannelFactory(profileSettings.tcpNoDelay, profileSettings.tcpKeepAlive, profileSettings.reuseAddress, Math.toIntExact(profileSettings.sendBufferSize.getBytes()), - Math.toIntExact(profileSettings.receiveBufferSize.getBytes()))); - this.profileName = profileSettings.profileName; + Math.toIntExact(profileSettings.receiveBufferSize.getBytes())), profileSettings.profileName, isClient); + } + + private SecurityTcpChannelFactory(RawChannelFactory rawChannelFactory, String profileName, boolean isClient) { + super(rawChannelFactory); + this.profileName = profileName; this.isClient = isClient; this.ipFilter = new NioIPFilter(authenticator, profileName); } @@ -162,18 +193,7 @@ public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel) SocketChannelContext context; if (sslEnabled) { - SSLEngine sslEngine; - SSLConfiguration defaultConfig = profileConfiguration.get(TcpTransport.DEFAULT_PROFILE); - SSLConfiguration sslConfig = profileConfiguration.getOrDefault(profileName, defaultConfig); - boolean hostnameVerificationEnabled = sslConfig.verificationMode().isHostnameVerificationEnabled(); - if (hostnameVerificationEnabled) { - InetSocketAddress inetSocketAddress = (InetSocketAddress) channel.getRemoteAddress(); - // we create the socket based on the name given. don't reverse DNS - sslEngine = sslService.createSSLEngine(sslConfig, inetSocketAddress.getHostString(), inetSocketAddress.getPort()); - } else { - sslEngine = sslService.createSSLEngine(sslConfig, null, -1); - } - SSLDriver sslDriver = new SSLDriver(sslEngine, isClient); + SSLDriver sslDriver = new SSLDriver(createSSLEngine(channel), isClient); context = new SSLChannelContext(nioChannel, selector, exceptionHandler, sslDriver, readWriteHandler, buffer, ipFilter); } else { context = new BytesChannelContext(nioChannel, selector, exceptionHandler, readWriteHandler, buffer, ipFilter); @@ -192,5 +212,46 @@ public NioTcpServerChannel createServerChannel(NioSelector selector, ServerSocke nioChannel.setContext(context); return nioChannel; } + + protected SSLEngine createSSLEngine(SocketChannel channel) throws IOException { + SSLEngine sslEngine; + SSLConfiguration defaultConfig = profileConfiguration.get(TcpTransport.DEFAULT_PROFILE); + SSLConfiguration sslConfig = profileConfiguration.getOrDefault(profileName, defaultConfig); + boolean hostnameVerificationEnabled = sslConfig.verificationMode().isHostnameVerificationEnabled(); + if (hostnameVerificationEnabled) { + InetSocketAddress inetSocketAddress = (InetSocketAddress) channel.getRemoteAddress(); + // we create the socket based on the name given. don't reverse DNS + sslEngine = sslService.createSSLEngine(sslConfig, inetSocketAddress.getHostString(), inetSocketAddress.getPort()); + } else { + sslEngine = sslService.createSSLEngine(sslConfig, null, -1); + } + return sslEngine; + } + } + + private class SecurityClientTcpChannelFactory extends SecurityTcpChannelFactory { + + private final SNIHostName serverName; + + private SecurityClientTcpChannelFactory(RawChannelFactory rawChannelFactory, SNIHostName serverName) { + super(rawChannelFactory, TcpTransport.DEFAULT_PROFILE, true); + this.serverName = serverName; + } + + @Override + public NioTcpServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) { + throw new AssertionError("Cannot create TcpServerChannel with client factory"); + } + + @Override + protected SSLEngine createSSLEngine(SocketChannel channel) throws IOException { + SSLEngine sslEngine = super.createSSLEngine(channel); + if (serverName != null) { + SSLParameters sslParameters = sslEngine.getSSLParameters(); + sslParameters.setServerNames(Collections.singletonList(serverName)); + sslEngine.setSSLParameters(sslParameters); + } + return sslEngine; + } } } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/AbstractSimpleSecurityTransportTestCase.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/AbstractSimpleSecurityTransportTestCase.java index a716d955bbe4e..0bf08d1cb07be 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/AbstractSimpleSecurityTransportTestCase.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/AbstractSimpleSecurityTransportTestCase.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.settings.MockSecureSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; +import org.elasticsearch.core.internal.io.IOUtils; import org.elasticsearch.env.TestEnvironment; import org.elasticsearch.node.Node; import org.elasticsearch.test.transport.MockTransportService; @@ -21,6 +22,7 @@ import org.elasticsearch.transport.ConnectTransportException; import org.elasticsearch.transport.ConnectionProfile; import org.elasticsearch.transport.TcpTransport; +import org.elasticsearch.transport.TransportRequestOptions; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.common.socket.SocketAccess; import org.elasticsearch.xpack.core.ssl.SSLConfiguration; @@ -28,12 +30,24 @@ import javax.net.SocketFactory; import javax.net.ssl.HandshakeCompletedListener; +import javax.net.ssl.SNIHostName; +import javax.net.ssl.SNIMatcher; +import javax.net.ssl.SNIServerName; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLServerSocket; +import javax.net.ssl.SSLServerSocketFactory; import javax.net.ssl.SSLSocket; import java.io.IOException; +import java.io.UncheckedIOException; import java.net.InetAddress; +import java.net.InetSocketAddress; import java.net.SocketTimeoutException; import java.net.UnknownHostException; import java.nio.file.Path; +import java.util.Collections; +import java.util.EnumSet; +import java.util.HashMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; @@ -44,6 +58,19 @@ public abstract class AbstractSimpleSecurityTransportTestCase extends AbstractSimpleTransportTestCase { + private static final ConnectionProfile SINGLE_CHANNEL_PROFILE; + + static { + ConnectionProfile.Builder builder = new ConnectionProfile.Builder(); + builder.addConnections(1, + TransportRequestOptions.Type.BULK, + TransportRequestOptions.Type.PING, + TransportRequestOptions.Type.RECOVERY, + TransportRequestOptions.Type.REG, + TransportRequestOptions.Type.STATE); + SINGLE_CHANNEL_PROFILE = builder.build(); + } + protected SSLService createSSLService() { return createSSLService(Settings.EMPTY); } @@ -54,11 +81,11 @@ protected SSLService createSSLService(Settings settings) { MockSecureSettings secureSettings = new MockSecureSettings(); secureSettings.setString("xpack.ssl.secure_key_passphrase", "testnode"); Settings settings1 = Settings.builder() - .put(settings) .put("xpack.security.transport.ssl.enabled", true) .put("xpack.ssl.key", testnodeKey) .put("xpack.ssl.certificate", testnodeCert) .put("path.home", createTempDir()) + .put(settings) .setSecureSettings(secureSettings) .build(); try { @@ -167,4 +194,108 @@ public void testRenegotiation() throws Exception { stream.flush(); } } + + public void testSNIServerNameIsPropagated() throws Exception { + SSLService sslService = createSSLService(); + + final SSLConfiguration sslConfiguration = sslService.getSSLConfiguration("xpack.ssl"); + SSLContext sslContext = sslService.sslContext(sslConfiguration); + final SSLServerSocketFactory serverSocketFactory = sslContext.getServerSocketFactory(); + final String sniIp = "sni-hostname"; + final SNIHostName sniHostName = new SNIHostName(sniIp); + final CountDownLatch latch = new CountDownLatch(2); + + try (SSLServerSocket sslServerSocket = (SSLServerSocket) serverSocketFactory.createServerSocket()) { + SSLParameters sslParameters = sslServerSocket.getSSLParameters(); + sslParameters.setSNIMatchers(Collections.singletonList(new SNIMatcher(0) { + @Override + public boolean matches(SNIServerName sniServerName) { + if (sniHostName.equals(sniServerName)) { + latch.countDown(); + return true; + } else { + return false; + } + } + })); + sslServerSocket.setSSLParameters(sslParameters); + + SocketAccess.doPrivileged(() -> sslServerSocket.bind(getLocalEphemeral())); + + new Thread(() -> { + try { + SSLSocket acceptedSocket = (SSLSocket) SocketAccess.doPrivileged(sslServerSocket::accept); + acceptedSocket.addHandshakeCompletedListener((e) -> { + latch.countDown(); + IOUtils.closeWhileHandlingException(acceptedSocket); + }); + // A read call will execute the handshake + acceptedSocket.getInputStream().read(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + }).start(); + + InetSocketAddress serverAddress = (InetSocketAddress) SocketAccess.doPrivileged(sslServerSocket::getLocalSocketAddress); + + Settings settings = Settings.builder().put("name", "TS_TEST").put("xpack.ssl.verification_mode", "none").build(); + try (MockTransportService serviceC = build(settings, version0, null, true)) { + serviceC.acceptIncomingRequests(); + + HashMap attributes = new HashMap<>(); + attributes.put("server_name", sniIp); + DiscoveryNode node = new DiscoveryNode("server_node_id", new TransportAddress(serverAddress), attributes, + EnumSet.allOf(DiscoveryNode.Role.class), Version.CURRENT); + + new Thread(() -> { + try { + serviceC.connectToNode(node, SINGLE_CHANNEL_PROFILE); + } catch (ConnectTransportException ex) { + // Ignore. The other side is not setup to do the ES handshake. So this will fail. + } + }).start(); + + latch.await(); + } + } + } + + public void testInvalidSNIServerName() throws Exception { + SSLService sslService = createSSLService(); + + final SSLConfiguration sslConfiguration = sslService.getSSLConfiguration("xpack.ssl"); + SSLContext sslContext = sslService.sslContext(sslConfiguration); + final SSLServerSocketFactory serverSocketFactory = sslContext.getServerSocketFactory(); + final String sniIp = "invalid_hostname"; + + try (SSLServerSocket sslServerSocket = (SSLServerSocket) serverSocketFactory.createServerSocket()) { + SocketAccess.doPrivileged(() -> sslServerSocket.bind(getLocalEphemeral())); + + new Thread(() -> { + try { + SocketAccess.doPrivileged(sslServerSocket::accept); + } catch (IOException e) { + // We except an IOException from the `accept` call because the server socket will be + // closed before the call returns. + } + }).start(); + + InetSocketAddress serverAddress = (InetSocketAddress) SocketAccess.doPrivileged(sslServerSocket::getLocalSocketAddress); + + Settings settings = Settings.builder().put("name", "TS_TEST").put("xpack.ssl.verification_mode", "none").build(); + try (MockTransportService serviceC = build(settings, version0, null, true)) { + serviceC.acceptIncomingRequests(); + + HashMap attributes = new HashMap<>(); + attributes.put("server_name", sniIp); + DiscoveryNode node = new DiscoveryNode("server_node_id", new TransportAddress(serverAddress), attributes, + EnumSet.allOf(DiscoveryNode.Role.class), Version.CURRENT); + + ConnectTransportException connectException = expectThrows(ConnectTransportException.class, + () -> serviceC.connectToNode(node, SINGLE_CHANNEL_PROFILE)); + + assertThat(connectException.getMessage(), containsString("invalid DiscoveryNode server_name [invalid_hostname]")); + } + } + } } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SimpleSecurityNetty4ServerTransportTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SimpleSecurityNetty4ServerTransportTests.java index ec85c41e6b107..bf31240148a2e 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SimpleSecurityNetty4ServerTransportTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SimpleSecurityNetty4ServerTransportTests.java @@ -5,13 +5,6 @@ */ package org.elasticsearch.xpack.security.transport.netty4; -import io.netty.bootstrap.ServerBootstrap; -import io.netty.channel.Channel; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelInitializer; -import io.netty.channel.nio.NioEventLoopGroup; -import io.netty.channel.socket.nio.NioServerSocketChannel; -import io.netty.handler.ssl.SslHandler; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.node.DiscoveryNode; @@ -19,51 +12,20 @@ import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.transport.ConnectTransportException; import org.elasticsearch.transport.ConnectionProfile; import org.elasticsearch.transport.TcpChannel; import org.elasticsearch.transport.TcpTransport; import org.elasticsearch.transport.Transport; -import org.elasticsearch.transport.TransportRequestOptions; -import org.elasticsearch.transport.TransportService; -import org.elasticsearch.xpack.core.ssl.SSLService; import org.elasticsearch.xpack.security.transport.AbstractSimpleSecurityTransportTestCase; -import javax.net.ssl.SNIHostName; -import javax.net.ssl.SNIMatcher; -import javax.net.ssl.SNIServerName; -import javax.net.ssl.SSLEngine; -import javax.net.ssl.SSLParameters; -import java.net.InetSocketAddress; import java.util.Collections; -import java.util.EnumSet; -import java.util.HashMap; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; - -import static org.elasticsearch.xpack.core.security.SecurityField.setting; -import static org.hamcrest.Matchers.containsString; public class SimpleSecurityNetty4ServerTransportTests extends AbstractSimpleSecurityTransportTestCase { - private static final ConnectionProfile SINGLE_CHANNEL_PROFILE; - - static { - ConnectionProfile.Builder builder = new ConnectionProfile.Builder(); - builder.addConnections(1, - TransportRequestOptions.Type.BULK, - TransportRequestOptions.Type.PING, - TransportRequestOptions.Type.RECOVERY, - TransportRequestOptions.Type.REG, - TransportRequestOptions.Type.STATE); - SINGLE_CHANNEL_PROFILE = builder.build(); - } - public MockTransportService nettyFromThreadPool(Settings settings, ThreadPool threadPool, final Version version, ClusterSettings clusterSettings, boolean doHandshake) { NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); @@ -103,134 +65,4 @@ protected MockTransportService build(Settings settings, Version version, Cluster transportService.start(); return transportService; } - - public void testSNIServerNameIsPropagated() throws Exception { - SSLService sslService = createSSLService(); - final ServerBootstrap serverBootstrap = new ServerBootstrap(); - boolean success = false; - try { - serverBootstrap.group(new NioEventLoopGroup(1)); - serverBootstrap.channel(NioServerSocketChannel.class); - - final String sniIp = "sni-hostname"; - final SNIHostName sniHostName = new SNIHostName(sniIp); - final CountDownLatch latch = new CountDownLatch(2); - serverBootstrap.childHandler(new ChannelInitializer() { - - @Override - protected void initChannel(Channel ch) { - SSLEngine serverEngine = sslService.createSSLEngine(sslService.getSSLConfiguration(setting("transport.ssl.")), - null, -1); - serverEngine.setUseClientMode(false); - SSLParameters sslParameters = serverEngine.getSSLParameters(); - sslParameters.setSNIMatchers(Collections.singletonList(new SNIMatcher(0) { - @Override - public boolean matches(SNIServerName sniServerName) { - if (sniHostName.equals(sniServerName)) { - latch.countDown(); - return true; - } else { - return false; - } - } - })); - serverEngine.setSSLParameters(sslParameters); - final SslHandler sslHandler = new SslHandler(serverEngine); - sslHandler.handshakeFuture().addListener(future -> latch.countDown()); - ch.pipeline().addFirst("sslhandler", sslHandler); - } - }); - serverBootstrap.validate(); - ChannelFuture serverFuture = serverBootstrap.bind(getLocalEphemeral()); - serverFuture.await(); - InetSocketAddress serverAddress = (InetSocketAddress) serverFuture.channel().localAddress(); - - try (MockTransportService serviceC = build( - Settings.builder() - .put("name", "TS_TEST") - .put(TransportService.TRACE_LOG_INCLUDE_SETTING.getKey(), "") - .put(TransportService.TRACE_LOG_EXCLUDE_SETTING.getKey(), "NOTHING") - .build(), - version0, - null, true)) { - serviceC.acceptIncomingRequests(); - - HashMap attributes = new HashMap<>(); - attributes.put("server_name", sniIp); - DiscoveryNode node = new DiscoveryNode("server_node_id", new TransportAddress(serverAddress), attributes, - EnumSet.allOf(DiscoveryNode.Role.class), Version.CURRENT); - - new Thread(() -> { - try { - serviceC.connectToNode(node, SINGLE_CHANNEL_PROFILE); - } catch (ConnectTransportException ex) { - // Ignore. The other side is not setup to do the ES handshake. So this will fail. - } - }).start(); - - latch.await(); - serverBootstrap.config().group().shutdownGracefully(0, 5, TimeUnit.SECONDS); - success = true; - } - } finally { - if (success == false) { - serverBootstrap.config().group().shutdownGracefully(0, 5, TimeUnit.SECONDS); - } - } - } - - public void testInvalidSNIServerName() throws Exception { - SSLService sslService = createSSLService(); - final ServerBootstrap serverBootstrap = new ServerBootstrap(); - boolean success = false; - try { - serverBootstrap.group(new NioEventLoopGroup(1)); - serverBootstrap.channel(NioServerSocketChannel.class); - - final String sniIp = "invalid_hostname"; - serverBootstrap.childHandler(new ChannelInitializer() { - - @Override - protected void initChannel(Channel ch) { - SSLEngine serverEngine = sslService.createSSLEngine(sslService.getSSLConfiguration(setting("transport.ssl.")), - null, -1); - serverEngine.setUseClientMode(false); - final SslHandler sslHandler = new SslHandler(serverEngine); - ch.pipeline().addFirst("sslhandler", sslHandler); - } - }); - serverBootstrap.validate(); - ChannelFuture serverFuture = serverBootstrap.bind(getLocalEphemeral()); - serverFuture.await(); - InetSocketAddress serverAddress = (InetSocketAddress) serverFuture.channel().localAddress(); - - try (MockTransportService serviceC = build( - Settings.builder() - .put("name", "TS_TEST") - .put(TransportService.TRACE_LOG_INCLUDE_SETTING.getKey(), "") - .put(TransportService.TRACE_LOG_EXCLUDE_SETTING.getKey(), "NOTHING") - .build(), - version0, - null, true)) { - serviceC.acceptIncomingRequests(); - - HashMap attributes = new HashMap<>(); - attributes.put("server_name", sniIp); - DiscoveryNode node = new DiscoveryNode("server_node_id", new TransportAddress(serverAddress), attributes, - EnumSet.allOf(DiscoveryNode.Role.class), Version.CURRENT); - - ConnectTransportException connectException = expectThrows(ConnectTransportException.class, - () -> serviceC.connectToNode(node, SINGLE_CHANNEL_PROFILE)); - - assertThat(connectException.getMessage(), containsString("invalid DiscoveryNode server_name [invalid_hostname]")); - - serverBootstrap.config().group().shutdownGracefully(0, 5, TimeUnit.SECONDS); - success = true; - } - } finally { - if (success == false) { - serverBootstrap.config().group().shutdownGracefully(0, 5, TimeUnit.SECONDS); - } - } - } }