Skip to content

Commit

Permalink
Merge pull request #566 from zhenlineo/1.7-hostname-for-sni
Browse files Browse the repository at this point in the history
Fixed the bug where original host name is lost in ssl handshake.
  • Loading branch information
zhenlineo authored Mar 6, 2019
2 parents 00522bc + e9e5a93 commit bc68f0f
Show file tree
Hide file tree
Showing 10 changed files with 34 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.net.SocketAddress;
import java.net.URI;
import java.net.UnknownHostException;
import java.util.Objects;

import org.neo4j.driver.v1.net.ServerAddress;

Expand All @@ -36,7 +37,8 @@ public class BoltServerAddress implements ServerAddress
public static final int DEFAULT_PORT = 7687;
public static final BoltServerAddress LOCAL_DEFAULT = new BoltServerAddress( "localhost", DEFAULT_PORT );

private final String host;
private final String originalHost; // This keeps the original host name provided by the user.
private final String host; // This could either be the same as originalHost or it is an IP address resolved from the original host.
private final int port;
private final String stringValue;

Expand All @@ -52,6 +54,12 @@ public BoltServerAddress( URI uri )

public BoltServerAddress( String host, int port )
{
this( host, host, port );
}

public BoltServerAddress( String originalHost, String host, int port )
{
this.originalHost = requireNonNull( originalHost, "original host" );
this.host = requireNonNull( host, "host" );
this.port = requireValidPort( port );
this.stringValue = String.format( "%s:%d", host, port );
Expand All @@ -76,13 +84,13 @@ public boolean equals( Object o )
return false;
}
BoltServerAddress that = (BoltServerAddress) o;
return port == that.port && host.equals( that.host );
return port == that.port && originalHost.equals( that.originalHost ) && host.equals( that.host );
}

@Override
public int hashCode()
{
return 31 * host.hashCode() + port;
return Objects.hash( originalHost, host, port );
}

@Override
Expand Down Expand Up @@ -112,14 +120,14 @@ public SocketAddress toSocketAddress()
*/
public BoltServerAddress resolve() throws UnknownHostException
{
String hostAddress = InetAddress.getByName( host ).getHostAddress();
if ( hostAddress.equals( host ) )
String ipAddress = InetAddress.getByName( host ).getHostAddress();
if ( ipAddress.equals( host ) )
{
return this;
}
else
{
return new BoltServerAddress( hostAddress, port );
return new BoltServerAddress( host, ipAddress, port );
}
}

Expand All @@ -129,6 +137,11 @@ public String host()
return host;
}

public String originalHost()
{
return originalHost;
}

