Skip to content

Commit

Permalink
Socket connect timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
xinchen10 committed Sep 19, 2020
1 parent 709eafd commit 86f022e
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 19 deletions.
55 changes: 42 additions & 13 deletions src/Transport/TcpTransportInitiator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@ namespace Microsoft.Azure.Amqp.Transport
using System;
using System.Net;
using System.Net.Sockets;
using System.Threading;

sealed class TcpTransportInitiator : TransportInitiator
{
readonly TcpTransportSettings transportSettings;
TransportAsyncCallbackArgs callbackArgs;
SocketAsyncEventArgs connectEventArgs;
Timer timer;
int state;

internal TcpTransportInitiator(TcpTransportSettings transportSettings)
{
Expand All @@ -19,16 +23,20 @@ internal TcpTransportInitiator(TcpTransportSettings transportSettings)

public override bool ConnectAsync(TimeSpan timeout, TransportAsyncCallbackArgs callbackArgs)
{
// TODO: set socket connect timeout to timeout
this.callbackArgs = callbackArgs;
this.callbackArgs.Exception = null;
this.callbackArgs.Transport = null;

DnsEndPoint dnsEndPoint = new DnsEndPoint(this.transportSettings.Host, this.transportSettings.Port);
this.connectEventArgs = new SocketAsyncEventArgs();
this.connectEventArgs.Completed += new EventHandler<SocketAsyncEventArgs>(OnConnectComplete);
this.connectEventArgs.RemoteEndPoint = dnsEndPoint;
this.connectEventArgs.UserToken = this;

SocketAsyncEventArgs connectEventArgs = new SocketAsyncEventArgs();
connectEventArgs.Completed += new EventHandler<SocketAsyncEventArgs>(OnConnectComplete);
connectEventArgs.RemoteEndPoint = dnsEndPoint;
connectEventArgs.UserToken = this;
if (timeout < TimeSpan.MaxValue)
{
this.timer = new Timer(s => OnTimer(s), this, timeout, Timeout.InfiniteTimeSpan);
}

// On Linux platform, socket connections are allowed to be initiated on the socket instance
// with hostname due to multiple IP address DNS resolution possibility.
Expand All @@ -40,20 +48,40 @@ public override bool ConnectAsync(TimeSpan timeout, TransportAsyncCallbackArgs c
}
else
{
this.Complete(connectEventArgs, true);
if (Interlocked.CompareExchange(ref this.state, 1, 0) == 0)
{
this.Complete(this.connectEventArgs, true);
}
else
{
this.connectEventArgs.ConnectSocket?.Dispose();
}

return false;
}
}

static void OnConnectComplete(object sender, SocketAsyncEventArgs e)
{
TcpTransportInitiator thisPtr = (TcpTransportInitiator)e.UserToken;
if (thisPtr.callbackArgs.Transport == null && thisPtr.callbackArgs.Exception == null)
if (Interlocked.CompareExchange(ref thisPtr.state, 1, 0) == 0)
{
// Mono invokes the callback twice from the callback event handler
// Ignore the second one as a workaround.
thisPtr.Complete(e, false);
}
else
{
e.ConnectSocket?.Dispose();
}
}

static void OnTimer(object obj)
{
var thisPtr = (TcpTransportInitiator)obj;
if (Interlocked.CompareExchange(ref thisPtr.state, 1, 0) == 0)
{
thisPtr.connectEventArgs.SocketError = SocketError.TimedOut;
thisPtr.Complete(thisPtr.connectEventArgs, false);
}
}

void Complete(SocketAsyncEventArgs e, bool completeSynchronously)
Expand All @@ -63,17 +91,16 @@ void Complete(SocketAsyncEventArgs e, bool completeSynchronously)
if (e.SocketError != SocketError.Success)
{
exception = new SocketException((int)e.SocketError);
if (e.AcceptSocket != null)
if (e.ConnectSocket != null)
{
e.AcceptSocket.Dispose();
e.ConnectSocket.Dispose();
}
}
else
{
try
{
Fx.Assert(e.ConnectSocket != null, "Must have a valid socket accepted.");
e.ConnectSocket.NoDelay = true;
transport = new TcpTransport(e.ConnectSocket, this.transportSettings);
transport.Open();
}
Expand All @@ -84,15 +111,17 @@ void Complete(SocketAsyncEventArgs e, bool completeSynchronously)
{
transport.SafeClose();
}

transport = null;
}
}

e.Dispose();
this.timer?.Dispose();

this.callbackArgs.CompletedSynchronously = completeSynchronously;
this.callbackArgs.Exception = exception;
this.callbackArgs.Transport = transport;

if (!completeSynchronously)
{
this.callbackArgs.CompletedCallback(this.callbackArgs);
Expand Down
43 changes: 37 additions & 6 deletions test/TestCases/AmqpTransportTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
{
using System;
using System.Diagnostics;
using System.Net;
using System.Net.Sockets;
using System.Threading;
using global::Microsoft.Azure.Amqp;
using global::Microsoft.Azure.Amqp.Transport;
using Xunit;

Expand Down Expand Up @@ -45,6 +48,38 @@ public void TcpTransportTest()
Assert.True(serverContext.Success);
}

[Fact]
public void ConnectTimeoutTest()
{
const int port = 30888;
IPAddress address = IPAddress.Loopback;
// Creat a listener socket but do not listen on it
var socket = new Socket(address.AddressFamily, SocketType.Stream, ProtocolType.Tcp) { NoDelay = true };
socket.Bind(new IPEndPoint(address, port));

try
{
var tcp = new TcpTransportSettings() { Host = "localhost", Port = port };
var amqp = new AmqpSettings();
amqp.TransportProviders.Add(new AmqpTransportProvider());
var initiator = new AmqpTransportInitiator(amqp, tcp);
var task = initiator.ConnectTaskAsync(TimeSpan.FromSeconds(1));
Assert.False(task.IsCompleted);

Thread.Sleep(2000);
Assert.True(task.IsFaulted);
Assert.NotNull(task.Exception);

var ex = task.Exception.GetBaseException() as SocketException;
Assert.NotNull(ex);
Assert.Equal(SocketError.TimedOut, (SocketError)ex.ErrorCode);
}
finally
{
socket.Close();
}
}

internal static TransportBase AcceptServerTransport(TransportSettings settings)
{
ManualResetEvent complete = new ManualResetEvent(false);
Expand Down Expand Up @@ -75,12 +110,8 @@ internal static TransportBase AcceptServerTransport(TransportSettings settings)

complete.WaitOne();
complete.Dispose();

transport.Closed += (s, a) =>
{
listener.Close();
Debug.WriteLine("Listeners Closed.");
};
listener.Close();
Debug.WriteLine("Listeners Closed.");

return transport;
}
Expand Down

0 comments on commit 86f022e

Please sign in to comment.