From 86f022ebd486c600a83cd80a4d5526585634343e Mon Sep 17 00:00:00 2001 From: xinchen Date: Sat, 19 Sep 2020 11:22:57 -0700 Subject: [PATCH] Socket connect timeout --- src/Transport/TcpTransportInitiator.cs | 55 ++++++++++++++++++++------ test/TestCases/AmqpTransportTests.cs | 43 +++++++++++++++++--- 2 files changed, 79 insertions(+), 19 deletions(-) diff --git a/src/Transport/TcpTransportInitiator.cs b/src/Transport/TcpTransportInitiator.cs index 97d5fbbe..0d18d6cf 100644 --- a/src/Transport/TcpTransportInitiator.cs +++ b/src/Transport/TcpTransportInitiator.cs @@ -6,11 +6,15 @@ namespace Microsoft.Azure.Amqp.Transport using System; using System.Net; using System.Net.Sockets; + using System.Threading; sealed class TcpTransportInitiator : TransportInitiator { readonly TcpTransportSettings transportSettings; TransportAsyncCallbackArgs callbackArgs; + SocketAsyncEventArgs connectEventArgs; + Timer timer; + int state; internal TcpTransportInitiator(TcpTransportSettings transportSettings) { @@ -19,16 +23,20 @@ internal TcpTransportInitiator(TcpTransportSettings transportSettings) public override bool ConnectAsync(TimeSpan timeout, TransportAsyncCallbackArgs callbackArgs) { - // TODO: set socket connect timeout to timeout this.callbackArgs = callbackArgs; this.callbackArgs.Exception = null; this.callbackArgs.Transport = null; + DnsEndPoint dnsEndPoint = new DnsEndPoint(this.transportSettings.Host, this.transportSettings.Port); + this.connectEventArgs = new SocketAsyncEventArgs(); + this.connectEventArgs.Completed += new EventHandler(OnConnectComplete); + this.connectEventArgs.RemoteEndPoint = dnsEndPoint; + this.connectEventArgs.UserToken = this; - SocketAsyncEventArgs connectEventArgs = new SocketAsyncEventArgs(); - connectEventArgs.Completed += new EventHandler(OnConnectComplete); - connectEventArgs.RemoteEndPoint = dnsEndPoint; - connectEventArgs.UserToken = this; + if (timeout < TimeSpan.MaxValue) + { + this.timer = new Timer(s => OnTimer(s), this, timeout, Timeout.InfiniteTimeSpan); + } // On Linux platform, socket connections are allowed to be initiated on the socket instance // with hostname due to multiple IP address DNS resolution possibility. @@ -40,7 +48,15 @@ public override bool ConnectAsync(TimeSpan timeout, TransportAsyncCallbackArgs c } else { - this.Complete(connectEventArgs, true); + if (Interlocked.CompareExchange(ref this.state, 1, 0) == 0) + { + this.Complete(this.connectEventArgs, true); + } + else + { + this.connectEventArgs.ConnectSocket?.Dispose(); + } + return false; } } @@ -48,12 +64,24 @@ public override bool ConnectAsync(TimeSpan timeout, TransportAsyncCallbackArgs c static void OnConnectComplete(object sender, SocketAsyncEventArgs e) { TcpTransportInitiator thisPtr = (TcpTransportInitiator)e.UserToken; - if (thisPtr.callbackArgs.Transport == null && thisPtr.callbackArgs.Exception == null) + if (Interlocked.CompareExchange(ref thisPtr.state, 1, 0) == 0) { - // Mono invokes the callback twice from the callback event handler - // Ignore the second one as a workaround. thisPtr.Complete(e, false); } + else + { + e.ConnectSocket?.Dispose(); + } + } + + static void OnTimer(object obj) + { + var thisPtr = (TcpTransportInitiator)obj; + if (Interlocked.CompareExchange(ref thisPtr.state, 1, 0) == 0) + { + thisPtr.connectEventArgs.SocketError = SocketError.TimedOut; + thisPtr.Complete(thisPtr.connectEventArgs, false); + } } void Complete(SocketAsyncEventArgs e, bool completeSynchronously) @@ -63,9 +91,9 @@ void Complete(SocketAsyncEventArgs e, bool completeSynchronously) if (e.SocketError != SocketError.Success) { exception = new SocketException((int)e.SocketError); - if (e.AcceptSocket != null) + if (e.ConnectSocket != null) { - e.AcceptSocket.Dispose(); + e.ConnectSocket.Dispose(); } } else @@ -73,7 +101,6 @@ void Complete(SocketAsyncEventArgs e, bool completeSynchronously) try { Fx.Assert(e.ConnectSocket != null, "Must have a valid socket accepted."); - e.ConnectSocket.NoDelay = true; transport = new TcpTransport(e.ConnectSocket, this.transportSettings); transport.Open(); } @@ -84,15 +111,17 @@ void Complete(SocketAsyncEventArgs e, bool completeSynchronously) { transport.SafeClose(); } + transport = null; } } e.Dispose(); + this.timer?.Dispose(); + this.callbackArgs.CompletedSynchronously = completeSynchronously; this.callbackArgs.Exception = exception; this.callbackArgs.Transport = transport; - if (!completeSynchronously) { this.callbackArgs.CompletedCallback(this.callbackArgs); diff --git a/test/TestCases/AmqpTransportTests.cs b/test/TestCases/AmqpTransportTests.cs index 9a03aba1..da53cb3f 100644 --- a/test/TestCases/AmqpTransportTests.cs +++ b/test/TestCases/AmqpTransportTests.cs @@ -2,7 +2,10 @@ { using System; using System.Diagnostics; + using System.Net; + using System.Net.Sockets; using System.Threading; + using global::Microsoft.Azure.Amqp; using global::Microsoft.Azure.Amqp.Transport; using Xunit; @@ -45,6 +48,38 @@ public void TcpTransportTest() Assert.True(serverContext.Success); } + [Fact] + public void ConnectTimeoutTest() + { + const int port = 30888; + IPAddress address = IPAddress.Loopback; + // Creat a listener socket but do not listen on it + var socket = new Socket(address.AddressFamily, SocketType.Stream, ProtocolType.Tcp) { NoDelay = true }; + socket.Bind(new IPEndPoint(address, port)); + + try + { + var tcp = new TcpTransportSettings() { Host = "localhost", Port = port }; + var amqp = new AmqpSettings(); + amqp.TransportProviders.Add(new AmqpTransportProvider()); + var initiator = new AmqpTransportInitiator(amqp, tcp); + var task = initiator.ConnectTaskAsync(TimeSpan.FromSeconds(1)); + Assert.False(task.IsCompleted); + + Thread.Sleep(2000); + Assert.True(task.IsFaulted); + Assert.NotNull(task.Exception); + + var ex = task.Exception.GetBaseException() as SocketException; + Assert.NotNull(ex); + Assert.Equal(SocketError.TimedOut, (SocketError)ex.ErrorCode); + } + finally + { + socket.Close(); + } + } + internal static TransportBase AcceptServerTransport(TransportSettings settings) { ManualResetEvent complete = new ManualResetEvent(false); @@ -75,12 +110,8 @@ internal static TransportBase AcceptServerTransport(TransportSettings settings) complete.WaitOne(); complete.Dispose(); - - transport.Closed += (s, a) => - { - listener.Close(); - Debug.WriteLine("Listeners Closed."); - }; + listener.Close(); + Debug.WriteLine("Listeners Closed."); return transport; }