@Override
public int port()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ private SslHandler createSslHandler()
private SSLEngine createSslEngine()
{
SSLContext sslContext = securityPlan.sslContext();
SSLEngine sslEngine = sslContext.createSSLEngine( address.host(), address.port() );
SSLEngine sslEngine = sslContext.createSSLEngine( address.originalHost(), address.port() );
sslEngine.setUseClientMode( true );
if ( securityPlan.requiresHostnameVerification() )
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public Set<ServerAddress> resolve( ServerAddress initialRouter )
try
{
return Stream.of( InetAddress.getAllByName( initialRouter.host() ) )
.map( address -> new BoltServerAddress( address.getHostAddress(), initialRouter.port() ) )
.map( address -> new BoltServerAddress( initialRouter.host(), address.getHostAddress(), initialRouter.port() ) )
.collect( toSet() );
}
catch ( UnknownHostException e )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ public class ServerVersion
{
public static final String NEO4J_PRODUCT = "Neo4j";

public static final ServerVersion v4_0_0 = new ServerVersion( NEO4J_PRODUCT, 4, 0, 0 );
public static final ServerVersion v3_5_0 = new ServerVersion( NEO4J_PRODUCT, 3, 5, 0 );
public static final ServerVersion v3_4_0 = new ServerVersion( NEO4J_PRODUCT, 3, 4, 0 );
public static final ServerVersion v3_2_0 = new ServerVersion( NEO4J_PRODUCT, 3, 2, 0 );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,6 @@ public interface ServerAddress
*/
static ServerAddress of( String host, int port )
{
return new BoltServerAddress( host, port );
return new BoltServerAddress( host, host, port );
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ void shouldUpdateChannelAttributes()
@Test
void shouldIncludeSniHostName() throws Exception
{
BoltServerAddress address = new BoltServerAddress( "database.neo4j.com", 8989 );
BoltServerAddress address = new BoltServerAddress( "database.neo4j.com", "10.0.0.18", 8989 );
NettyChannelInitializer initializer = new NettyChannelInitializer( address, trustAllCertificates(), 10000, Clock.SYSTEM, DEV_NULL_LOGGING );

initializer.initChannel( channel );
Expand All @@ -125,7 +125,7 @@ void shouldIncludeSniHostName() throws Exception
List<SNIServerName> sniServerNames = sslParameters.getServerNames();
assertThat( sniServerNames, hasSize( 1 ) );
assertThat( sniServerNames.get( 0 ), instanceOf( SNIHostName.class ) );
assertThat( ((SNIHostName) sniServerNames.get( 0 )).getAsciiName(), equalTo( address.host() ) );
assertThat( ((SNIHostName) sniServerNames.get( 0 )).getAsciiName(), equalTo( address.originalHost() ) );
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import static org.neo4j.driver.internal.util.ServerVersion.v3_2_0;
import static org.neo4j.driver.internal.util.ServerVersion.v3_4_0;
import static org.neo4j.driver.internal.util.ServerVersion.v3_5_0;
import static org.neo4j.driver.internal.util.ServerVersion.v4_0_0;

public enum Neo4jFeature
{
Expand All @@ -36,7 +37,8 @@ public enum Neo4jFeature
STATEMENT_RESULT_TIMINGS( v3_1_0 ),
LIST_QUERIES_PROCEDURE( v3_1_0 ),
CONNECTOR_LISTEN_ADDRESS_CONFIGURATION( v3_1_0 ),
BOLT_V3( v3_5_0 );
BOLT_V3( v3_5_0 ),
BOLT_V4( v4_0_0 );

private final ServerVersion availableFromVersion;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.neo4j.driver.internal.cluster.RoutingSettings;
import org.neo4j.driver.internal.retry.RetrySettings;
import org.neo4j.driver.internal.util.ChannelTrackingDriverFactory;
import org.neo4j.driver.internal.util.DisabledOnNeo4jWith;
import org.neo4j.driver.internal.util.FailingConnectionDriverFactory;
import org.neo4j.driver.internal.util.FakeClock;
import org.neo4j.driver.internal.util.ServerVersion;
Expand Down Expand Up @@ -86,6 +87,7 @@
import static org.neo4j.driver.internal.logging.DevNullLogging.DEV_NULL_LOGGING;
import static org.neo4j.driver.internal.util.Matchers.connectionAcquisitionTimeoutError;
import static org.neo4j.driver.internal.util.Neo4jFeature.BOLT_V3;
import static org.neo4j.driver.internal.util.Neo4jFeature.BOLT_V4;
import static org.neo4j.driver.v1.Values.parameters;
import static org.neo4j.driver.v1.util.DaemonThreadFactory.daemon;
import static org.neo4j.driver.v1.util.TestUtil.await;
Expand Down Expand Up @@ -137,6 +139,7 @@ void shouldExecuteReadAndWritesWhenDriverSuppliedWithAddressOfLeader() throws Ex
}

@Test
@DisabledOnNeo4jWith( BOLT_V4 )
void shouldExecuteReadAndWritesWhenRouterIsDiscovered() throws Exception
{
Cluster cluster = clusterRule.getCluster();
Expand All @@ -157,6 +160,7 @@ void shouldExecuteReadAndWritesWhenDriverSuppliedWithAddressOfFollower() throws
}

@Test
@DisabledOnNeo4jWith( BOLT_V4 )
void sessionCreationShouldFailIfCallingDiscoveryProcedureOnEdgeServer()
{
Cluster cluster = clusterRule.getCluster();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ void connectionUsedForSessionRunReturnedToThePoolWhenServerErrorDuringResultFetc
Connection connection1 = connectionPool.lastAcquiredConnectionSpy;
verify( connection1, never() ).release();

assertThrows( ClientException.class, result::hasNext );
assertThrows( ClientException.class, result::consume );

Connection connection2 = connectionPool.lastAcquiredConnectionSpy;
assertSame( connection1, connection2 );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public int boltPort()

public BoltServerAddress boltAddress()
{
return new BoltServerAddress( "localhost", boltPort() );
return new BoltServerAddress( boltUri() );
}

public URI boltUri()
Expand Down

0 comments on commit bc68f0f

Please sign in to comment.