From 060de1cdbc78d901d5354d7fe5848b1d2ac8a154 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marie=20P=C3=ADchov=C3=A1?= <11718369+ManickaP@users.noreply.github.com> Date: Wed, 13 Jul 2022 23:13:34 +0200 Subject: [PATCH] [QUIC] API QuicConnection (#71783), QuicStream (#71969), QUIC public (#72031) (#72106) * [QUIC] API QuicConnection (#71783) * Listener comment; PreviewFeature attribute * Feedback * QuicConnection new API including compilable implementation * Fixed logging * Fixed S.N.Quic and S.N.Http tests * Options now correspond to the issue * Feedback * Comments, PreviewFeature attribute and RemoteCertificate disposal. * Preview feature attribute is assembly wide * Some typos * Fixed test with certificate * Default values as constants * Event handlers split into methods called via switch expression. * Some more comments * Unified unsafe usage * Fixed some more tests * Cleaned up some exceptions and resource strings. * Feedback * Latest greatest API proposal. * Fixed Http solution * Feedback * [QUIC] API QuicStream (#71969) * Quic stream API surface * Fixed test compilation * Fixed http test compilation * HttpLoopbackConnection Dispose -> DisposeAsync * QuicStream implementation * Fixed some tests * Fixed all QUIC and HTTP tests * Fixed exception type for stream closed by connection close * Feedback * Fixed WebSocket.Client test build * Feedback, test fixes * Fixed build on framework and windows * Fixed winhandler test * Swap variable based on order in defining class * Post merge fixes * Feedback and build * Reverted connection state to pass around abort error code * Fixed exception type. * [QUIC] System.Net.Quic API made public (#72031) * System.Net.Quic removed from ASP transport package and made part of SDK ref * Removed manual references to System.Net.Quic.csproj --- .../System/Net/Http/GenericLoopbackServer.cs | 4 +- .../Net/Http/Http2LoopbackConnection.cs | 4 +- .../System/Net/Http/Http2LoopbackServer.cs | 4 +- .../Net/Http/Http3LoopbackConnection.cs | 28 +- .../System/Net/Http/Http3LoopbackServer.cs | 18 +- .../System/Net/Http/Http3LoopbackStream.cs | 21 +- .../Net/Http/HttpAgnosticLoopbackServer.cs | 15 +- .../HttpClientHandlerTest.Authentication.cs | 2 +- .../tests/System/Net/Http/LoopbackServer.cs | 11 +- ...Internal.Runtime.AspNetCore.Transport.proj | 5 +- src/libraries/NetCoreAppLibrary.props | 1 - .../FunctionalTests/ClientCertificateTest.cs | 4 +- .../src/System.Net.Http.csproj | 6 +- .../Http/SocketsHttpHandler/ConnectHelper.cs | 22 +- .../SocketsHttpHandler/Http3Connection.cs | 14 +- .../SocketsHttpHandler/Http3RequestStream.cs | 16 +- .../HttpClientHandlerTest.AltSvc.cs | 4 +- .../HttpClientHandlerTest.Http2.cs | 4 +- .../HttpClientHandlerTest.Http3.cs | 113 +- ...cketsHttpHandlerTest.Http2KeepAlivePing.cs | 2 +- .../FunctionalTests/SocketsHttpHandlerTest.cs | 14 +- .../System.Net.Http.Functional.Tests.csproj | 1 - .../tests/FunctionalTests/TelemetryTest.cs | 4 +- .../System.Net.Quic/ref/System.Net.Quic.cs | 72 +- .../ref/System.Net.Quic.csproj | 3 - .../src/Resources/Strings.resx | 44 +- .../src/System.Net.Quic.csproj | 1 + .../MsQuic/Internal/MsQuicAddressHelpers.cs | 35 - .../MsQuic/Internal/MsQuicApi.cs | 153 -- .../MsQuic/Internal/MsQuicParameterHelpers.cs | 101 - .../Internal/ResettableCompletionSource.cs | 84 - .../Interop/SafeMsQuicConfigurationHandle.cs | 297 --- .../Interop/SafeMsQuicConnectionHandle.cs | 14 - .../MsQuic/Interop/SafeMsQuicStreamHandle.cs | 14 - .../MsQuic/MsQuicConnection.cs | 774 -------- .../Implementations/MsQuic/MsQuicStream.cs | 1694 ----------------- .../src/System/Net/Quic/Internal/MsQuicApi.cs | 148 ++ .../MsQuic => }/Internal/MsQuicBuffers.cs | 40 +- .../Net/Quic/Internal/MsQuicConfiguration.cs | 235 +++ .../System/Net/Quic/Internal/MsQuicHelpers.cs | 84 + .../{Interop => Internal}/MsQuicSafeHandle.cs | 13 +- .../Net/Quic/Internal/ReceiveBuffers.cs | 80 + .../Internal/ResettableValueTaskSource.cs | 281 +++ .../System/Net/Quic/Internal/ThrowHelper.cs | 245 ++- .../Net/Quic/Internal/ValueTaskSource.cs | 11 +- .../MsQuic => }/Interop/msquic.cs | 0 .../MsQuic => }/Interop/msquic_extensions.cs | 0 .../MsQuic => }/Interop/msquic_generated.cs | 0 .../Interop/msquic_generated_linux.cs | 0 .../Interop/msquic_generated_macos.cs | 0 .../Interop/msquic_generated_windows.cs | 0 .../src/System/Net/Quic/QuicAbortDirection.cs | 24 + .../QuicConnection.SslConnectionOptions.cs | 127 ++ .../src/System/Net/Quic/QuicConnection.cs | 632 +++++- .../System/Net/Quic/QuicConnectionOptions.cs | 104 +- .../src/System/Net/Quic/QuicDefaults.cs | 35 + .../src/System/Net/Quic/QuicError.cs | 2 +- .../Quic/QuicListener.PendingConnection.cs | 5 +- .../src/System/Net/Quic/QuicListener.cs | 31 +- .../System/Net/Quic/QuicListenerOptions.cs | 33 +- .../src/System/Net/Quic/QuicStream.Stream.cs | 167 ++ .../src/System/Net/Quic/QuicStream.cs | 613 +++++- .../src/System/Net/Quic/QuicStreamType.cs | 21 + .../MsQuicCipherSuitesPolicyTests.cs | 3 +- .../tests/FunctionalTests/MsQuicTests.cs | 296 ++- .../FunctionalTests/QuicConnectionTests.cs | 83 +- .../FunctionalTests/QuicListenerTests.cs | 24 +- ...icStreamConnectedStreamConformanceTests.cs | 16 +- .../tests/FunctionalTests/QuicStreamTests.cs | 234 +-- .../tests/FunctionalTests/QuicTestBase.cs | 59 +- .../System.Net.Quic.Functional.Tests.csproj | 1 - .../tests/ClientWebSocketOptionsTests.cs | 2 +- 72 files changed, 3133 insertions(+), 4119 deletions(-) delete mode 100644 src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Internal/MsQuicAddressHelpers.cs delete mode 100644 src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Internal/MsQuicApi.cs delete mode 100644 src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Internal/MsQuicParameterHelpers.cs delete mode 100644 src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Internal/ResettableCompletionSource.cs delete mode 100644 src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/SafeMsQuicConfigurationHandle.cs delete mode 100644 src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/SafeMsQuicConnectionHandle.cs delete mode 100644 src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/SafeMsQuicStreamHandle.cs delete mode 100644 src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs delete mode 100644 src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs create mode 100644 src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicApi.cs rename src/libraries/System.Net.Quic/src/System/Net/Quic/{Implementations/MsQuic => }/Internal/MsQuicBuffers.cs (74%) create mode 100644 src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicConfiguration.cs create mode 100644 src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicHelpers.cs rename src/libraries/System.Net.Quic/src/System/Net/Quic/{Interop => Internal}/MsQuicSafeHandle.cs (89%) create mode 100644 src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ReceiveBuffers.cs create mode 100644 src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ResettableValueTaskSource.cs rename src/libraries/System.Net.Quic/src/System/Net/Quic/{Implementations/MsQuic => }/Interop/msquic.cs (100%) rename src/libraries/System.Net.Quic/src/System/Net/Quic/{Implementations/MsQuic => }/Interop/msquic_extensions.cs (100%) rename src/libraries/System.Net.Quic/src/System/Net/Quic/{Implementations/MsQuic => }/Interop/msquic_generated.cs (100%) rename src/libraries/System.Net.Quic/src/System/Net/Quic/{Implementations/MsQuic => }/Interop/msquic_generated_linux.cs (100%) rename src/libraries/System.Net.Quic/src/System/Net/Quic/{Implementations/MsQuic => }/Interop/msquic_generated_macos.cs (100%) rename src/libraries/System.Net.Quic/src/System/Net/Quic/{Implementations/MsQuic => }/Interop/msquic_generated_windows.cs (100%) create mode 100644 src/libraries/System.Net.Quic/src/System/Net/Quic/QuicAbortDirection.cs create mode 100644 src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.SslConnectionOptions.cs create mode 100644 src/libraries/System.Net.Quic/src/System/Net/Quic/QuicDefaults.cs create mode 100644 src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.Stream.cs create mode 100644 src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStreamType.cs diff --git a/src/libraries/Common/tests/System/Net/Http/GenericLoopbackServer.cs b/src/libraries/Common/tests/System/Net/Http/GenericLoopbackServer.cs index 5407a6095cd0a..c3fb1512fbce8 100644 --- a/src/libraries/Common/tests/System/Net/Http/GenericLoopbackServer.cs +++ b/src/libraries/Common/tests/System/Net/Http/GenericLoopbackServer.cs @@ -122,9 +122,9 @@ private void CloseWebSocket() } } - public abstract class GenericLoopbackConnection : IDisposable + public abstract class GenericLoopbackConnection : IAsyncDisposable { - public abstract void Dispose(); + public abstract ValueTask DisposeAsync(); public abstract Task InitializeConnectionAsync(); diff --git a/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs b/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs index 75415dcc97798..4ca66fabed420 100644 --- a/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs +++ b/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs @@ -838,12 +838,12 @@ public async Task SendResponseBodyAsync(int streamId, ReadOnlyMemory respo await SendResponseDataAsync(streamId, responseBody, isFinal).ConfigureAwait(false); } - public override void Dispose() + public override async ValueTask DisposeAsync() { // Might have been already shutdown manually via WaitForConnectionShutdownAsync which nulls the _connectionStream. if (_connectionStream != null) { - ShutdownIgnoringErrorsAsync(_lastStreamId).GetAwaiter().GetResult(); + await ShutdownIgnoringErrorsAsync(_lastStreamId); } } diff --git a/src/libraries/Common/tests/System/Net/Http/Http2LoopbackServer.cs b/src/libraries/Common/tests/System/Net/Http/Http2LoopbackServer.cs index 9d17c93f0d0cb..e6ac7ae626640 100644 --- a/src/libraries/Common/tests/System/Net/Http/Http2LoopbackServer.cs +++ b/src/libraries/Common/tests/System/Net/Http/Http2LoopbackServer.cs @@ -148,7 +148,7 @@ public override void Dispose() public override async Task HandleRequestAsync(HttpStatusCode statusCode = HttpStatusCode.OK, IList headers = null, string content = "") { - using (Http2LoopbackConnection connection = await EstablishConnectionAsync().ConfigureAwait(false)) + await using (Http2LoopbackConnection connection = await EstablishConnectionAsync().ConfigureAwait(false)) { return await connection.HandleRequestAsync(statusCode, headers, content).ConfigureAwait(false); } @@ -156,7 +156,7 @@ public override async Task HandleRequestAsync(HttpStatusCode st public override async Task AcceptConnectionAsync(Func funcAsync) { - using (Http2LoopbackConnection connection = await EstablishConnectionAsync().ConfigureAwait(false)) + await using (Http2LoopbackConnection connection = await EstablishConnectionAsync().ConfigureAwait(false)) { await funcAsync(connection).ConfigureAwait(false); } diff --git a/src/libraries/Common/tests/System/Net/Http/Http3LoopbackConnection.cs b/src/libraries/Common/tests/System/Net/Http/Http3LoopbackConnection.cs index fd9e04190fbf6..c64ea03a392c6 100644 --- a/src/libraries/Common/tests/System/Net/Http/Http3LoopbackConnection.cs +++ b/src/libraries/Common/tests/System/Net/Http/Http3LoopbackConnection.cs @@ -56,17 +56,17 @@ public Http3LoopbackConnection(QuicConnection connection) public long MaxHeaderListSize { get; private set; } = -1; - public override void Dispose() + public override async ValueTask DisposeAsync() { // Close any remaining request streams (but NOT control streams, as these should not be closed while the connection is open) foreach (Http3LoopbackStream stream in _openStreams.Values) { - stream.Dispose(); + await stream.DisposeAsync().ConfigureAwait(false); } foreach (QuicStream stream in _delayedStreams) { - stream.Dispose(); + await stream.DisposeAsync().ConfigureAwait(false); } // We don't dispose the connection currently, because this causes races when the server connection is closed before @@ -79,8 +79,8 @@ public override void Dispose() _connection.Dispose(); // Dispose control streams so that we release their handles too. - _inboundControlStream?.Dispose(); - _outboundControlStream?.Dispose(); + await _inboundControlStream?.DisposeAsync().ConfigureAwait(false); + await _outboundControlStream?.DisposeAsync().ConfigureAwait(false); #endif } @@ -91,12 +91,12 @@ public async Task CloseAsync(long errorCode) public async ValueTask OpenUnidirectionalStreamAsync() { - return new Http3LoopbackStream(await _connection.OpenUnidirectionalStreamAsync()); + return new Http3LoopbackStream(await _connection.OpenOutboundStreamAsync(QuicStreamType.Unidirectional)); } public async ValueTask OpenBidirectionalStreamAsync() { - return new Http3LoopbackStream(await _connection.OpenBidirectionalStreamAsync()); + return new Http3LoopbackStream(await _connection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional)); } public static int GetRequestId(QuicStream stream) @@ -104,7 +104,7 @@ public static int GetRequestId(QuicStream stream) Debug.Assert(stream.CanRead && stream.CanWrite, "Stream must be a request stream."); // TODO: QUIC streams can have IDs larger than int.MaxValue; update all our tests to use long rather than int. - return checked((int)stream.StreamId + 1); + return checked((int)stream.Id + 1); } public Http3LoopbackStream GetOpenRequest(int requestId = 0) @@ -131,7 +131,7 @@ async Task EnsureControlStreamAcceptedInternalAsync() while (true) { - QuicStream quicStream = await _connection.AcceptStreamAsync().ConfigureAwait(false); + QuicStream quicStream = await _connection.AcceptInboundStreamAsync().ConfigureAwait(false); if (!quicStream.CanWrite) { @@ -165,16 +165,16 @@ public async Task AcceptRequestStreamAsync() if (!_delayedStreams.TryDequeue(out QuicStream quicStream)) { - quicStream = await _connection.AcceptStreamAsync().ConfigureAwait(false); + quicStream = await _connection.AcceptInboundStreamAsync().ConfigureAwait(false); } var stream = new Http3LoopbackStream(quicStream); Assert.True(quicStream.CanWrite, "Expected writeable stream."); - _openStreams.Add(checked((int)quicStream.StreamId), stream); + _openStreams.Add(checked((int)quicStream.Id), stream); _currentStream = stream; - _currentStreamId = quicStream.StreamId; + _currentStreamId = quicStream.Id; return stream; } @@ -293,9 +293,9 @@ public async Task WaitForClientDisconnectAsync(bool refuseNewRequests = true) break; } - using (stream) + await using (stream) { - await stream.AbortAndWaitForShutdownAsync(H3_REQUEST_REJECTED); + stream.Abort(H3_REQUEST_REJECTED); } } diff --git a/src/libraries/Common/tests/System/Net/Http/Http3LoopbackServer.cs b/src/libraries/Common/tests/System/Net/Http/Http3LoopbackServer.cs index ba5109bc24bd7..60328d0b9f13f 100644 --- a/src/libraries/Common/tests/System/Net/Http/Http3LoopbackServer.cs +++ b/src/libraries/Common/tests/System/Net/Http/Http3LoopbackServer.cs @@ -36,8 +36,10 @@ public Http3LoopbackServer(Http3Options options = null) { var serverOptions = new QuicServerConnectionOptions() { - MaxBidirectionalStreams = options.MaxBidirectionalStreams, - MaxUnidirectionalStreams = options.MaxUnidirectionalStreams, + DefaultStreamErrorCode = Http3LoopbackConnection.H3_REQUEST_CANCELLED, + DefaultCloseErrorCode = Http3LoopbackConnection.H3_NO_ERROR, + MaxInboundBidirectionalStreams = options.MaxInboundBidirectionalStreams, + MaxInboundUnidirectionalStreams = options.MaxInboundUnidirectionalStreams, ServerAuthenticationOptions = new SslServerAuthenticationOptions { EnabledSslProtocols = options.SslProtocols, @@ -80,14 +82,14 @@ public override async Task EstablishGenericConnection public override async Task AcceptConnectionAsync(Func funcAsync) { - using Http3LoopbackConnection con = await EstablishHttp3ConnectionAsync().ConfigureAwait(false); + await using Http3LoopbackConnection con = await EstablishHttp3ConnectionAsync().ConfigureAwait(false); await funcAsync(con).ConfigureAwait(false); await con.ShutdownAsync(); } public override async Task HandleRequestAsync(HttpStatusCode statusCode = HttpStatusCode.OK, IList headers = null, string content = "") { - using var con = (Http3LoopbackConnection)await EstablishGenericConnectionAsync().ConfigureAwait(false); + await using Http3LoopbackConnection con = (Http3LoopbackConnection)await EstablishGenericConnectionAsync().ConfigureAwait(false); return await con.HandleRequestAsync(statusCode, headers, content).ConfigureAwait(false); } } @@ -136,16 +138,16 @@ private static Http3Options CreateOptions(GenericLoopbackOptions options) } public class Http3Options : GenericLoopbackOptions { - public int MaxUnidirectionalStreams { get; set; } + public int MaxInboundUnidirectionalStreams { get; set; } - public int MaxBidirectionalStreams { get; set; } + public int MaxInboundBidirectionalStreams { get; set; } public string Alpn { get; set; } public Http3Options() { - MaxUnidirectionalStreams = 10; - MaxBidirectionalStreams = 100; + MaxInboundUnidirectionalStreams = 10; + MaxInboundBidirectionalStreams = 100; Alpn = SslApplicationProtocol.Http3.ToString(); } } diff --git a/src/libraries/Common/tests/System/Net/Http/Http3LoopbackStream.cs b/src/libraries/Common/tests/System/Net/Http/Http3LoopbackStream.cs index b3d7b7f564423..3cc397c37d1a0 100644 --- a/src/libraries/Common/tests/System/Net/Http/Http3LoopbackStream.cs +++ b/src/libraries/Common/tests/System/Net/Http/Http3LoopbackStream.cs @@ -15,7 +15,7 @@ namespace System.Net.Test.Common { - internal sealed class Http3LoopbackStream : IDisposable + internal sealed class Http3LoopbackStream : IAsyncDisposable { private const int MaximumVarIntBytes = 8; private const long VarIntMax = (1L << 62) - 1; @@ -40,12 +40,9 @@ public Http3LoopbackStream(QuicStream stream) _stream = stream; } - public void Dispose() - { - _stream.Dispose(); - } + public ValueTask DisposeAsync() => _stream.DisposeAsync(); - public long StreamId => _stream.StreamId; + public long StreamId => _stream.Id; public async Task HandleRequestAsync(HttpStatusCode statusCode = HttpStatusCode.OK, IList headers = null, string content = "") { @@ -285,9 +282,7 @@ public async Task SendResponseBodyAsync(byte[] content, bool isFinal = true) if (isFinal) { - _stream.Shutdown(); - await _stream.ShutdownCompleted().ConfigureAwait(false); - Dispose(); + _stream.CompleteWrites(); } } @@ -389,7 +384,7 @@ async Task WaitForWriteCancellation() { try { - await _stream.WaitForWriteCompletionAsync(); + await _stream.WritesClosed; } catch (QuicException ex) when (ex.QuicError == QuicError.StreamAborted && ex.ApplicationErrorCode == Http3LoopbackConnection.H3_REQUEST_CANCELLED) { @@ -424,11 +419,9 @@ private async Task DrainResponseData() } } - public async Task AbortAndWaitForShutdownAsync(long errorCode) + public void Abort(long errorCode) { - _stream.AbortRead(errorCode); - _stream.AbortWrite(errorCode); - await _stream.ShutdownCompleted(); + _stream.Abort(QuicAbortDirection.Both, errorCode); } public async Task<(long? frameType, byte[] payload)> ReadFrameAsync() diff --git a/src/libraries/Common/tests/System/Net/Http/HttpAgnosticLoopbackServer.cs b/src/libraries/Common/tests/System/Net/Http/HttpAgnosticLoopbackServer.cs index 88a8071153e6c..48626cb2de6f6 100644 --- a/src/libraries/Common/tests/System/Net/Http/HttpAgnosticLoopbackServer.cs +++ b/src/libraries/Common/tests/System/Net/Http/HttpAgnosticLoopbackServer.cs @@ -109,15 +109,18 @@ public override async Task EstablishGenericConnection { return connection = await Http2LoopbackServerFactory.Singleton.CreateConnectionAsync(new SocketWrapper(socket), stream, options).ConfigureAwait(false); } - else + else { throw new Exception($"Invalid ClearTextVersion={_options.ClearTextVersion} specified"); } } catch - { - connection?.Dispose(); - connection = null; + { + if (connection is not null) + { + await connection.DisposeAsync(); + connection = null; + } stream.Dispose(); throw; } @@ -132,7 +135,7 @@ public override async Task EstablishGenericConnection public override async Task HandleRequestAsync(HttpStatusCode statusCode = HttpStatusCode.OK, IList headers = null, string content = "") { - using (GenericLoopbackConnection connection = await EstablishGenericConnectionAsync().ConfigureAwait(false)) + await using (GenericLoopbackConnection connection = await EstablishGenericConnectionAsync().ConfigureAwait(false)) { return await connection.HandleRequestAsync(statusCode, headers, content).ConfigureAwait(false); } @@ -140,7 +143,7 @@ public override async Task HandleRequestAsync(HttpStatusCode st public override async Task AcceptConnectionAsync(Func funcAsync) { - using (GenericLoopbackConnection connection = await EstablishGenericConnectionAsync().ConfigureAwait(false)) + await using (GenericLoopbackConnection connection = await EstablishGenericConnectionAsync().ConfigureAwait(false)) { await funcAsync(connection).ConfigureAwait(false); } diff --git a/src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTest.Authentication.cs b/src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTest.Authentication.cs index 4d7e95ee25563..b19c02b52ebc0 100644 --- a/src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTest.Authentication.cs +++ b/src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTest.Authentication.cs @@ -717,7 +717,7 @@ await LoopbackServer.CreateClientAndServerAsync( Assert.Equal(0, requestData.GetHeaderValueCount("Authorization")); // Establish a session connection - using var connection = await server.EstablishConnectionAsync(); + await using LoopbackServer.Connection connection = await server.EstablishConnectionAsync(); requestData = await connection.ReadRequestDataAsync(); string authHeaderValue = requestData.GetSingleHeaderValue("Authorization"); Assert.Contains("NTLM", authHeaderValue); diff --git a/src/libraries/Common/tests/System/Net/Http/LoopbackServer.cs b/src/libraries/Common/tests/System/Net/Http/LoopbackServer.cs index 51445abb09c65..07a039da95349 100644 --- a/src/libraries/Common/tests/System/Net/Http/LoopbackServer.cs +++ b/src/libraries/Common/tests/System/Net/Http/LoopbackServer.cs @@ -156,7 +156,7 @@ public async Task EstablishConnectionAsync() public async Task AcceptConnectionAsync(Func funcAsync) { - using (Connection connection = await EstablishConnectionAsync().ConfigureAwait(false)) + await using (Connection connection = await EstablishConnectionAsync().ConfigureAwait(false)) { await funcAsync(connection).ConfigureAwait(false); } @@ -654,7 +654,7 @@ private async Task ReadLineBytesAsync() return null; } - public override void Dispose() + public override async ValueTask DisposeAsync() { try { @@ -666,7 +666,12 @@ public override void Dispose() } catch (Exception) { } +#if !NETSTANDARD2_0 && !NETFRAMEWORK + await _stream.DisposeAsync().ConfigureAwait(false); +#else _stream.Dispose(); + await Task.CompletedTask.ConfigureAwait(false); +#endif _socket?.Dispose(); } @@ -1076,7 +1081,7 @@ public override Task WaitForCloseAsync(CancellationToken cancellationToken) public override async Task HandleRequestAsync(HttpStatusCode statusCode = HttpStatusCode.OK, IList headers = null, string content = "") { - using (Connection connection = await EstablishConnectionAsync().ConfigureAwait(false)) + await using (Connection connection = await EstablishConnectionAsync().ConfigureAwait(false)) { return await connection.HandleRequestAsync(statusCode, headers, content).ConfigureAwait(false); } diff --git a/src/libraries/Microsoft.Internal.Runtime.AspNetCore.Transport/src/Microsoft.Internal.Runtime.AspNetCore.Transport.proj b/src/libraries/Microsoft.Internal.Runtime.AspNetCore.Transport/src/Microsoft.Internal.Runtime.AspNetCore.Transport.proj index dc521c46e86e4..cd8628c81afd0 100644 --- a/src/libraries/Microsoft.Internal.Runtime.AspNetCore.Transport/src/Microsoft.Internal.Runtime.AspNetCore.Transport.proj +++ b/src/libraries/Microsoft.Internal.Runtime.AspNetCore.Transport/src/Microsoft.Internal.Runtime.AspNetCore.Transport.proj @@ -14,9 +14,8 @@ - - + - System.Net.Quic; System.Private.CoreLib; System.Private.DataContractSerialization; System.Private.Uri; diff --git a/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/ClientCertificateTest.cs b/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/ClientCertificateTest.cs index c88b81f9886ce..e3f25bfeb6bac 100644 --- a/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/ClientCertificateTest.cs +++ b/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/ClientCertificateTest.cs @@ -41,7 +41,7 @@ await LoopbackServer.CreateClientAndServerAsync( }, async s => { - using (LoopbackServer.Connection connection = await s.EstablishConnectionAsync().ConfigureAwait(false)) + await using (LoopbackServer.Connection connection = await s.EstablishConnectionAsync().ConfigureAwait(false)) { SslStream sslStream = connection.Stream as SslStream; Assert.NotNull(sslStream); @@ -76,7 +76,7 @@ await Http2LoopbackServer.CreateClientAndServerAsync( }, async s => { - using (Http2LoopbackConnection connection = await s.EstablishConnectionAsync().ConfigureAwait(false)) + await using (Http2LoopbackConnection connection = await s.EstablishConnectionAsync().ConfigureAwait(false)) { SslStream sslStream = connection.Stream as SslStream; Assert.NotNull(sslStream); diff --git a/src/libraries/System.Net.Http/src/System.Net.Http.csproj b/src/libraries/System.Net.Http/src/System.Net.Http.csproj index f27128545761c..d8780e26ccc84 100644 --- a/src/libraries/System.Net.Http/src/System.Net.Http.csproj +++ b/src/libraries/System.Net.Http/src/System.Net.Http.csproj @@ -434,13 +434,12 @@ - - - @@ -453,6 +452,7 @@ + diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs index ebe0ba529a0b9..04c5bef5d95b7 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs @@ -107,22 +107,22 @@ public static async ValueTask EstablishSslConnectionAsync(SslClientAu public static async ValueTask ConnectQuicAsync(HttpRequestMessage request, DnsEndPoint endPoint, TimeSpan idleTimeout, SslClientAuthenticationOptions clientAuthenticationOptions, CancellationToken cancellationToken) { clientAuthenticationOptions = SetUpRemoteCertificateValidationCallback(clientAuthenticationOptions, request); - QuicConnection connection = await QuicConnection.ConnectAsync(new QuicClientConnectionOptions() - { - MaxBidirectionalStreams = 0, // Client doesn't support inbound streams: https://www.rfc-editor.org/rfc/rfc9114.html#name-bidirectional-streams. An extension might change this. - MaxUnidirectionalStreams = 5, // Minimum is 3: https://www.rfc-editor.org/rfc/rfc9114.html#unidirectional-streams (1x control stream + 2x QPACK). Set to 100 if/when support for PUSH streams is added. - IdleTimeout = idleTimeout, - RemoteEndPoint = endPoint, - ClientAuthenticationOptions = clientAuthenticationOptions - }, cancellationToken).ConfigureAwait(false); + try { - await connection.ConnectAsync(cancellationToken).ConfigureAwait(false); - return connection; + return await QuicConnection.ConnectAsync(new QuicClientConnectionOptions() + { + MaxInboundBidirectionalStreams = 0, // Client doesn't support inbound streams: https://www.rfc-editor.org/rfc/rfc9114.html#name-bidirectional-streams. An extension might change this. + MaxInboundUnidirectionalStreams = 5, // Minimum is 3: https://www.rfc-editor.org/rfc/rfc9114.html#unidirectional-streams (1x control stream + 2x QPACK). Set to 100 if/when support for PUSH streams is added. + IdleTimeout = idleTimeout, + DefaultStreamErrorCode = (long)Http3ErrorCode.RequestCancelled, + DefaultCloseErrorCode = (long)Http3ErrorCode.NoError, + RemoteEndPoint = endPoint, + ClientAuthenticationOptions = clientAuthenticationOptions + }, cancellationToken).ConfigureAwait(false); } catch (Exception ex) { - connection.Dispose(); throw CreateWrappedException(ex, endPoint.Host, endPoint.Port, cancellationToken); } } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3Connection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3Connection.cs index 3201d5de92472..e22e3502f6c9e 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3Connection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3Connection.cs @@ -132,7 +132,7 @@ private void CheckForShutdown() QuicConnection connection = _connection; _connection = null; - _ = _connectionClosedTask.ContinueWith(closeTask => + _ = _connectionClosedTask.ContinueWith(async closeTask => { if (closeTask.IsFaulted && NetEventSource.Log.IsEnabled()) { @@ -141,7 +141,7 @@ private void CheckForShutdown() try { - connection.Dispose(); + await connection.DisposeAsync().ConfigureAwait(false); } catch (Exception ex) { @@ -184,7 +184,7 @@ public async Task SendAsync(HttpRequestMessage request, lon queueStartingTimestamp = Stopwatch.GetTimestamp(); } - quicStream = await conn.OpenBidirectionalStreamAsync(cancellationToken).ConfigureAwait(false); + quicStream = await conn.OpenOutboundStreamAsync(QuicStreamType.Bidirectional, cancellationToken).ConfigureAwait(false); requestStream = new Http3RequestStream(request, this, quicStream); lock (SyncObj) @@ -210,7 +210,7 @@ public async Task SendAsync(HttpRequestMessage request, lon throw new HttpRequestException(SR.net_http_request_aborted, null, RequestRetryType.RetryOnConnectionFailure); } - requestStream!.StreamId = quicStream.StreamId; + requestStream!.StreamId = quicStream.Id; bool goAway; lock (SyncObj) @@ -366,7 +366,7 @@ private async Task SendSettingsAsync() { try { - _clientControl = await _connection!.OpenUnidirectionalStreamAsync().ConfigureAwait(false); + _clientControl = await _connection!.OpenOutboundStreamAsync(QuicStreamType.Unidirectional).ConfigureAwait(false); await _clientControl.WriteAsync(_pool.Settings.Http3SettingsFrame, CancellationToken.None).ConfigureAwait(false); } catch (Exception ex) @@ -410,7 +410,7 @@ private async Task AcceptStreamsAsync() } // No cancellation token is needed here; we expect the operation to cancel itself when _connection is disposed. - streamTask = _connection!.AcceptStreamAsync(CancellationToken.None); + streamTask = _connection!.AcceptInboundStreamAsync(CancellationToken.None); } QuicStream stream = await streamTask.ConfigureAwait(false); @@ -542,7 +542,7 @@ private async Task ProcessServerStreamAsync(QuicStream stream) NetEventSource.Info(this, $"Ignoring server-initiated stream of unknown type {unknownStreamType}."); } - stream.AbortRead((long)Http3ErrorCode.StreamCreationError); + stream.Abort(QuicAbortDirection.Read, (long)Http3ErrorCode.StreamCreationError); stream.Dispose(); return; } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs index 58a2159d3a007..3f16d194d0c8e 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs @@ -260,7 +260,7 @@ await Task.WhenAny(sendContentTask, readResponseTask).ConfigureAwait(false) == s // We're either observing GOAWAY, or the cancellationToken parameter has been canceled. if (cancellationToken.IsCancellationRequested) { - _stream.AbortWrite((long)Http3ErrorCode.RequestCancelled); + _stream.Abort(QuicAbortDirection.Write, (long)Http3ErrorCode.RequestCancelled); throw new TaskCanceledException(ex.Message, ex, cancellationToken); } else @@ -277,7 +277,7 @@ await Task.WhenAny(sendContentTask, readResponseTask).ConfigureAwait(false) == s } catch (Exception ex) { - _stream.AbortWrite((long)Http3ErrorCode.InternalError); + _stream.Abort(QuicAbortDirection.Write, (long)Http3ErrorCode.InternalError); if (ex is HttpRequestException) { throw; @@ -398,7 +398,7 @@ private async Task SendContentAsync(HttpContent content, CancellationToken cance } else { - _stream.Shutdown(); + _stream.CompleteWrites(); } if (HttpTelemetry.Log.IsEnabled()) HttpTelemetry.Log.RequestContentStop(writeStream.BytesWritten); @@ -814,7 +814,7 @@ private async ValueTask ReadHeadersAsync(long headersLength, CancellationToken c // https://tools.ietf.org/html/draft-ietf-quic-http-24#section-4.1.1 if (headersLength > _headerBudgetRemaining) { - _stream.AbortWrite((long)Http3ErrorCode.ExcessiveLoad); + _stream.Abort(QuicAbortDirection.Write, (long)Http3ErrorCode.ExcessiveLoad); throw new HttpRequestException(SR.Format(SR.net_http_response_headers_exceeded_length, _connection.Pool.Settings._maxResponseHeadersLength * 1024L)); } @@ -1201,12 +1201,12 @@ private void HandleReadResponseContentException(Exception ex, CancellationToken _connection.Abort(ex); throw new IOException(SR.net_http_client_execution_error, new HttpRequestException(SR.net_http_client_execution_error, ex)); case OperationCanceledException oce when oce.CancellationToken == cancellationToken: - _stream.AbortRead((long)Http3ErrorCode.RequestCancelled); + _stream.Abort(QuicAbortDirection.Read, (long)Http3ErrorCode.RequestCancelled); ExceptionDispatchInfo.Throw(ex); // Rethrow. return; // Never reached. } - _stream.AbortRead((long)Http3ErrorCode.InternalError); + _stream.Abort(QuicAbortDirection.Read, (long)Http3ErrorCode.InternalError); throw new IOException(SR.net_http_client_execution_error, new HttpRequestException(SR.net_http_client_execution_error, ex)); } @@ -1264,12 +1264,12 @@ private void AbortStream() // If the request body isn't completed, cancel it now. if (_requestContentLengthRemaining != 0) // 0 is used for the end of content writing, -1 is used for unknown Content-Length { - _stream.AbortWrite((long)Http3ErrorCode.RequestCancelled); + _stream.Abort(QuicAbortDirection.Write, (long)Http3ErrorCode.RequestCancelled); } // If the response body isn't completed, cancel it now. if (_responseDataPayloadRemaining != -1) // -1 is used for EOF, 0 for consumed DATA frame payload before the next read { - _stream.AbortRead((long)Http3ErrorCode.RequestCancelled); + _stream.Abort(QuicAbortDirection.Read, (long)Http3ErrorCode.RequestCancelled); } } diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.AltSvc.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.AltSvc.cs index 5cc4e1aad36bd..f71efde663105 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.AltSvc.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.AltSvc.cs @@ -81,7 +81,7 @@ public async Task AltSvc_ConnectionFrame_UpgradeFrom20_Success() Task firstResponseTask = client.GetAsync(firstServer.Address); Task serverTask = Task.Run(async () => { - using Http2LoopbackConnection connection = await firstServer.EstablishConnectionAsync(); + await using Http2LoopbackConnection connection = await firstServer.EstablishConnectionAsync(); int streamId = await connection.ReadRequestHeaderAsync(); await connection.WriteFrameAsync(new AltSvcFrame($"https://{firstServer.Address.IdnHost}:{firstServer.Address.Port}", $"h3=\"{secondServer.Address.IdnHost}:{secondServer.Address.Port}\"", streamId: 0)); @@ -106,7 +106,7 @@ public async Task AltSvc_ResponseFrame_UpgradeFrom20_Success() Task firstResponseTask = client.GetAsync(firstServer.Address); Task serverTask = Task.Run(async () => { - using Http2LoopbackConnection connection = await firstServer.EstablishConnectionAsync(); + await using Http2LoopbackConnection connection = await firstServer.EstablishConnectionAsync(); int streamId = await connection.ReadRequestHeaderAsync(); await connection.SendDefaultResponseHeadersAsync(streamId); diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http2.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http2.cs index dfc3625aaa7b1..41543b32f7348 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http2.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http2.cs @@ -291,7 +291,7 @@ public async Task Http2_ServerSendsInvalidSettingsValue_Error(SettingId settingI await Assert.ThrowsAsync(() => sendTask); - connection.Dispose(); + await connection.DisposeAsync(); } } @@ -2609,7 +2609,7 @@ public async Task ConnectAsync_ReadWriteWebSocketStream() Assert.Equal(0, await responseStream.ReadAsync(readBuffer).AsTask().WaitAsync(TimeSpan.FromSeconds(10))); Assert.NotNull(connection); - connection.Dispose(); + await connection.DisposeAsync(); } [Fact] diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs index 8d7c6987f53f9..f322751bfdd01 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs @@ -41,12 +41,12 @@ public async Task ClientSettingsReceived_Success(int headerSizeLimit) Task serverTask = Task.Run(async () => { - using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); + await using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); (Http3LoopbackStream settingsStream, Http3LoopbackStream requestStream) = await connection.AcceptControlAndRequestStreamAsync(); - using (settingsStream) - using (requestStream) + await using (settingsStream) + await using (requestStream) { Assert.False(settingsStream.CanWrite, "Expected unidirectional control stream."); Assert.Equal(headerSizeLimit * 1024L, connection.MaxHeaderListSize); @@ -81,14 +81,14 @@ public async Task ClientSettingsReceived_Success(int headerSizeLimit) [InlineData(1000)] public async Task SendMoreThanStreamLimitRequests_Succeeds(int streamLimit) { - using Http3LoopbackServer server = CreateHttp3LoopbackServer(new Http3Options() { MaxBidirectionalStreams = streamLimit }); + using Http3LoopbackServer server = CreateHttp3LoopbackServer(new Http3Options() { MaxInboundBidirectionalStreams = streamLimit }); Task serverTask = Task.Run(async () => { - using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); + await using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); for (int i = 0; i < streamLimit + 1; ++i) { - using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); + await using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); await stream.HandleRequestAsync(); } }); @@ -119,14 +119,14 @@ public async Task SendMoreThanStreamLimitRequests_Succeeds(int streamLimit) [InlineData(1000)] public async Task SendStreamLimitRequestsConcurrently_Succeeds(int streamLimit) { - using Http3LoopbackServer server = CreateHttp3LoopbackServer(new Http3Options() { MaxBidirectionalStreams = streamLimit }); + using Http3LoopbackServer server = CreateHttp3LoopbackServer(new Http3Options() { MaxInboundBidirectionalStreams = streamLimit }); Task serverTask = Task.Run(async () => { - using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); + await using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); for (int i = 0; i < streamLimit; ++i) { - using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); + await using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); await stream.HandleRequestAsync(); } }); @@ -165,13 +165,13 @@ public async Task SendStreamLimitRequestsConcurrently_Succeeds(int streamLimit) [InlineData(1000)] public async Task SendMoreThanStreamLimitRequestsConcurrently_LastWaits(int streamLimit) { - using Http3LoopbackServer server = CreateHttp3LoopbackServer(new Http3Options() { MaxBidirectionalStreams = streamLimit }); + using Http3LoopbackServer server = CreateHttp3LoopbackServer(new Http3Options() { MaxInboundBidirectionalStreams = streamLimit }); var lastRequestContentStarted = new TaskCompletionSource(); Task serverTask = Task.Run(async () => { // Read the first streamLimit requests, keep the streams open to make the last one wait. - using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); + await using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); var streams = new Http3LoopbackStream[streamLimit]; for (int i = 0; i < streamLimit; ++i) { @@ -183,7 +183,7 @@ public async Task SendMoreThanStreamLimitRequestsConcurrently_LastWaits(int stre // Make the last request running independently. var lastRequest = Task.Run(async () => { - using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); + await using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); await stream.HandleRequestAsync(); }); @@ -194,7 +194,7 @@ public async Task SendMoreThanStreamLimitRequestsConcurrently_LastWaits(int stre for (int i = 0; i < streamLimit; ++i) { await streams[i].SendResponseAsync(); - streams[i].Dispose(); + await streams[i].DisposeAsync(); // After the first request is fully processed, the last request should unblock and get processed. if (i == 0) { @@ -273,15 +273,15 @@ public async Task ReservedFrameType_Throws() Task serverTask = Task.Run(async () => { - using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); - using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); + await using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); + await using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); await stream.SendFrameAsync(ReservedHttp2PriorityFrameId, new byte[8]); QuicException ex = await AssertThrowsQuicExceptionAsync(QuicError.ConnectionAborted, async () => { await stream.HandleRequestAsync(); - using Http3LoopbackStream stream2 = await connection.AcceptRequestStreamAsync(); + await using Http3LoopbackStream stream2 = await connection.AcceptRequestStreamAsync(); }); Assert.Equal(UnexpectedFrameErrorCode, ex.ApplicationErrorCode); @@ -313,8 +313,8 @@ public async Task RequestSentResponseDisposed_ThrowsOnServer() Task serverTask = Task.Run(async () => { - using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); - using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); + await using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); + await using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); HttpRequestData request = await stream.ReadRequestDataAsync(); await stream.SendResponseHeadersAsync(); @@ -371,8 +371,8 @@ public async Task RequestSendingResponseDisposed_ThrowsOnServer() Task serverTask = Task.Run(async () => { - using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); - using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); + await using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); + await using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); HttpRequestData request = await stream.ReadRequestDataAsync(false); await stream.SendResponseHeadersAsync(); @@ -436,10 +436,10 @@ public async Task ServerCertificateCustomValidationCallback_Succeeds() Task serverTask = Task.Run(async () => { - using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); - using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); + await using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); + await using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); await stream.HandleRequestAsync(); - using Http3LoopbackStream stream2 = await connection.AcceptRequestStreamAsync(); + await using Http3LoopbackStream stream2 = await connection.AcceptRequestStreamAsync(); await stream2.HandleRequestAsync(); }); @@ -479,8 +479,8 @@ public async Task EmptyCustomContent_FlushHeaders() Task serverTask = Task.Run(async () => { - using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); - using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); + await using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); + await using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); // Receive headers and unblock the client. await stream.ReadRequestDataAsync(false); @@ -528,7 +528,7 @@ public async Task DisposeHttpClient_Http3ConnectionIsClosed() Task serverTask = Task.Run(async () => { - using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); + await using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); HttpRequestData request = await connection.ReadRequestDataAsync(); await connection.SendResponseAsync(); @@ -665,8 +665,8 @@ public async Task ResponseCancellation_ServerReceivesCancellation(CancellationTy Task serverTask = Task.Run(async () => { - using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); - using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); + await using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); + await using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); HttpRequestData request = await stream.ReadRequestDataAsync().ConfigureAwait(false); @@ -746,8 +746,8 @@ public async Task ResponseCancellation_BothCancellationTokenAndDispose_Success() Task serverTask = Task.Run(async () => { - using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); - using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); + await using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); + await using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); HttpRequestData request = await stream.ReadRequestDataAsync().ConfigureAwait(false); @@ -829,7 +829,7 @@ public async Task Alpn_H3_Success() Task serverTask = Task.Run(async () => { connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); - using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); + await using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); await stream.HandleRequestAsync(); }); @@ -850,7 +850,7 @@ public async Task Alpn_H3_Success() SslApplicationProtocol negotiatedAlpn = ExtractMsQuicNegotiatedAlpn(connection); Assert.Equal(new SslApplicationProtocol("h3"), negotiatedAlpn); - connection.Dispose(); + await connection.DisposeAsync(); } [Fact] @@ -890,31 +890,10 @@ public async Task Alpn_NonH3_NegotiationFailure() private SslApplicationProtocol ExtractMsQuicNegotiatedAlpn(Http3LoopbackConnection loopbackConnection) { - // TODO: rewrite after object structure change - // current structure: - // Http3LoopbackConnection -> private QuicConnection _connection - // QuicConnection -> private QuicConnectionProvider _provider (= MsQuicConnection) - // MsQuicConnection -> private SslApplicationProtocol _negotiatedAlpnProtocol - FieldInfo quicConnectionField = loopbackConnection.GetType().GetField("_connection", BindingFlags.Instance | BindingFlags.NonPublic); Assert.NotNull(quicConnectionField); - object quicConnection = quicConnectionField.GetValue(loopbackConnection); - Assert.NotNull(quicConnection); - Assert.Equal("QuicConnection", quicConnection.GetType().Name); - - FieldInfo msQuicConnectionField = quicConnection.GetType().GetField("_provider", BindingFlags.Instance | BindingFlags.NonPublic); - Assert.NotNull(msQuicConnectionField); - object msQuicConnection = msQuicConnectionField.GetValue(quicConnection); - Assert.NotNull(msQuicConnection); - Assert.Equal("MsQuicConnection", msQuicConnection.GetType().Name); - - FieldInfo alpnField = msQuicConnection.GetType().GetField("_negotiatedAlpnProtocol", BindingFlags.Instance | BindingFlags.NonPublic); - Assert.NotNull(alpnField); - object alpn = alpnField.GetValue(msQuicConnection); - Assert.NotNull(alpn); - Assert.IsType(alpn); - - return (SslApplicationProtocol)alpn; + QuicConnection quicConnection = Assert.IsType(quicConnectionField.GetValue(loopbackConnection)); + return quicConnection.NegotiatedApplicationProtocol; } [Theory] @@ -927,7 +906,7 @@ public async Task StatusCodes_ReceiveSuccess(HttpStatusCode statusCode, bool qpa Task serverTask = Task.Run(async () => { connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); - using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); + await using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); HttpRequestData request = await stream.ReadRequestDataAsync().ConfigureAwait(false); @@ -955,7 +934,7 @@ public async Task StatusCodes_ReceiveSuccess(HttpStatusCode statusCode, bool qpa await serverTask; Assert.NotNull(connection); - connection.Dispose(); + await connection.DisposeAsync(); } [Theory] @@ -1038,9 +1017,9 @@ public async Task EchoServerStreaming_DifferentMessageSize_Success(int messageSi await serverTask.WaitAsync(TimeSpan.FromSeconds(60)); - serverStream.Dispose(); + await serverStream.DisposeAsync(); Assert.NotNull(connection); - connection.Dispose(); + await connection.DisposeAsync(); } [Fact] @@ -1103,9 +1082,9 @@ public async Task RequestContentStreaming_Timeout_BothClientAndServerReceiveCanc Assert.Equal(268 /*H3_REQUEST_CANCELLED (0x10C)*/, ex.ApplicationErrorCode); Assert.NotNull(serverStream); - serverStream.Dispose(); + await serverStream.DisposeAsync(); Assert.NotNull(connection); - connection.Dispose(); + await connection.DisposeAsync(); } [Fact] @@ -1168,9 +1147,9 @@ public async Task RequestContentStreaming_Cancellation_BothClientAndServerReceiv Assert.Equal(268 /*H3_REQUEST_CANCELLED (0x10C)*/, ex.ApplicationErrorCode); Assert.NotNull(serverStream); - serverStream.Dispose(); + await serverStream.DisposeAsync(); Assert.NotNull(connection); - connection.Dispose(); + await connection.DisposeAsync(); } [Fact] @@ -1249,9 +1228,9 @@ public async Task DuplexStreaming_RequestCTCancellation_DoesNotApply() await serverTask.WaitAsync(TimeSpan.FromSeconds(120)); Assert.NotNull(serverStream); - serverStream.Dispose(); + await serverStream.DisposeAsync(); Assert.NotNull(connection); - connection.Dispose(); + await connection.DisposeAsync(); } [Theory] @@ -1335,9 +1314,9 @@ public async Task DuplexStreaming_AbortByServer_StreamingCancelled(bool graceful await serverTask.WaitAsync(TimeSpan.FromSeconds(120)); Assert.NotNull(serverStream); - serverStream.Dispose(); + await serverStream.DisposeAsync(); Assert.NotNull(connection); - connection.Dispose(); + await connection.DisposeAsync(); } private static async Task AssertThrowsQuicExceptionAsync(QuicError expectedError, Func testCode) diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Http2KeepAlivePing.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Http2KeepAlivePing.cs index 57ab51d3fa175..b2af9384413eb 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Http2KeepAlivePing.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Http2KeepAlivePing.cs @@ -323,7 +323,7 @@ private async Task ProcessIncomingFramesAsync(CancellationToken cancellationToke } _output?.WriteLine("ProcessIncomingFramesAsync finished"); - _connection.Dispose(); + await _connection.DisposeAsync(); } private void DisablePingResponse() => Interlocked.Exchange(ref _sendPingResponse, 0); diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs index 6797e8e07a82b..21347f3901e5f 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs @@ -2250,21 +2250,21 @@ public async Task Http2_MultipleConnectionsEnabled_ManyRequestsEnqueuedSimultane List<(Http2LoopbackConnection connection, int streamId)> acceptedRequests = new List<(Http2LoopbackConnection connection, int streamId)>(); - using Http2LoopbackConnection c1 = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.MaxConcurrentStreams, Value = 100 }); + await using Http2LoopbackConnection c1 = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.MaxConcurrentStreams, Value = 100 }); for (int i = 0; i < MaxConcurrentStreams; i++) { (int streamId, _) = await c1.ReadAndParseRequestHeaderAsync(); acceptedRequests.Add((c1, streamId)); } - using Http2LoopbackConnection c2 = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.MaxConcurrentStreams, Value = 100 }); + await using Http2LoopbackConnection c2 = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.MaxConcurrentStreams, Value = 100 }); for (int i = 0; i < MaxConcurrentStreams; i++) { (int streamId, _) = await c2.ReadAndParseRequestHeaderAsync(); acceptedRequests.Add((c2, streamId)); } - using Http2LoopbackConnection c3 = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.MaxConcurrentStreams, Value = 100 }); + await using Http2LoopbackConnection c3 = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.MaxConcurrentStreams, Value = 100 }); (int finalStreamId, _) = await c3.ReadAndParseRequestHeaderAsync(); acceptedRequests.Add((c3, finalStreamId)); @@ -2646,7 +2646,7 @@ public async Task ConnectCallback_UseMemoryBuffer_Success(bool useSsl) Task serverTask = Task.Run(async () => { - using GenericLoopbackConnection loopbackConnection = await LoopbackServerFactory.CreateConnectionAsync(socket: null, serverStream, options); + await using GenericLoopbackConnection loopbackConnection = await LoopbackServerFactory.CreateConnectionAsync(socket: null, serverStream, options); await loopbackConnection.InitializeConnectionAsync(); HttpRequestData requestData = await loopbackConnection.ReadRequestDataAsync(); @@ -2708,7 +2708,7 @@ public async Task ConnectCallback_UseUnixDomainSocket_Success(bool useSsl) Task clientTask = client.GetStringAsync($"{(options.UseSsl ? "https" : "http")}://{guid}/foo"); Socket serverSocket = await listenSocket.AcceptAsync(); - using (GenericLoopbackConnection loopbackConnection = await LoopbackServerFactory.CreateConnectionAsync(socket: null, new NetworkStream(serverSocket, ownsSocket: true), options)) + await using (GenericLoopbackConnection loopbackConnection = await LoopbackServerFactory.CreateConnectionAsync(socket: null, new NetworkStream(serverSocket, ownsSocket: true), options)) { await loopbackConnection.InitializeConnectionAsync(); @@ -2771,7 +2771,7 @@ public async Task ConnectCallback_ConnectionPrefix_Success(bool useSsl) await serverStream.WriteAsync(ResponsePrefix); - using GenericLoopbackConnection loopbackConnection = await LoopbackServerFactory.CreateConnectionAsync(socket: null, serverStream, options); + await using GenericLoopbackConnection loopbackConnection = await LoopbackServerFactory.CreateConnectionAsync(socket: null, serverStream, options); await loopbackConnection.InitializeConnectionAsync(); HttpRequestData requestData = await loopbackConnection.ReadRequestDataAsync(); @@ -3271,7 +3271,7 @@ public async Task PlaintextStreamFilter_ConnectionPrefix_Success(bool useSsl) await serverStream.WriteAsync(ResponsePrefix); - using GenericLoopbackConnection loopbackConnection = await LoopbackServerFactory.CreateConnectionAsync(socket: null, serverStream, new GenericLoopbackOptions() { UseSsl = false }); + await using GenericLoopbackConnection loopbackConnection = await LoopbackServerFactory.CreateConnectionAsync(socket: null, serverStream, new GenericLoopbackOptions() { UseSsl = false }); await loopbackConnection.InitializeConnectionAsync(); HttpRequestData requestData = await loopbackConnection.ReadRequestDataAsync(); diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/System.Net.Http.Functional.Tests.csproj b/src/libraries/System.Net.Http/tests/FunctionalTests/System.Net.Http.Functional.Tests.csproj index 8a56b5033f29e..3f45782c7f14d 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/System.Net.Http.Functional.Tests.csproj +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/System.Net.Http.Functional.Tests.csproj @@ -300,7 +300,6 @@ - diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/TelemetryTest.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/TelemetryTest.cs index 6f31d7088501b..f4458d3a276ce 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/TelemetryTest.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/TelemetryTest.cs @@ -664,7 +664,7 @@ await GetFactoryForVersion(version).CreateClientAndServerAsync( connection = await server.EstablishGenericConnectionAsync(); } - using (connection) + await using (connection) { // Dummy request to ensure that the MaxConcurrentStreams setting has been acknowledged await connection.ReadRequestDataAsync(readBody: false); @@ -682,7 +682,7 @@ await GetFactoryForVersion(version).CreateClientAndServerAsync( await connection.ReadRequestDataAsync(readBody: false); await connection.SendResponseAsync(); }; - }, options: new Http3Options { MaxBidirectionalStreams = 1 }); + }, options: new Http3Options { MaxInboundBidirectionalStreams = 1 }); await WaitForEventCountersAsync(events); }); diff --git a/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs b/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs index ae739503f7730..fed27168ef8ca 100644 --- a/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs +++ b/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs @@ -6,38 +6,43 @@ namespace System.Net.Quic { + [System.FlagsAttribute] + public enum QuicAbortDirection + { + Read = 1, + Write = 2, + Both = 3, + } public sealed partial class QuicClientConnectionOptions : System.Net.Quic.QuicConnectionOptions { public QuicClientConnectionOptions() { } - public required System.Net.Security.SslClientAuthenticationOptions ClientAuthenticationOptions { get { throw null; } set { } } + public System.Net.Security.SslClientAuthenticationOptions ClientAuthenticationOptions { get { throw null; } set { } } public System.Net.IPEndPoint? LocalEndPoint { get { throw null; } set { } } - public required System.Net.EndPoint RemoteEndPoint { get { throw null; } set { } } + public System.Net.EndPoint RemoteEndPoint { get { throw null; } set { } } } - public sealed partial class QuicConnection : System.IDisposable + public sealed partial class QuicConnection : System.IAsyncDisposable { internal QuicConnection() { } - public bool Connected { get { throw null; } } public static bool IsSupported { get { throw null; } } - public System.Net.IPEndPoint? LocalEndPoint { get { throw null; } } + public System.Net.IPEndPoint LocalEndPoint { get { throw null; } } public System.Net.Security.SslApplicationProtocol NegotiatedApplicationProtocol { get { throw null; } } public System.Security.Cryptography.X509Certificates.X509Certificate? RemoteCertificate { get { throw null; } } - public System.Net.EndPoint RemoteEndPoint { get { throw null; } } - public System.Threading.Tasks.ValueTask AcceptStreamAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public System.Net.IPEndPoint RemoteEndPoint { get { throw null; } } + public System.Threading.Tasks.ValueTask AcceptInboundStreamAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public System.Threading.Tasks.ValueTask CloseAsync(long errorCode, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public static System.Threading.Tasks.ValueTask ConnectAsync(System.Net.Quic.QuicClientConnectionOptions options, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } - public System.Threading.Tasks.ValueTask ConnectAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } - public void Dispose() { } - public int GetRemoteAvailableBidirectionalStreamCount() { throw null; } - public int GetRemoteAvailableUnidirectionalStreamCount() { throw null; } - public System.Threading.Tasks.ValueTask OpenBidirectionalStreamAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } - public System.Threading.Tasks.ValueTask OpenUnidirectionalStreamAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public System.Threading.Tasks.ValueTask DisposeAsync() { throw null; } + public System.Threading.Tasks.ValueTask OpenOutboundStreamAsync(System.Net.Quic.QuicStreamType type, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public override string ToString() { throw null; } } public abstract partial class QuicConnectionOptions { internal QuicConnectionOptions() { } + public long DefaultCloseErrorCode { get { throw null; } set { } } + public long DefaultStreamErrorCode { get { throw null; } set { } } public System.TimeSpan IdleTimeout { get { throw null; } set { } } - public int MaxBidirectionalStreams { get { throw null; } set { } } - public int MaxUnidirectionalStreams { get { throw null; } set { } } + public int MaxInboundBidirectionalStreams { get { throw null; } set { } } + public int MaxInboundUnidirectionalStreams { get { throw null; } set { } } } public enum QuicError { @@ -74,15 +79,15 @@ internal QuicListener() { } public sealed partial class QuicListenerOptions { public QuicListenerOptions() { } - public required System.Collections.Generic.List ApplicationProtocols { get { throw null; } set { } } - public required System.Func> ConnectionOptionsCallback { get { throw null; } set { } } + public System.Collections.Generic.List ApplicationProtocols { get { throw null; } set { } } + public System.Func> ConnectionOptionsCallback { get { throw null; } set { } } public int ListenBacklog { get { throw null; } set { } } - public required System.Net.IPEndPoint ListenEndPoint { get { throw null; } set { } } + public System.Net.IPEndPoint ListenEndPoint { get { throw null; } set { } } } public sealed partial class QuicServerConnectionOptions : System.Net.Quic.QuicConnectionOptions { public QuicServerConnectionOptions() { } - public required System.Net.Security.SslServerAuthenticationOptions ServerAuthenticationOptions { get { throw null; } set { } } + public System.Net.Security.SslServerAuthenticationOptions ServerAuthenticationOptions { get { throw null; } set { } } } public sealed partial class QuicStream : System.IO.Stream { @@ -91,38 +96,41 @@ internal QuicStream() { } public override bool CanSeek { get { throw null; } } public override bool CanTimeout { get { throw null; } } public override bool CanWrite { get { throw null; } } + public long Id { get { throw null; } } public override long Length { get { throw null; } } public override long Position { get { throw null; } set { } } - public bool ReadsCompleted { get { throw null; } } + public System.Threading.Tasks.Task ReadsClosed { get { throw null; } } public override int ReadTimeout { get { throw null; } set { } } - public long StreamId { get { throw null; } } + public System.Net.Quic.QuicStreamType Type { get { throw null; } } + public System.Threading.Tasks.Task WritesClosed { get { throw null; } } public override int WriteTimeout { get { throw null; } set { } } - public void AbortRead(long errorCode) { } - public void AbortWrite(long errorCode) { } + public void Abort(System.Net.Quic.QuicAbortDirection abortDirection, long errorCode) { } public override System.IAsyncResult BeginRead(byte[] buffer, int offset, int count, System.AsyncCallback? callback, object? state) { throw null; } public override System.IAsyncResult BeginWrite(byte[] buffer, int offset, int count, System.AsyncCallback? callback, object? state) { throw null; } + public void CompleteWrites() { } protected override void Dispose(bool disposing) { } + public override System.Threading.Tasks.ValueTask DisposeAsync() { throw null; } public override int EndRead(System.IAsyncResult asyncResult) { throw null; } public override void EndWrite(System.IAsyncResult asyncResult) { } public override void Flush() { } - public override System.Threading.Tasks.Task FlushAsync(System.Threading.CancellationToken cancellationToken) { throw null; } + public override System.Threading.Tasks.Task FlushAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public override int Read(byte[] buffer, int offset, int count) { throw null; } public override int Read(System.Span buffer) { throw null; } - public override System.Threading.Tasks.Task ReadAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken) { throw null; } + public override System.Threading.Tasks.Task ReadAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public override System.Threading.Tasks.ValueTask ReadAsync(System.Memory buffer, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public override int ReadByte() { throw null; } public override long Seek(long offset, System.IO.SeekOrigin origin) { throw null; } public override void SetLength(long value) { } - public void Shutdown() { } - public System.Threading.Tasks.ValueTask ShutdownCompleted(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } - public System.Threading.Tasks.ValueTask WaitForWriteCompletionAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public override void Write(byte[] buffer, int offset, int count) { } public override void Write(System.ReadOnlySpan buffer) { } - public System.Threading.Tasks.ValueTask WriteAsync(System.Buffers.ReadOnlySequence buffers, bool endStream, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } - public System.Threading.Tasks.ValueTask WriteAsync(System.Buffers.ReadOnlySequence buffers, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } - public override System.Threading.Tasks.Task WriteAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken) { throw null; } - public System.Threading.Tasks.ValueTask WriteAsync(System.ReadOnlyMemory buffer, bool endStream, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public override System.Threading.Tasks.Task WriteAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public System.Threading.Tasks.ValueTask WriteAsync(System.ReadOnlyMemory buffer, bool completeWrites, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public override System.Threading.Tasks.ValueTask WriteAsync(System.ReadOnlyMemory buffer, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public override void WriteByte(byte value) { } } + public enum QuicStreamType + { + Unidirectional = 0, + Bidirectional = 1, + } } diff --git a/src/libraries/System.Net.Quic/ref/System.Net.Quic.csproj b/src/libraries/System.Net.Quic/ref/System.Net.Quic.csproj index 4e5d642f64624..cbb50d6c555b1 100644 --- a/src/libraries/System.Net.Quic/ref/System.Net.Quic.csproj +++ b/src/libraries/System.Net.Quic/ref/System.Net.Quic.csproj @@ -1,9 +1,6 @@ $(NetCoreAppCurrent) - - true diff --git a/src/libraries/System.Net.Quic/src/Resources/Strings.resx b/src/libraries/System.Net.Quic/src/Resources/Strings.resx index 2b74167931721..f2581b3d16eea 100644 --- a/src/libraries/System.Net.Quic/src/Resources/Strings.resx +++ b/src/libraries/System.Net.Quic/src/Resources/Strings.resx @@ -117,9 +117,6 @@ System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 - - Only IPv4 or IPv6 are supported - Connection aborted by peer ({0}). @@ -141,15 +138,36 @@ System.Net.Quic is not supported on this platform. - - Unsupported address family of '{0}' for remote endpoint. - Writing is not allowed on stream. + + '{0}'' should be within [0, {1}) range. + + + '{0}' must be specified to start the listener. + + + '{0}' must be specified and contain at least one item to start the listener. + + + '{0}' must be specified to open the connection. + + + '{0}' must be specified to accept the connection. + Timeout can only be set to 'System.Threading.Timeout.Infinite' or a value > 0. + + '{0}' in not supported remote endpoint type, expected IP or DNS endpoint." + + + '{0}' must be specified and contain at least one item to establish the connection. + + + Server must provide a certificate in '{0}' or '{1}' or via '{2}' for the connection. + Connection timed out waiting for a response from the peer. @@ -162,21 +180,18 @@ The remote certificate is invalid because of errors in the certificate chain: {0} - - Connection is not connected. - An internal error has occured. {0} - - The application protocol list is invalid. - Could not use a TLS version required by Quic. TLS 1.3 may have been disabled in the registry. CipherSuitePolicy must specify at least one cipher supported by QUIC. + + QuicConnection is configured to not accept any streams. + The local address is already in use. @@ -211,15 +226,12 @@ Authentication failed because the remote party sent a TLS alert: '{0}'. - + The AddressFamily {0} is not valid for the {1} end point, use {2} instead. The supplied {0} is an invalid size for the {1} end point. - - QuicConnection is configured to not accept any streams. - diff --git a/src/libraries/System.Net.Quic/src/System.Net.Quic.csproj b/src/libraries/System.Net.Quic/src/System.Net.Quic.csproj index 5d35afdb51134..2511b13f92237 100644 --- a/src/libraries/System.Net.Quic/src/System.Net.Quic.csproj +++ b/src/libraries/System.Net.Quic/src/System.Net.Quic.csproj @@ -34,6 +34,7 @@ + diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Internal/MsQuicAddressHelpers.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Internal/MsQuicAddressHelpers.cs deleted file mode 100644 index 436405fc30281..0000000000000 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Internal/MsQuicAddressHelpers.cs +++ /dev/null @@ -1,35 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Diagnostics; -using System.Net.Sockets; -using System.Runtime.CompilerServices; -using System.Runtime.InteropServices; -using Microsoft.Quic; - -namespace System.Net.Quic.Implementations.MsQuic.Internal -{ - internal static class MsQuicAddressHelpers - { - internal static unsafe IPEndPoint ToIPEndPoint(this ref QuicAddr quicAddress) - { - // MsQuic always uses storage size as if IPv6 was used - // QuicAddr is native memory, it cannot be moved by GC, thus no need for fixed expression here. - Span addressBytes = new Span((byte*)Unsafe.AsPointer(ref quicAddress), Internals.SocketAddress.IPv6AddressSize); - return new Internals.SocketAddress(SocketAddressPal.GetAddressFamily(addressBytes), addressBytes).GetIPEndPoint(); - } - - internal static unsafe QuicAddr ToQuicAddr(this IPEndPoint iPEndPoint) - { - // TODO: is the layout same for SocketAddress.Buffer and QuicAddr on all platforms? - QuicAddr result = default; - Span rawAddress = MemoryMarshal.AsBytes(new Span(ref result)); - - Internals.SocketAddress address = IPEndPointExtensions.Serialize(iPEndPoint); - Debug.Assert(address.Size <= rawAddress.Length); - - address.Buffer.AsSpan(0, address.Size).CopyTo(rawAddress); - return result; - } - } -} diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Internal/MsQuicApi.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Internal/MsQuicApi.cs deleted file mode 100644 index 5cf30951913a9..0000000000000 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Internal/MsQuicApi.cs +++ /dev/null @@ -1,153 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Diagnostics.CodeAnalysis; -using System.Runtime.InteropServices; -using System.Text; -using Microsoft.Quic; - -using static Microsoft.Quic.MsQuic; - -#if TARGET_WINDOWS -using Microsoft.Win32; -#endif - -namespace System.Net.Quic.Implementations.MsQuic.Internal -{ - internal sealed unsafe class MsQuicApi - { - private static readonly Version MinWindowsVersion = new Version(10, 0, 20145, 1000); - - private static readonly Version MsQuicVersion = new Version(2, 0); - - public MsQuicSafeHandle Registration { get; } - - public QUIC_API_TABLE* ApiTable { get; } - - // This is workaround for a bug in ILTrimmer. - // Without these DynamicDependency attributes, .ctor() will be removed from the safe handles. - // Remove once fixed: https://github.com/mono/linker/issues/1660 - [DynamicDependency(DynamicallyAccessedMemberTypes.PublicConstructors, typeof(SafeMsQuicConfigurationHandle))] - [DynamicDependency(DynamicallyAccessedMemberTypes.PublicConstructors, typeof(SafeMsQuicConnectionHandle))] - [DynamicDependency(DynamicallyAccessedMemberTypes.PublicConstructors, typeof(SafeMsQuicStreamHandle))] - [DynamicDependency(DynamicallyAccessedMemberTypes.PublicConstructors, typeof(MsQuicSafeHandle))] - [DynamicDependency(DynamicallyAccessedMemberTypes.PublicConstructors, typeof(MsQuicContextSafeHandle))] - private MsQuicApi(QUIC_API_TABLE* apiTable) - { - ApiTable = apiTable; - - fixed (byte* pAppName = "System.Net.Quic"u8) - { - var cfg = new QUIC_REGISTRATION_CONFIG - { - AppName = (sbyte*)pAppName, - ExecutionProfile = QUIC_EXECUTION_PROFILE.LOW_LATENCY - }; - - QUIC_HANDLE* handle; - ThrowHelper.ThrowIfMsQuicError(ApiTable->RegistrationOpen(&cfg, &handle), "RegistrationOpen failed"); - - Registration = new MsQuicSafeHandle(handle, apiTable->RegistrationClose, SafeHandleType.Registration); - } - } - - internal static MsQuicApi Api { get; } = null!; - - internal static bool IsQuicSupported { get; } - - internal static bool Tls13ServerMayBeDisabled { get; } - internal static bool Tls13ClientMayBeDisabled { get; } - - static MsQuicApi() - { - if (OperatingSystem.IsWindows()) - { - if (!IsWindowsVersionSupported()) - { - if (NetEventSource.Log.IsEnabled()) - { - NetEventSource.Info(null, $"Current Windows version ({Environment.OSVersion}) is not supported by QUIC. Minimal supported version is {MinWindowsVersion}"); - } - - return; - } - - Tls13ServerMayBeDisabled = IsTls13Disabled(true); - Tls13ClientMayBeDisabled = IsTls13Disabled(false); - } - - IntPtr msQuicHandle; - if (NativeLibrary.TryLoad($"{Interop.Libraries.MsQuic}.{MsQuicVersion.Major}", typeof(MsQuicApi).Assembly, DllImportSearchPath.AssemblyDirectory, out msQuicHandle) || - NativeLibrary.TryLoad(Interop.Libraries.MsQuic, typeof(MsQuicApi).Assembly, DllImportSearchPath.AssemblyDirectory, out msQuicHandle)) - { - try - { - if (NativeLibrary.TryGetExport(msQuicHandle, "MsQuicOpenVersion", out IntPtr msQuicOpenVersionAddress)) - { - QUIC_API_TABLE* apiTable; - delegate* unmanaged[Cdecl] msQuicOpenVersion = (delegate* unmanaged[Cdecl])msQuicOpenVersionAddress; - if (StatusSucceeded(msQuicOpenVersion((uint)MsQuicVersion.Major, &apiTable))) - { - int arraySize = 4; - uint* libVersion = stackalloc uint[arraySize]; - uint size = (uint)arraySize * sizeof(uint); - if (StatusSucceeded(apiTable->GetParam(null, QUIC_PARAM_GLOBAL_LIBRARY_VERSION, &size, libVersion))) - { - var version = new Version((int)libVersion[0], (int)libVersion[1], (int)libVersion[2], (int)libVersion[3]); - if (version >= MsQuicVersion) - { - Api = new MsQuicApi(apiTable); - IsQuicSupported = true; - } - else - { - if (NetEventSource.Log.IsEnabled()) - { - NetEventSource.Info(null, $"Incompatible MsQuic library version '{version}', expecting '{MsQuicVersion}'"); - } - } - } - } - } - } - finally - { - if (!IsQuicSupported) - { - NativeLibrary.Free(msQuicHandle); - } - } - } - } - - private static bool IsWindowsVersionSupported() => OperatingSystem.IsWindowsVersionAtLeast(MinWindowsVersion.Major, - MinWindowsVersion.Minor, MinWindowsVersion.Build, MinWindowsVersion.Revision); - - private static bool IsTls13Disabled(bool isServer) - { -#if TARGET_WINDOWS - string SChannelTls13RegistryKey = isServer - ? @"SYSTEM\CurrentControlSet\Control\SecurityProviders\SCHANNEL\Protocols\TLS 1.3\Server" - : @"SYSTEM\CurrentControlSet\Control\SecurityProviders\SCHANNEL\Protocols\TLS 1.3\Client"; - - using var regKey = Registry.LocalMachine.OpenSubKey(SChannelTls13RegistryKey); - - if (regKey is null) - { - return false; - } - - if (regKey.GetValue("Enabled") is int enabled && enabled == 0) - { - return true; - } - - if (regKey.GetValue("DisabledByDefault") is int disabled && disabled == 1) - { - return true; - } -#endif - return false; - } - } -} diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Internal/MsQuicParameterHelpers.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Internal/MsQuicParameterHelpers.cs deleted file mode 100644 index 00cc40d71b4cd..0000000000000 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Internal/MsQuicParameterHelpers.cs +++ /dev/null @@ -1,101 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Diagnostics; -using System.Runtime.InteropServices; -using System.Net.Sockets; -using Microsoft.Quic; -using static Microsoft.Quic.MsQuic; - -namespace System.Net.Quic.Implementations.MsQuic.Internal -{ - internal static class MsQuicParameterHelpers - { - internal static unsafe IPEndPoint GetIPEndPointParam(MsQuicApi api, MsQuicSafeHandle nativeObject, uint param, AddressFamily? addressFamilyOverride = null) - { - // MsQuic always uses storage size as if IPv6 was used - uint valueLen = (uint)Internals.SocketAddress.IPv6AddressSize; - Span address = stackalloc byte[Internals.SocketAddress.IPv6AddressSize]; - - fixed (byte* paddress = &MemoryMarshal.GetReference(address)) - { - ThrowHelper.ThrowIfMsQuicError(api.ApiTable->GetParam( - nativeObject.QuicHandle, - param, - &valueLen, - paddress), "GetIPEndPointParam failed."); - } - - address = address.Slice(0, (int)valueLen); - - return new Internals.SocketAddress(addressFamilyOverride ?? SocketAddressPal.GetAddressFamily(address), address).GetIPEndPoint(); - } - - internal static unsafe void SetIPEndPointParam(MsQuicApi api, MsQuicSafeHandle nativeObject, uint param, IPEndPoint value) - { - Internals.SocketAddress socketAddress = IPEndPointExtensions.Serialize(value); - - // MsQuic always reads same amount of memory as if IPv6 was used, so we can't pass pointer to socketAddress.Buffer directly - Span address = stackalloc byte[Internals.SocketAddress.IPv6AddressSize]; - socketAddress.Buffer.AsSpan(0, socketAddress.Size).CopyTo(address); - address.Slice(socketAddress.Size).Clear(); - - fixed (byte* paddress = &MemoryMarshal.GetReference(address)) - { - ThrowHelper.ThrowIfMsQuicError(api.ApiTable->SetParam( - nativeObject.QuicHandle, - param, - (uint)address.Length, - paddress), "Could not set IPEndPoint"); - } - } - - internal static unsafe ushort GetUShortParam(MsQuicApi api, MsQuicSafeHandle nativeObject, uint param) - { - ushort value; - uint valueLen = (uint)sizeof(ushort); - - ThrowHelper.ThrowIfMsQuicError(api.ApiTable->GetParam( - nativeObject.QuicHandle, - param, - &valueLen, - (byte*)&value), "GetUShortParam failed"); - Debug.Assert(valueLen == sizeof(ushort)); - - return value; - } - - internal static unsafe void SetUShortParam(MsQuicApi api, MsQuicSafeHandle nativeObject, uint param, ushort value) - { - ThrowHelper.ThrowIfMsQuicError(api.ApiTable->SetParam( - nativeObject.QuicHandle, - param, - sizeof(ushort), - (byte*)&value), "Could not set ushort"); - } - - internal static unsafe ulong GetULongParam(MsQuicApi api, MsQuicSafeHandle nativeObject, uint param) - { - ulong value; - uint valueLen = (uint)sizeof(ulong); - - ThrowHelper.ThrowIfMsQuicError(api.ApiTable->GetParam( - nativeObject.QuicHandle, - param, - &valueLen, - (byte*)&value), "GetULongParam failed"); - Debug.Assert(valueLen == sizeof(ulong)); - - return value; - } - - internal static unsafe void SetULongParam(MsQuicApi api, MsQuicSafeHandle nativeObject, uint param, ulong value) - { - ThrowHelper.ThrowIfMsQuicError(api.ApiTable->SetParam( - nativeObject.QuicHandle, - param, - sizeof(ulong), - (byte*)&value), "Could not set ulong"); - } - } -} diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Internal/ResettableCompletionSource.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Internal/ResettableCompletionSource.cs deleted file mode 100644 index fb94687962ee1..0000000000000 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Internal/ResettableCompletionSource.cs +++ /dev/null @@ -1,84 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Threading.Tasks; -using System.Threading.Tasks.Sources; - -namespace System.Net.Quic.Implementations.MsQuic.Internal -{ - /// - /// A resettable completion source which can be completed multiple times. - /// Used to make methods async between completed events and their associated async method. - /// - internal sealed class ResettableCompletionSource : IValueTaskSource, IValueTaskSource - { - private ManualResetValueTaskSourceCore _valueTaskSource; - - public ResettableCompletionSource() - { - _valueTaskSource.RunContinuationsAsynchronously = true; - } - - public ValueTask GetValueTask() - { - return new ValueTask(this, _valueTaskSource.Version); - } - - public ValueTask GetTypelessValueTask() - { - return new ValueTask(this, _valueTaskSource.Version); - } - - public ValueTaskSourceStatus GetStatus(short token) - { - return _valueTaskSource.GetStatus(token); - } - - public void OnCompleted(Action continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags) - { - _valueTaskSource.OnCompleted(continuation, state, token, flags); - } - - public void Complete(T result) - { - _valueTaskSource.SetResult(result); - } - - public void CompleteException(Exception ex) - { - _valueTaskSource.SetException(ex); - } - - public T GetResult(short token) - { - bool isValid = token == _valueTaskSource.Version; - try - { - return _valueTaskSource.GetResult(token); - } - finally - { - if (isValid) - { - _valueTaskSource.Reset(); - } - } - } - - void IValueTaskSource.GetResult(short token) - { - bool isValid = token == _valueTaskSource.Version; - try - { - _valueTaskSource.GetResult(token); - } - finally - { - if (isValid) - { - _valueTaskSource.Reset(); - } - } - } - } - } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/SafeMsQuicConfigurationHandle.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/SafeMsQuicConfigurationHandle.cs deleted file mode 100644 index fb0f9fc11ffec..0000000000000 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/SafeMsQuicConfigurationHandle.cs +++ /dev/null @@ -1,297 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Buffers; -using System.Collections.Generic; -using System.Diagnostics; -using System.Net.Security; -using System.Runtime.InteropServices; -using System.Security.Cryptography.X509Certificates; -using System.Threading; -using Microsoft.Quic; -using static Microsoft.Quic.MsQuic; - -namespace System.Net.Quic.Implementations.MsQuic.Internal -{ - internal sealed class SafeMsQuicConfigurationHandle : MsQuicSafeHandle - { - public unsafe SafeMsQuicConfigurationHandle(QUIC_HANDLE* handle) - : base(handle, MsQuicApi.Api.ApiTable->ConfigurationClose, SafeHandleType.Configuration) - { } - - // TODO: consider moving the static code from here to keep all the handle classes small and simple. - public static SafeMsQuicConfigurationHandle Create(QuicClientConnectionOptions options) - { - X509Certificate? certificate = null; - - if (options.ClientAuthenticationOptions != null) - { - SslClientAuthenticationOptions clientAuthenticationOptions = options.ClientAuthenticationOptions; - -#pragma warning disable SYSLIB0040 // NoEncryption and AllowNoEncryption are obsolete - if (clientAuthenticationOptions.EncryptionPolicy == EncryptionPolicy.NoEncryption) - { - throw new PlatformNotSupportedException(SR.Format(SR.net_quic_ssl_option, nameof(clientAuthenticationOptions.EncryptionPolicy))); - } -#pragma warning restore SYSLIB0040 - - if (clientAuthenticationOptions.LocalCertificateSelectionCallback != null) - { - X509Certificate? cert = clientAuthenticationOptions.LocalCertificateSelectionCallback( - options, - clientAuthenticationOptions.TargetHost ?? string.Empty, - clientAuthenticationOptions.ClientCertificates ?? new X509CertificateCollection(), - null, - Array.Empty()); - - if (cert is X509Certificate2 cert2 && cert2.Handle != IntPtr.Zero && cert2.HasPrivateKey) - { - certificate = cert; - } - } - else if (clientAuthenticationOptions.ClientCertificates != null) - { - foreach (X509Certificate cert in clientAuthenticationOptions.ClientCertificates) - { - - if (cert is X509Certificate2 cert2 && cert2.Handle != IntPtr.Zero && cert2.HasPrivateKey) - { - // Pick first certificate with private key. - certificate = cert; - break; - } - } - } - } - - QUIC_CREDENTIAL_FLAGS flags = QUIC_CREDENTIAL_FLAGS.CLIENT; - if (OperatingSystem.IsWindows()) - { - flags |= QUIC_CREDENTIAL_FLAGS.USE_SUPPLIED_CREDENTIALS; - } - return Create(options, flags, certificate: certificate, certificateContext: null, options.ClientAuthenticationOptions?.ApplicationProtocols, options.ClientAuthenticationOptions?.CipherSuitesPolicy); - } - - public static SafeMsQuicConfigurationHandle Create(QuicServerConnectionOptions options, SslServerAuthenticationOptions? serverAuthenticationOptions, string? targetHost = null) - { - QUIC_CREDENTIAL_FLAGS flags = QUIC_CREDENTIAL_FLAGS.NONE; - X509Certificate? certificate = serverAuthenticationOptions?.ServerCertificate; - - if (serverAuthenticationOptions != null) - { -#pragma warning disable SYSLIB0040 // NoEncryption and AllowNoEncryption are obsolete - if (serverAuthenticationOptions.EncryptionPolicy == EncryptionPolicy.NoEncryption) - { - throw new PlatformNotSupportedException(SR.Format(SR.net_quic_ssl_option, nameof(serverAuthenticationOptions.EncryptionPolicy))); - } -#pragma warning restore SYSLIB0040 - - if (serverAuthenticationOptions.ClientCertificateRequired) - { - flags |= QUIC_CREDENTIAL_FLAGS.REQUIRE_CLIENT_AUTHENTICATION | QUIC_CREDENTIAL_FLAGS.INDICATE_CERTIFICATE_RECEIVED | QUIC_CREDENTIAL_FLAGS.NO_CERTIFICATE_VALIDATION; - } - - if (certificate == null && serverAuthenticationOptions?.ServerCertificateSelectionCallback != null && targetHost != null) - { - certificate = serverAuthenticationOptions.ServerCertificateSelectionCallback(options, targetHost); - } - } - - return Create(options, flags, certificate, serverAuthenticationOptions?.ServerCertificateContext, serverAuthenticationOptions?.ApplicationProtocols, serverAuthenticationOptions?.CipherSuitesPolicy); - } - - // TODO: this is called from MsQuicListener and when it fails it wreaks havoc in MsQuicListener finalizer. - // Consider moving bigger logic like this outside of constructor call chains. - private static unsafe SafeMsQuicConfigurationHandle Create(QuicConnectionOptions options, QUIC_CREDENTIAL_FLAGS flags, X509Certificate? certificate, SslStreamCertificateContext? certificateContext, List? alpnProtocols, CipherSuitesPolicy? cipherSuitesPolicy) - { - // TODO: some of these checks should be done by the QuicOptions type. - if (alpnProtocols == null || alpnProtocols.Count == 0) - { - throw new Exception("At least one SslApplicationProtocol value must be present in SslClientAuthenticationOptions or SslServerAuthenticationOptions."); - } - - if (options.MaxBidirectionalStreams > ushort.MaxValue) - { - throw new Exception("MaxBidirectionalStreams overflow."); - } - - if (options.MaxBidirectionalStreams > ushort.MaxValue) - { - throw new Exception("MaxBidirectionalStreams overflow."); - } - - bool isServer = (flags & QUIC_CREDENTIAL_FLAGS.CLIENT) == 0; - if (isServer) - { - if (certificate == null && certificateContext == null) - { - throw new Exception("Server must provide certificate"); - } - } - else - { - flags |= QUIC_CREDENTIAL_FLAGS.INDICATE_CERTIFICATE_RECEIVED | QUIC_CREDENTIAL_FLAGS.NO_CERTIFICATE_VALIDATION; - } - - if (!OperatingSystem.IsWindows()) - { - // Use certificate handles on Windows, fall-back to ASN1 otherwise. - flags |= QUIC_CREDENTIAL_FLAGS.USE_PORTABLE_CERTIFICATES; - } - - Debug.Assert(!MsQuicApi.Api.Registration.IsInvalid); - - QUIC_SETTINGS settings = default(QUIC_SETTINGS); - settings.IsSet.PeerUnidiStreamCount = 1; - settings.PeerUnidiStreamCount = (ushort)options.MaxUnidirectionalStreams; - settings.IsSet.PeerBidiStreamCount = 1; - settings.PeerBidiStreamCount = (ushort)options.MaxBidirectionalStreams; - - if (options.IdleTimeout != TimeSpan.Zero) - { - settings.IsSet.IdleTimeoutMs = 1; - if (options.IdleTimeout != Timeout.InfiniteTimeSpan) - { - if (options.IdleTimeout <= TimeSpan.Zero) throw new Exception("IdleTimeout must not be negative."); - settings.IdleTimeoutMs = (ulong)options.IdleTimeout.TotalMilliseconds; - } - else - { - settings.IdleTimeoutMs = 0; - } - } - - SafeMsQuicConfigurationHandle configurationHandle; - X509Certificate2[]? intermediates = null; - - QUIC_HANDLE* handle; - using var msquicBuffers = new MsQuicBuffers(); - msquicBuffers.Initialize(alpnProtocols, alpnProtocol => alpnProtocol.Protocol); - ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->ConfigurationOpen( - MsQuicApi.Api.Registration.QuicHandle, - msquicBuffers.Buffers, - (uint)alpnProtocols.Count, - &settings, - (uint)sizeof(QUIC_SETTINGS), - (void*)IntPtr.Zero, - &handle), "ConfigurationOpen failed"); - configurationHandle = new SafeMsQuicConfigurationHandle(handle); - - try - { - QUIC_CREDENTIAL_CONFIG config = default; - config.Flags = flags; // TODO: consider using LOAD_ASYNCHRONOUS with a callback. - - if (cipherSuitesPolicy != null) - { - config.Flags |= QUIC_CREDENTIAL_FLAGS.SET_ALLOWED_CIPHER_SUITES; - config.AllowedCipherSuites = CipherSuitePolicyToFlags(cipherSuitesPolicy); - } - - if (certificateContext != null) - { - certificate = certificateContext.Certificate; - intermediates = certificateContext.IntermediateCertificates; - } - - int status; - if (certificate != null) - { - if (OperatingSystem.IsWindows()) - { - config.Type = QUIC_CREDENTIAL_TYPE.CERTIFICATE_CONTEXT; - config.CertificateContext = (void*)certificate.Handle; - status = MsQuicApi.Api.ApiTable->ConfigurationLoadCredential(configurationHandle.QuicHandle, &config); - } - else - { - byte[] asn1; - - if (intermediates?.Length > 0) - { - X509Certificate2Collection collection = new X509Certificate2Collection(); - collection.Add(certificate); - for (int i = 0; i < intermediates?.Length; i++) - { - collection.Add(intermediates[i]); - } - - asn1 = collection.Export(X509ContentType.Pkcs12)!; - } - else - { - asn1 = certificate.Export(X509ContentType.Pkcs12); - } - - fixed (byte* ptr = asn1) - { - QUIC_CERTIFICATE_PKCS12 pkcs12Config = new QUIC_CERTIFICATE_PKCS12 - { - Asn1Blob = ptr, - Asn1BlobLength = (uint)asn1.Length, - PrivateKeyPassword = (sbyte*)IntPtr.Zero - }; - - config.Type = QUIC_CREDENTIAL_TYPE.CERTIFICATE_PKCS12; - config.CertificatePkcs12 = &pkcs12Config; - status = MsQuicApi.Api.ApiTable->ConfigurationLoadCredential(configurationHandle.QuicHandle, &config); - } - } - } - else - { - config.Type = QUIC_CREDENTIAL_TYPE.NONE; - status = MsQuicApi.Api.ApiTable->ConfigurationLoadCredential(configurationHandle.QuicHandle, &config); - } - -#if TARGET_WINDOWS - if ((Interop.SECURITY_STATUS)status == Interop.SECURITY_STATUS.AlgorithmMismatch && (isServer ? MsQuicApi.Tls13ServerMayBeDisabled : MsQuicApi.Tls13ClientMayBeDisabled)) - { - throw new PlatformNotSupportedException(SR.net_quic_tls_version_notsupported); - } -#endif - - ThrowHelper.ThrowIfMsQuicError(status, "ConfigurationLoadCredential failed"); - } - catch - { - configurationHandle.Dispose(); - throw; - } - - return configurationHandle; - } - - private static QUIC_ALLOWED_CIPHER_SUITE_FLAGS CipherSuitePolicyToFlags(CipherSuitesPolicy cipherSuitesPolicy) - { - QUIC_ALLOWED_CIPHER_SUITE_FLAGS flags = QUIC_ALLOWED_CIPHER_SUITE_FLAGS.NONE; - - foreach (TlsCipherSuite cipher in cipherSuitesPolicy.AllowedCipherSuites) - { - switch (cipher) - { - case TlsCipherSuite.TLS_AES_128_GCM_SHA256: - flags |= QUIC_ALLOWED_CIPHER_SUITE_FLAGS.AES_128_GCM_SHA256; - break; - case TlsCipherSuite.TLS_AES_256_GCM_SHA384: - flags |= QUIC_ALLOWED_CIPHER_SUITE_FLAGS.AES_256_GCM_SHA384; - break; - case TlsCipherSuite.TLS_CHACHA20_POLY1305_SHA256: - flags |= QUIC_ALLOWED_CIPHER_SUITE_FLAGS.CHACHA20_POLY1305_SHA256; - break; - case TlsCipherSuite.TLS_AES_128_CCM_SHA256: // not supported by MsQuic (yet?), but QUIC RFC allows it so we ignore it. - default: - // ignore - break; - } - } - - if (flags == QUIC_ALLOWED_CIPHER_SUITE_FLAGS.NONE) - { - throw new ArgumentException(SR.net_quic_empty_cipher_suite, nameof(SslClientAuthenticationOptions.CipherSuitesPolicy)); - } - - return flags; - } - } -} diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/SafeMsQuicConnectionHandle.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/SafeMsQuicConnectionHandle.cs deleted file mode 100644 index 5a084d6d4af7d..0000000000000 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/SafeMsQuicConnectionHandle.cs +++ /dev/null @@ -1,14 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using Microsoft.Quic; - -namespace System.Net.Quic.Implementations.MsQuic.Internal -{ - internal sealed class SafeMsQuicConnectionHandle : MsQuicSafeHandle - { - public unsafe SafeMsQuicConnectionHandle(QUIC_HANDLE* handle) - : base(handle, MsQuicApi.Api.ApiTable->ConnectionClose, SafeHandleType.Connection) - { } - } -} diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/SafeMsQuicStreamHandle.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/SafeMsQuicStreamHandle.cs deleted file mode 100644 index d3848b5a9ee1f..0000000000000 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/SafeMsQuicStreamHandle.cs +++ /dev/null @@ -1,14 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using Microsoft.Quic; - -namespace System.Net.Quic.Implementations.MsQuic.Internal -{ - internal sealed class SafeMsQuicStreamHandle : MsQuicSafeHandle - { - public unsafe SafeMsQuicStreamHandle(QUIC_HANDLE* handle) - : base(handle, MsQuicApi.Api.ApiTable->StreamClose, SafeHandleType.Stream) - { } - } -} diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs deleted file mode 100644 index 759c5e1aae71a..0000000000000 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs +++ /dev/null @@ -1,774 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Diagnostics; -using System.Net.Quic.Implementations.MsQuic.Internal; -using System.Net; -using System.Net.Security; -using System.Net.Sockets; -using System.Runtime.CompilerServices; -using System.Runtime.ExceptionServices; -using System.Runtime.InteropServices; -using System.Security.Authentication; -using System.Security.Cryptography; -using System.Security.Cryptography.X509Certificates; -using System.Threading; -using System.Threading.Channels; -using System.Threading.Tasks; -using Microsoft.Quic; -using static Microsoft.Quic.MsQuic; - -namespace System.Net.Quic.Implementations.MsQuic -{ - internal sealed class MsQuicConnection : IDisposable - { - private static readonly Oid s_clientAuthOid = new Oid("1.3.6.1.5.5.7.3.2", "1.3.6.1.5.5.7.3.2"); - private static readonly Oid s_serverAuthOid = new Oid("1.3.6.1.5.5.7.3.1", "1.3.6.1.5.5.7.3.1"); - private const uint DefaultResetValue = 0xffffffff; // Arbitrary value unlikely to conflict with application protocols. - - // TODO: remove this. - // This is only used for client-initiated connections, and isn't needed even then once Connect() has been called. - private SafeMsQuicConfigurationHandle? _configuration; - - private readonly State _state = new State(); - private int _disposed; - - private bool _canAccept; - - private IPEndPoint? _localEndPoint; - private readonly EndPoint _remoteEndPoint; - private SslApplicationProtocol _negotiatedAlpnProtocol; - - internal sealed class State - { - public SafeMsQuicConnectionHandle Handle = null!; // set inside of MsQuicConnection ctor. - - public GCHandle StateGCHandle; - - // These exists to prevent GC of the MsQuicConnection in the middle of an async op (Connect or Shutdown). - public MsQuicConnection? Connection; - public bool ShutdownInProgress; - - public readonly ValueTaskSource ConnectTcs = new ValueTaskSource(); - // TODO: only allocate these when there is an outstanding shutdown. - public readonly TaskCompletionSource ShutdownTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - public long AbortErrorCode = -1; - public int StreamCount; - private bool _closing; - - // Certificate validation properties - public X509Certificate? RemoteCertificate; - public bool RemoteCertificateRequired; - public X509RevocationMode RevocationMode = X509RevocationMode.Offline; - public RemoteCertificateValidationCallback? RemoteCertificateValidationCallback; - public bool IsServer; - public string? TargetHost; - - // Queue for accepted streams. - // Backlog limit is managed by MsQuic so it can be unbounded here. - public readonly Channel AcceptQueue = Channel.CreateUnbounded(new UnboundedChannelOptions() - { - SingleWriter = true, - }); - - public void RemoveStream(MsQuicStream? stream) - { - bool releaseHandles; - lock (this) - { - StreamCount--; - Debug.Assert(StreamCount >= 0); - releaseHandles = _closing && StreamCount == 0; - } - - if (releaseHandles) - { - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"{Handle} releasing handle after last stream."); - Handle?.Dispose(); - } - } - - public bool TryQueueNewStream(SafeMsQuicStreamHandle streamHandle, QUIC_STREAM_OPEN_FLAGS flags) - { - var stream = new MsQuicStream(this, streamHandle, flags); - if (AcceptQueue.Writer.TryWrite(stream)) - { - return true; - } - else - { - stream.Dispose(); - return false; - } - } - - public bool TryAddStream(MsQuicStream stream) - { - lock (this) - { - if (_closing) - { - return false; - } - - StreamCount++; - return true; - } - } - - // This is called under lock from connection dispose - public void SetClosing() - { - lock (this) - { - _closing = true; - } - } - } - - // constructor for inbound connections - internal unsafe MsQuicConnection(QUIC_HANDLE* handle, QUIC_NEW_CONNECTION_INFO* info) - { - _state.Handle = new SafeMsQuicConnectionHandle(handle); - _state.StateGCHandle = GCHandle.Alloc(_state); - _state.IsServer = true; - - try - { - Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)"); - MsQuicApi.Api.ApiTable->SetConnectionCallback(_state.Handle.QuicHandle, &NativeCallback, (void*)GCHandle.ToIntPtr(_state.StateGCHandle)); - } - catch - { - _state.StateGCHandle.Free(); - throw; - } - - _remoteEndPoint = info->RemoteAddress->ToIPEndPoint(); - _localEndPoint = info->LocalAddress->ToIPEndPoint(); - _negotiatedAlpnProtocol = new SslApplicationProtocol(new Span(info->NegotiatedAlpn, info->NegotiatedAlpnLength).ToArray()); - - if (NetEventSource.Log.IsEnabled()) - { - NetEventSource.Info(_state, $"{_state.Handle} Inbound connection created"); - } - } - - // constructor for outbound connections - public unsafe MsQuicConnection(QuicClientConnectionOptions options) - { - ArgumentNullException.ThrowIfNull(options.RemoteEndPoint, nameof(options.RemoteEndPoint)); - - _canAccept = options.MaxBidirectionalStreams > 0 || options.MaxUnidirectionalStreams > 0; - _remoteEndPoint = options.RemoteEndPoint; - _configuration = SafeMsQuicConfigurationHandle.Create(options); - _state.RemoteCertificateRequired = true; - _state.RevocationMode = options.ClientAuthenticationOptions.CertificateRevocationCheckMode; - _state.RemoteCertificateValidationCallback = options.ClientAuthenticationOptions.RemoteCertificateValidationCallback; - _state.TargetHost = options.ClientAuthenticationOptions.TargetHost; - - _state.StateGCHandle = GCHandle.Alloc(_state); - try - { - QUIC_HANDLE* handle; - Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)"); - ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->ConnectionOpen( - MsQuicApi.Api.Registration.QuicHandle, - &NativeCallback, - (void*)GCHandle.ToIntPtr(_state.StateGCHandle), - &handle), "Could not open the connection"); - _state.Handle = new SafeMsQuicConnectionHandle(handle); - } - catch - { - _state.StateGCHandle.Free(); - throw; - } - - if (NetEventSource.Log.IsEnabled()) - { - NetEventSource.Info(_state, $"{_state.Handle} Outbound connection created"); - } - } - - internal IPEndPoint? LocalEndPoint => _localEndPoint; - - internal EndPoint RemoteEndPoint => _remoteEndPoint; - - internal X509Certificate? RemoteCertificate => _state.RemoteCertificate; - - internal SslApplicationProtocol NegotiatedApplicationProtocol => _negotiatedAlpnProtocol; - - internal bool Connected => _state.ConnectTcs.IsCompleted; - - private static unsafe int HandleEventConnected(State state, ref QUIC_CONNECTION_EVENT connectionEvent) - { - if (state.ConnectTcs.IsCompleted) - { - return QUIC_STATUS_SUCCESS; - } - - // Connected will already be true for connections accepted from a listener. - Debug.Assert(!Monitor.IsEntered(state)); - - Debug.Assert(state.Connection != null); - //state.Connection._remoteEndPoint = MsQuicParameterHelpers.GetIPEndPointParam(MsQuicApi.Api, state.Handle, QUIC_PARAM_CONN_REMOTE_ADDRESS); - state.Connection._localEndPoint = MsQuicParameterHelpers.GetIPEndPointParam(MsQuicApi.Api, state.Handle, QUIC_PARAM_CONN_LOCAL_ADDRESS); - state.Connection._negotiatedAlpnProtocol = new SslApplicationProtocol(new Span(connectionEvent.CONNECTED.NegotiatedAlpn, connectionEvent.CONNECTED.NegotiatedAlpnLength).ToArray()); - if (!state.ShutdownInProgress) - { - state.Connection = null; - } - - state.ConnectTcs.TrySetResult(); - - return QUIC_STATUS_SUCCESS; - } - - private static int HandleEventShutdownInitiatedByTransport(State state, ref QUIC_CONNECTION_EVENT connectionEvent) - { - if (!state.ConnectTcs.IsCompleted) - { - Debug.Assert(state.Connection != null); - if (!state.ShutdownInProgress) - { - state.Connection = null; - } - - state.ConnectTcs.TrySetException(ThrowHelper.GetExceptionForMsQuicStatus(connectionEvent.SHUTDOWN_INITIATED_BY_TRANSPORT.Status, "Connection has been shutdown by transport")); - } - - // To throw QuicConnectionAbortedException (instead of QuicOperationAbortedException) out of AcceptStreamAsync() since - // it wasn't our side who shutdown the connection. - // We should rather keep the Status and propagate it either in a different exception or as a different field of QuicConnectionAbortedException. - // See: https://github.com/dotnet/runtime/issues/60133 - state.AbortErrorCode = 0; - state.AcceptQueue.Writer.TryComplete(); - return QUIC_STATUS_SUCCESS; - } - - private static int HandleEventShutdownInitiatedByPeer(State state, ref QUIC_CONNECTION_EVENT connectionEvent) - { - state.AbortErrorCode = (long)connectionEvent.SHUTDOWN_INITIATED_BY_PEER.ErrorCode; - state.AcceptQueue.Writer.TryComplete(); - return QUIC_STATUS_SUCCESS; - } - - private static int HandleEventShutdownComplete(State state, ref QUIC_CONNECTION_EVENT connectionEvent) - { - // This is the final event on the connection, so free the GCHandle used by the event callback. - state.StateGCHandle.Free(); - - state.Connection = null; - state.ShutdownInProgress = false; - - state.ShutdownTcs.SetResult(QUIC_STATUS_SUCCESS); - - // Stop accepting new streams. - state.AcceptQueue.Writer.TryComplete(); - - return QUIC_STATUS_SUCCESS; - } - - private static unsafe int HandleEventNewStream(State state, ref QUIC_CONNECTION_EVENT connectionEvent) - { - var streamHandle = new SafeMsQuicStreamHandle(connectionEvent.PEER_STREAM_STARTED.Stream); - if (!state.TryQueueNewStream(streamHandle, connectionEvent.PEER_STREAM_STARTED.Flags)) - { - // This will call StreamCloseDelegate and free the stream. - // We will return Success to the MsQuic to prevent double free. - streamHandle.Dispose(); - } - - return QUIC_STATUS_SUCCESS; - } - - private static int HandleEventStreamsAvailable(State state, ref QUIC_CONNECTION_EVENT connectionEvent) - { - return QUIC_STATUS_SUCCESS; - } - - private static unsafe int HandleEventPeerCertificateReceived(State state, ref QUIC_CONNECTION_EVENT connectionEvent) - { - SslPolicyErrors sslPolicyErrors = SslPolicyErrors.None; - X509Chain? chain = null; - X509Certificate2? certificate = null; - X509Certificate2Collection? additionalCertificates = null; - IntPtr certificateBuffer = IntPtr.Zero; - int certificateLength = 0; - - try - { - IntPtr certificateHandle = (IntPtr)connectionEvent.PEER_CERTIFICATE_RECEIVED.Certificate; - if (certificateHandle != IntPtr.Zero) - { - if (OperatingSystem.IsWindows()) - { - certificate = new X509Certificate2(certificateHandle); - } - else - { - unsafe - { - QUIC_BUFFER* certBuffer = (QUIC_BUFFER*)certificateHandle; - certificate = new X509Certificate2(new ReadOnlySpan(certBuffer->Buffer, (int)certBuffer->Length)); - certificateBuffer = (IntPtr)certBuffer->Buffer; - certificateLength = (int)certBuffer->Length; - - IntPtr chainHandle = (IntPtr)connectionEvent.PEER_CERTIFICATE_RECEIVED.Chain; - if (chainHandle != IntPtr.Zero) - { - QUIC_BUFFER* chainBuffer = (QUIC_BUFFER*)chainHandle; - if (chainBuffer->Length != 0 && chainBuffer->Buffer != null) - { - additionalCertificates = new X509Certificate2Collection(); - additionalCertificates.Import(new ReadOnlySpan(chainBuffer->Buffer, (int)chainBuffer->Length)); - } - } - } - } - } - - if (certificate == null) - { - if (NetEventSource.Log.IsEnabled() && state.RemoteCertificateRequired) NetEventSource.Error(state, $"{state.Handle} Remote certificate required, but no remote certificate received"); - sslPolicyErrors |= SslPolicyErrors.RemoteCertificateNotAvailable; - } - else - { - chain = new X509Chain(); - chain.ChainPolicy.RevocationMode = state.RevocationMode; - chain.ChainPolicy.RevocationFlag = X509RevocationFlag.ExcludeRoot; - chain.ChainPolicy.ApplicationPolicy.Add(state.IsServer ? s_clientAuthOid : s_serverAuthOid); - - if (additionalCertificates != null && additionalCertificates.Count > 1) - { - chain.ChainPolicy.ExtraStore.AddRange(additionalCertificates); - } - - sslPolicyErrors |= CertificateValidation.BuildChainAndVerifyProperties(chain, certificate, true, state.IsServer, state.TargetHost, certificateBuffer, certificateLength); - } - - if (!state.RemoteCertificateRequired) - { - sslPolicyErrors &= ~SslPolicyErrors.RemoteCertificateNotAvailable; - } - - state.RemoteCertificate = certificate; - - if (state.RemoteCertificateValidationCallback != null) - { - bool success = state.RemoteCertificateValidationCallback(state, certificate, chain, sslPolicyErrors); - // Unset the callback to prevent multiple invocations of the callback per a single connection. - // Return the same value as the custom callback just did. - state.RemoteCertificateValidationCallback = (_, _, _, _) => success; - - if (!success && NetEventSource.Log.IsEnabled()) - NetEventSource.Error(state, $"{state.Handle} Remote certificate rejected by verification callback"); - - if (!success) - { - if (state.IsServer) - { - return QUIC_STATUS_USER_CANCELED; - } - - throw new AuthenticationException(SR.net_quic_cert_custom_validation); - } - - return QUIC_STATUS_SUCCESS; - } - - if (NetEventSource.Log.IsEnabled()) - NetEventSource.Info(state, $"{state.Handle} Certificate validation for '${certificate?.Subject}' finished with ${sslPolicyErrors}"); - - - if (sslPolicyErrors != SslPolicyErrors.None) - { - if (state.IsServer) - { - return QUIC_STATUS_HANDSHAKE_FAILURE; - } - - throw new AuthenticationException(SR.Format(SR.net_quic_cert_chain_validation, sslPolicyErrors)); - } - - return QUIC_STATUS_SUCCESS; - } - catch (Exception ex) - { - if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(state, $"{state.Handle} Certificate validation failed ${ex.Message}"); - throw; - } - } - - internal async ValueTask AcceptStreamAsync(CancellationToken cancellationToken = default) - { - ObjectDisposedException.ThrowIf(_disposed == 1, this); - - if (!_canAccept) - { - throw new InvalidOperationException(SR.net_quic_accept_not_allowed); - } - - MsQuicStream stream; - - try - { - stream = await _state.AcceptQueue.Reader.ReadAsync(cancellationToken).ConfigureAwait(false); - } - catch (ChannelClosedException) - { - throw ThrowHelper.GetConnectionAbortedException(_state.AbortErrorCode); - } - - return stream; - } - - private async ValueTask OpenStreamAsync(QUIC_STREAM_OPEN_FLAGS flags, CancellationToken cancellationToken) - { - ObjectDisposedException.ThrowIf(_disposed == 1, this); - if (!Connected) - { - throw new InvalidOperationException(SR.net_quic_not_connected); - } - - var stream = new MsQuicStream(_state, flags); - - try - { - await stream.StartAsync(cancellationToken).ConfigureAwait(false); - } - catch - { - stream.Dispose(); - throw; - } - - return stream; - } - - internal ValueTask OpenUnidirectionalStreamAsync(CancellationToken cancellationToken = default) - => OpenStreamAsync(QUIC_STREAM_OPEN_FLAGS.UNIDIRECTIONAL, cancellationToken); - internal ValueTask OpenBidirectionalStreamAsync(CancellationToken cancellationToken = default) - => OpenStreamAsync(QUIC_STREAM_OPEN_FLAGS.NONE, cancellationToken); - - internal int GetRemoteAvailableUnidirectionalStreamCount() - { - Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)"); - return MsQuicParameterHelpers.GetUShortParam(MsQuicApi.Api, _state.Handle, QUIC_PARAM_CONN_LOCAL_UNIDI_STREAM_COUNT); - } - - internal int GetRemoteAvailableBidirectionalStreamCount() - { - Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)"); - return MsQuicParameterHelpers.GetUShortParam(MsQuicApi.Api, _state.Handle, QUIC_PARAM_CONN_LOCAL_BIDI_STREAM_COUNT); - } - - internal unsafe ValueTask ConnectAsync(CancellationToken cancellationToken = default) - { - ObjectDisposedException.ThrowIf(_disposed == 1, this); - - if (_configuration is null) - { - throw new InvalidOperationException($"{nameof(ConnectAsync)} must not be called on a connection obtained from a listener."); - } - - ushort af = _remoteEndPoint.AddressFamily switch - { - AddressFamily.Unspecified => (ushort)QUIC_ADDRESS_FAMILY_UNSPEC, - AddressFamily.InterNetwork => (ushort)QUIC_ADDRESS_FAMILY_INET, - AddressFamily.InterNetworkV6 => (ushort)QUIC_ADDRESS_FAMILY_INET6, - _ => throw new ArgumentException(SR.Format(SR.net_quic_unsupported_address_family, _remoteEndPoint.AddressFamily)) - }; - - Debug.Assert(_state.StateGCHandle.IsAllocated); - - if (_state.ConnectTcs.TryInitialize(out ValueTask valueTask, cancellationToken: cancellationToken)) - { - _state.Connection = this; - string targetHost; - int port; - - if (_remoteEndPoint is IPEndPoint ipEndPoint) - { - Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)"); - MsQuicParameterHelpers.SetIPEndPointParam(MsQuicApi.Api, _state.Handle, QUIC_PARAM_CONN_REMOTE_ADDRESS, ipEndPoint); - targetHost = _state.TargetHost ?? ((IPEndPoint)_remoteEndPoint).Address.ToString(); - port = ((IPEndPoint)_remoteEndPoint).Port; - - } - else if (_remoteEndPoint is DnsEndPoint dnsEndPoint) - { - port = dnsEndPoint.Port; - string dnsHost = dnsEndPoint.Host!; - - // We don't have way how to set separate SNI and name for connection at this moment. - // If the name is actually IP address we can use it to make at least some cases work for people - // who want to bypass DNS but connect to specific virtual host. - if (!dnsHost.Equals(_state.TargetHost, StringComparison.InvariantCultureIgnoreCase) && !string.IsNullOrEmpty(_state.TargetHost)) - { - targetHost = _state.TargetHost!; - if (IPAddress.TryParse(dnsHost, out IPAddress? address)) - { - // This is form of IPAddress and _state.TargetHost is set to different string - Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)"); - MsQuicParameterHelpers.SetIPEndPointParam(MsQuicApi.Api, _state.Handle, QUIC_PARAM_CONN_REMOTE_ADDRESS, new IPEndPoint(address, port)); - } - else - { - IPAddress[] addresses = Dns.GetHostAddressesAsync(dnsHost, cancellationToken).GetAwaiter().GetResult(); - cancellationToken.ThrowIfCancellationRequested(); - if (addresses.Length == 0) - { - throw new SocketException((int)SocketError.HostNotFound); - } - Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)"); - // We can do something better than just using first IP but that is what - // MsQuic does today anyway. - MsQuicParameterHelpers.SetIPEndPointParam(MsQuicApi.Api, _state.Handle, QUIC_PARAM_CONN_REMOTE_ADDRESS, new IPEndPoint(addresses[0], port)); - } - } - else - { - // We defer everything to MsQuic. - targetHost = dnsHost; - } - } - else - { - throw new ArgumentException($"Unsupported remote endpoint type '{_remoteEndPoint.GetType()}'."); - } - - IntPtr pTargetHost = Marshal.StringToCoTaskMemAnsi(targetHost); - try - { - Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)"); - ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->ConnectionStart( - _state.Handle.QuicHandle, - _configuration.QuicHandle, - af, - (sbyte*)pTargetHost, - (ushort)port), "Failed to connect to peer"); - - // this handle is ref counted by MsQuic, so safe to dispose here. - _configuration.Dispose(); - _configuration = null; - } - catch - { - _state.Connection = null; - throw; - } - finally - { - Marshal.FreeCoTaskMem(pTargetHost); - } - } - - return valueTask; - } - - internal unsafe ValueTask FinishHandshakeAsync(QuicServerConnectionOptions options, string? targetHost, CancellationToken cancellationToken = default) - { - ObjectDisposedException.ThrowIf(_disposed == 1, this); - - if (_state.ConnectTcs.TryInitialize(out var valueTask, this, cancellationToken)) - { - _canAccept = options.MaxBidirectionalStreams > 0 || options.MaxUnidirectionalStreams > 0; - _state.Connection = this; - try - { - _state.RemoteCertificateRequired = options.ServerAuthenticationOptions.ClientCertificateRequired; - _state.RevocationMode = options.ServerAuthenticationOptions.CertificateRevocationCheckMode; - _state.RemoteCertificateValidationCallback = options.ServerAuthenticationOptions.RemoteCertificateValidationCallback; - _configuration = SafeMsQuicConfigurationHandle.Create(options, options.ServerAuthenticationOptions, targetHost); - ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->ConnectionSetConfiguration( - _state.Handle.QuicHandle, - _configuration.QuicHandle)); - } - catch - { - _state.Connection = null; - throw; - } - } - - return valueTask; - } - - private unsafe ValueTask ShutdownAsync( - QUIC_CONNECTION_SHUTDOWN_FLAGS Flags, - long ErrorCode) - { - // Store the connection into the GCHandle'd state to prevent GC if user calls ShutdownAsync and gets rid of all references to the MsQuicConnection. - _state.ShutdownInProgress = true; - _state.Connection = this; - - try - { - Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)"); - MsQuicApi.Api.ApiTable->ConnectionShutdown( - _state.Handle.QuicHandle, - Flags, - (ulong)ErrorCode); - } - catch - { - _state.ShutdownInProgress = false; - _state.Connection = null; - throw; - } - - return new ValueTask(_state.ShutdownTcs.Task); - } - -#pragma warning disable CS3016 - [UnmanagedCallersOnly(CallConvs = new Type[] { typeof(CallConvCdecl) })] -#pragma warning restore CS3016 - private static unsafe int NativeCallback(QUIC_HANDLE* connection, void* context, QUIC_CONNECTION_EVENT* connectionEvent) - { - GCHandle gcHandle = GCHandle.FromIntPtr((IntPtr)context); - Debug.Assert(gcHandle.IsAllocated); - Debug.Assert(gcHandle.Target is not null); - var state = (State)gcHandle.Target; - - if (NetEventSource.Log.IsEnabled()) - { - NetEventSource.Info(state, $"{state.Handle} Connection received event {connectionEvent->Type}"); - } - - try - { - switch (connectionEvent->Type) - { - case QUIC_CONNECTION_EVENT_TYPE.CONNECTED: - return HandleEventConnected(state, ref *connectionEvent); - case QUIC_CONNECTION_EVENT_TYPE.SHUTDOWN_INITIATED_BY_TRANSPORT: - return HandleEventShutdownInitiatedByTransport(state, ref *connectionEvent); - case QUIC_CONNECTION_EVENT_TYPE.SHUTDOWN_INITIATED_BY_PEER: - return HandleEventShutdownInitiatedByPeer(state, ref *connectionEvent); - case QUIC_CONNECTION_EVENT_TYPE.SHUTDOWN_COMPLETE: - return HandleEventShutdownComplete(state, ref *connectionEvent); - case QUIC_CONNECTION_EVENT_TYPE.PEER_STREAM_STARTED: - return HandleEventNewStream(state, ref *connectionEvent); - case QUIC_CONNECTION_EVENT_TYPE.STREAMS_AVAILABLE: - return HandleEventStreamsAvailable(state, ref *connectionEvent); - case QUIC_CONNECTION_EVENT_TYPE.PEER_CERTIFICATE_RECEIVED: - return HandleEventPeerCertificateReceived(state, ref *connectionEvent); - default: - return QUIC_STATUS_SUCCESS; - } - } - catch (Exception ex) - { - if (NetEventSource.Log.IsEnabled()) - { - NetEventSource.Error(state, $"{state.Handle} Exception occurred during handling {connectionEvent->Type} connection callback: {ex}"); - } - - if (!state.ConnectTcs.IsCompleted) - { - // This is opportunistic if we get exception and have ability to propagate it to caller. - state.ConnectTcs.TrySetException(ex); - state.Connection = null; - } - else - { - Debug.Fail($"{state.Handle} Exception occurred during handling {connectionEvent->Type} connection callback: {ex}"); - } - - // TODO: trigger an exception on any outstanding async calls. - return QUIC_STATUS_INTERNAL_ERROR; - } - } - - public void Dispose() - { - Dispose(true); - GC.SuppressFinalize(this); - } - - ~MsQuicConnection() - { - Dispose(false); - } - - private async Task FlushAcceptQueue() - { - _state.AcceptQueue.Writer.TryComplete(); - await foreach (MsQuicStream stream in _state.AcceptQueue.Reader.ReadAllAsync().ConfigureAwait(false)) - { - if (stream.CanRead) - { - stream.AbortRead(DefaultResetValue); - } - if (stream.CanWrite) - { - stream.AbortWrite(DefaultResetValue); - } - stream.Dispose(); - } - } - - private unsafe void Dispose(bool disposing) - { - int disposed = Interlocked.Exchange(ref _disposed, 1); - if (disposed != 0) - { - return; - } - - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(_state, $"{_state.Handle} Connection disposing {disposing}"); - - // If we haven't already shutdown gracefully (via a successful CloseAsync call), then force an abortive shutdown. - if (_state.Handle != null && !_state.Handle.IsInvalid && !_state.Handle.IsClosed) - { - // Handle can be null if outbound constructor failed and we are called from finalizer. - Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)"); - MsQuicApi.Api.ApiTable->ConnectionShutdown( - _state.Handle.QuicHandle, - QUIC_CONNECTION_SHUTDOWN_FLAGS.SILENT, - 0); - } - - bool releaseHandles = false; - lock (_state) - { - _state.Connection = null; - if (_state.StreamCount == 0) - { - releaseHandles = true; - } - else - { - // We have pending streams so we need to defer cleanup until last one is gone. - _state.SetClosing(); - } - } - - FlushAcceptQueue().GetAwaiter().GetResult(); - _configuration?.Dispose(); - if (releaseHandles) - { - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(_state, $"{_state.Handle} Connection releasing handle"); - - // We may not be fully initialized if constructor fails. - _state.Handle?.Dispose(); - } - } - - // TODO: this appears abortive and will cause prior successfully shutdown and closed streams to drop data. - // It's unclear how to gracefully wait for a connection to be 100% done. - internal ValueTask CloseAsync(long errorCode, CancellationToken cancellationToken = default) - { - if (_disposed == 1) - { - return default; - } - - return ShutdownAsync(QUIC_CONNECTION_SHUTDOWN_FLAGS.NONE, errorCode); - } - } -} diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs deleted file mode 100644 index 485ebc7d59d6b..0000000000000 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs +++ /dev/null @@ -1,1694 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Buffers; -using System.Diagnostics; -using System.IO; -using System.Net.Quic.Implementations.MsQuic.Internal; -using System.Runtime.CompilerServices; -using System.Runtime.ExceptionServices; -using System.Runtime.InteropServices; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.Quic; -using static Microsoft.Quic.MsQuic; - -namespace System.Net.Quic.Implementations.MsQuic -{ - internal sealed class MsQuicStream : IAsyncDisposable, IDisposable - { - // The state is passed to msquic and then it's passed back by msquic to the callback handler. - private readonly State _state = new State(); - - private readonly bool _canRead; - private readonly bool _canWrite; - - private int _disposed; - - private sealed class State - { - public SafeMsQuicStreamHandle Handle = null!; // set in ctor. - // Roots the state in GC and it won't get collected while this exist. - // It must be kept alive until we receive SHUTDOWN_COMPLETE event - public GCHandle StateGCHandle; - - public long StreamId = -1; - - public MsQuicStream? Stream; // roots the stream in the pinned state to prevent GC during an async read I/O. - public MsQuicConnection.State ConnectionState = null!; // set in ctor. - - public ReadState ReadState; - - // set when ReadState.Aborted: - public long ReadErrorCode = -1; - - // filled when ReadState.BuffersAvailable: - public QUIC_BUFFER[] ReceiveQuicBuffers = Array.Empty(); - public int ReceiveQuicBuffersCount; - public int ReceiveQuicBuffersTotalBytes; - public bool ReceiveIsFinal; - - // set when ReadState.PendingRead: - public Memory ReceiveUserBuffer; - public CancellationTokenRegistration ReceiveCancellationRegistration; - // Resettable completions to be used for multiple calls to receive. - public readonly ResettableCompletionSource ReceiveResettableCompletionSource = new ResettableCompletionSource(); - - public SendState SendState; - public long SendErrorCode = -1; - - public MsQuicBuffers SendBuffers; - - // Resettable completions to be used for multiple calls to send. - public readonly ResettableCompletionSource SendResettableCompletionSource = new ResettableCompletionSource(); - - public ShutdownWriteState ShutdownWriteState; - - // Set once writes have been shutdown. - public readonly TaskCompletionSource ShutdownWriteCompletionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - // Set once stream has been started and within peer's advertised stream limits - public readonly TaskCompletionSource StartCompletionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - public ShutdownState ShutdownState; - - // The value makes sure that we release the handles only once. - public int ShutdownDone; - public const int ShutdownDone_Disposed = 1; - public const int ShutdownDone_NotificationReceived = 2; - - // Set once stream have been shutdown. - public readonly TaskCompletionSource ShutdownCompletionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - public State() - { - SendBuffers = new MsQuicBuffers(); - } - - public void Cleanup() - { - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"{Handle} releasing handles."); - - ShutdownState = ShutdownState.Finished; - CleanupSendState(this); - Handle?.Dispose(); - SendBuffers.Dispose(); - if (StateGCHandle.IsAllocated) StateGCHandle.Free(); - ConnectionState?.RemoveStream(null); - } - } - - // inbound. - internal unsafe MsQuicStream(MsQuicConnection.State connectionState, SafeMsQuicStreamHandle streamHandle, QUIC_STREAM_OPEN_FLAGS flags) - { - if (!connectionState.TryAddStream(this)) - { - throw new ObjectDisposedException(nameof(QuicConnection)); - } - // this assignment should be done before SetCallbackHandlerDelegate to prevent NRE in HandleEventConnectionClose - // but after TryAddStream to prevent unnecessary RemoveStream in finalizer - _state.ConnectionState = connectionState; - - // Inbound streams are already started - _state.StartCompletionSource.SetResult(); - _state.Handle = streamHandle; - _state.StreamId = GetStreamId(streamHandle); - - _canRead = true; - _canWrite = !flags.HasFlag(QUIC_STREAM_OPEN_FLAGS.UNIDIRECTIONAL); - if (!_canWrite) - { - _state.SendState = SendState.Closed; - } - - _state.StateGCHandle = GCHandle.Alloc(_state); - try - { - Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)"); - MsQuicApi.Api.ApiTable->SetStreamCallback(_state.Handle.QuicHandle, &NativeCallback, (void*)GCHandle.ToIntPtr(_state.StateGCHandle)); - } - catch - { - _state.StateGCHandle.Free(); - // don't free the streamHandle, it will be freed by the caller - throw; - } - - if (NetEventSource.Log.IsEnabled()) - { - NetEventSource.Info( - _state, - $"{_state.Handle} Inbound {(flags.HasFlag(QUIC_STREAM_OPEN_FLAGS.UNIDIRECTIONAL) ? "uni" : "bi")}directional stream created " + - $"in connection {_state.ConnectionState.Handle} with StreamId {_state.StreamId}."); - } - } - - // outbound. - internal unsafe MsQuicStream(MsQuicConnection.State connectionState, QUIC_STREAM_OPEN_FLAGS flags) - { - Debug.Assert(connectionState.Handle != null); - - if (!connectionState.TryAddStream(this)) - { - throw new ObjectDisposedException(nameof(QuicConnection)); - } - // this assignment should be done before StreamOpenDelegate to prevent NRE in HandleEventConnectionClose - // but after TryAddStream to prevent unnecessary RemoveStream in finalizer - _state.ConnectionState = connectionState; - - _canRead = !flags.HasFlag(QUIC_STREAM_OPEN_FLAGS.UNIDIRECTIONAL); - _canWrite = true; - - _state.StateGCHandle = GCHandle.Alloc(_state); - if (!_canRead) - { - _state.ReadState = ReadState.Closed; - } - - try - { - QUIC_HANDLE* handle; - Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)"); - int status = MsQuicApi.Api.ApiTable->StreamOpen( - connectionState.Handle.QuicHandle, - flags, - &NativeCallback, - (void*)GCHandle.ToIntPtr(_state.StateGCHandle), - &handle); - - if (status == QUIC_STATUS_ABORTED) - { - // connection already aborted by peer, throw relevant exception - throw ThrowHelper.GetConnectionAbortedException(connectionState.AbortErrorCode); - } - - ThrowHelper.ThrowIfMsQuicError(status, "Failed to open stream to peer"); - _state.Handle = new SafeMsQuicStreamHandle(handle); - } - catch - { - _state.Handle?.Dispose(); - _state.StateGCHandle.Free(); - throw; - } - - if (NetEventSource.Log.IsEnabled()) - { - NetEventSource.Info( - _state, - $"{_state.Handle} Outbound {(flags.HasFlag(QUIC_STREAM_OPEN_FLAGS.UNIDIRECTIONAL) ? "uni" : "bi")}directional stream created " + - $"in connection {_state.ConnectionState.Handle}."); - } - } - - internal bool CanRead => _disposed == 0 && _canRead; - - internal bool CanWrite => _disposed == 0 && _canWrite; - - internal bool ReadsCompleted => _state.ReadState == ReadState.ReadsCompleted; - -#pragma warning disable CA1822 - internal bool CanTimeout => true; -#pragma warning restore CA1822 - - private int _readTimeout = Timeout.Infinite; - - internal int ReadTimeout - { - get - { - ThrowIfDisposed(); - return _readTimeout; - } - set - { - ThrowIfDisposed(); - if (value <= 0 && value != System.Threading.Timeout.Infinite) - { - throw new ArgumentOutOfRangeException(nameof(value), SR.net_quic_timeout_use_gt_zero); - } - _readTimeout = value; - } - } - - private int _writeTimeout = Timeout.Infinite; - internal int WriteTimeout - { - get - { - ThrowIfDisposed(); - return _writeTimeout; - } - set - { - ThrowIfDisposed(); - if (value <= 0 && value != System.Threading.Timeout.Infinite) - { - throw new ArgumentOutOfRangeException(nameof(value), SR.net_quic_timeout_use_gt_zero); - } - _writeTimeout = value; - } - } - - internal long StreamId - { - get - { - ThrowIfDisposed(); - Debug.Assert(_state.StreamId != -1); - return _state.StreamId; - } - } - - internal ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) - { - return WriteAsync(buffer, endStream: false, cancellationToken); - } - - internal ValueTask WriteAsync(ReadOnlySequence buffers, CancellationToken cancellationToken = default) - { - return WriteAsync(buffers, endStream: false, cancellationToken); - } - - internal ValueTask WriteAsync(ReadOnlySequence buffers, bool endStream, CancellationToken cancellationToken = default) - { - return WriteAsync(static (state, buffers) => state.SendBuffers.Initialize(buffers), buffers, buffers.IsEmpty, endStream, cancellationToken); - } - - internal ValueTask WriteAsync(ReadOnlyMemory buffer, bool endStream, CancellationToken cancellationToken = default) - { - return WriteAsync(static (state, buffer) => state.SendBuffers.Initialize(buffer), buffer, buffer.IsEmpty, endStream, cancellationToken); - } - - private async ValueTask WriteAsync(Action stateSetup, TBuffer buffer, bool isEmpty, bool endStream, CancellationToken cancellationToken) - { - ThrowIfDisposed(); - - if (cancellationToken.IsCancellationRequested) - { - lock (_state) - { - if (_state.SendState == SendState.None || _state.SendState == SendState.Pending) - { - _state.SendState = SendState.Aborted; - } - } - - throw new OperationCanceledException(cancellationToken); - } - - if (_state.SendState == SendState.Closed) - { - throw new InvalidOperationException(SR.net_quic_writing_notallowed); - } - // Use Volatile.Read to ensure we read the actual SendErrorCode set by the racing callback thread. - if ((SendState)Volatile.Read(ref Unsafe.As(ref _state.SendState)) == SendState.Aborted) - { - if (_state.SendErrorCode != -1) - { - // aborted by peer - throw ThrowHelper.GetStreamAbortedException(_state.SendErrorCode); - } - - // aborted locally - throw ThrowHelper.GetOperationAbortedException(SR.net_quic_writing_aborted); - } - - // if token was already cancelled, this would execute synchronously - using CancellationTokenRegistration registration = cancellationToken.UnsafeRegister(static (s, token) => - { - var state = (State)s!; - bool shouldComplete = false; - - lock (state) - { - if (state.SendState == SendState.None || state.SendState == SendState.Pending) - { - state.SendState = SendState.Aborted; - shouldComplete = true; - } - } - - if (shouldComplete) - { - state.SendResettableCompletionSource.CompleteException( - ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException("Write was canceled", token))); - } - }, _state); - - lock (_state) - { - if (_state.SendState == SendState.Aborted) - { - cancellationToken.ThrowIfCancellationRequested(); - - if (_state.SendErrorCode != -1) - { - // aborted by peer - throw ThrowHelper.GetStreamAbortedException(_state.SendErrorCode); - } - - // aborted locally - throw ThrowHelper.GetOperationAbortedException(SR.net_quic_writing_aborted); - } - if (_state.SendState == SendState.ConnectionClosed) - { - throw GetConnectionAbortedException(_state); - } - - if (_state.SendState == SendState.Pending || _state.SendState == SendState.Finished) - { - throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, "write")); - } - - // Change the state in the same lock where we check for final states to prevent coming back from Aborted/ConnectionClosed. - Debug.Assert(_state.SendState != SendState.Pending); - _state.SendState = isEmpty ? SendState.Finished : SendState.Pending; - } - - await WriteAsyncCore(stateSetup, buffer, isEmpty, endStream).ConfigureAwait(false); - - lock (_state) - { - if (_state.SendState == SendState.Finished) - { - _state.SendState = SendState.None; - } - } - } - - private unsafe ValueTask WriteAsyncCore(Action stateSetup, TBuffer buffer, bool isEmpty, bool endStream) - { - if (isEmpty) - { - if (endStream) - { - // Start graceful shutdown sequence if passed in the fin flag and there is an empty buffer. - StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, errorCode: 0); - } - return default; - } - - stateSetup(_state, buffer); - - Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)"); - int status = MsQuicApi.Api.ApiTable->StreamSend( - _state.Handle.QuicHandle, - _state.SendBuffers.Buffers, - (uint)_state.SendBuffers.Count, - endStream ? QUIC_SEND_FLAGS.FIN : QUIC_SEND_FLAGS.NONE, - (void*)IntPtr.Zero); - - if (StatusFailed(status)) - { - lock (_state) - { - if (_state.SendState == SendState.Pending) - { - _state.SendState = SendState.Finished; - } - } - - CleanupSendState(_state); - - if (status == QUIC_STATUS_ABORTED) - { - if (_state.SendErrorCode != -1) - { - throw ThrowHelper.GetStreamAbortedException(_state.SendErrorCode); - } - throw ThrowHelper.GetConnectionAbortedException(_state.ConnectionState.AbortErrorCode); - } - ThrowHelper.ThrowIfMsQuicError(status, "Could not send data to peer."); - } - - return _state.SendResettableCompletionSource.GetTypelessValueTask(); - } - - internal async ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken = default) - { - // - // If MsQuic indicated that some data were received (QUIC_STREAM_EVENT_RECEIVE), we use it to complete the request - // synchronously. Otherwise we setup the request to be completed by the HandleEventReceive handler. - // - - ThrowIfDisposed(); - - if (_state.ReadState == ReadState.Closed) - { - throw new InvalidOperationException(SR.net_quic_reading_notallowed); - } - - if (NetEventSource.Log.IsEnabled()) - { - NetEventSource.Info(_state, $"{_state.Handle} Stream reading into Memory of '{destination.Length}' bytes."); - } - - ReadState initialReadState; // value before transitions - long abortError; - bool preCanceled = false; - - int bytesRead = -1; - bool reenableReceive = false; - lock (_state) - { - initialReadState = _state.ReadState; - abortError = _state.ReadErrorCode; - - // Failure scenario: pre-canceled token. Transition: Any non-final -> Aborted - // PendingRead or PendingReadFinished state indicates there is another concurrent read operation in flight - // which is forbidden, so it is handled separately - if (initialReadState != ReadState.PendingRead && initialReadState != ReadState.PendingReadFinished && cancellationToken.IsCancellationRequested) - { - initialReadState = ReadState.Aborted; - CleanupReadStateAndCheckPending(_state, ReadState.Aborted); - preCanceled = true; - } - - // Success scenario: EOS already reached, completing synchronously. No transition (final state) - if (initialReadState == ReadState.ReadsCompleted) - { - return 0; - } - - // Success scenario: no data available yet, will return a task to wait on. Transition None->PendingRead - if (initialReadState == ReadState.None) - { - Debug.Assert(_state.Stream is null); - - _state.ReceiveUserBuffer = destination; - _state.Stream = this; - _state.ReadState = ReadState.PendingRead; - - if (cancellationToken.CanBeCanceled) - { - // Failure scenario: cancellation. Transition: Any non-final -> Aborted - _state.ReceiveCancellationRegistration = cancellationToken.UnsafeRegister(static (obj, token) => - { - var state = (State)obj!; - bool completePendingRead; - lock (state) - { - completePendingRead = CleanupReadStateAndCheckPending(state, ReadState.Aborted); - } - - if (completePendingRead) - { - state.ReceiveResettableCompletionSource.CompleteException(ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException(token))); - } - }, _state); - } - else - { - _state.ReceiveCancellationRegistration = default; - } - } - - // Success scenario: data already available, completing synchronously. - // Transition IndividualReadComplete->None, or IndividualReadComplete->ReadsCompleted, if it was the last message and we fully consumed it - if (initialReadState == ReadState.IndividualReadComplete) - { - _state.ReadState = ReadState.None; - - bytesRead = CopyMsQuicBuffersToUserBuffer(_state.ReceiveQuicBuffers.AsSpan(0, _state.ReceiveQuicBuffersCount), destination.Span); - - if (bytesRead != _state.ReceiveQuicBuffersTotalBytes) - { - // Need to re-enable receives because MsQuic will pause them when we don't consume the entire buffer. - reenableReceive = true; - } - else if (_state.ReceiveIsFinal) - { - // This was a final message and we've consumed everything. We can complete the state without waiting for PEER_SEND_SHUTDOWN - _state.ReadState = ReadState.ReadsCompleted; - } - } - } - - if (initialReadState == ReadState.None) - { - // wait for the incoming data to finish the read. - bytesRead = await _state.ReceiveResettableCompletionSource.GetValueTask().ConfigureAwait(false); - - // Reset the read state - lock (_state) - { - if (_state.ReadState == ReadState.PendingReadFinished) - { - _state.ReadState = ReadState.None; - } - } - - return bytesRead; - } - - // methods below need to be called outside of the lock - if (bytesRead > -1) - { - ReceiveComplete(bytesRead); - - if (reenableReceive) - { - EnableReceive(); - } - - return bytesRead; - } - - // All success scenarios returned at this point. Failure scenarios below: - - Exception? ex = null; - - switch (initialReadState) - { - case ReadState.PendingRead: - case ReadState.PendingReadFinished: - ex = new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, "read")); - break; - case ReadState.Aborted: - ex = preCanceled ? new OperationCanceledException(cancellationToken) : - ThrowHelper.GetStreamAbortedException(abortError); - break; - case ReadState.ConnectionClosed: - default: - Debug.Assert(initialReadState == ReadState.ConnectionClosed, $"{nameof(ReadState)} of '{initialReadState}' is unaccounted for in {nameof(ReadAsync)}."); - ex = GetConnectionAbortedException(_state); - break; - } - - throw ex; - } - - /// The number of bytes copied. - private static unsafe int CopyMsQuicBuffersToUserBuffer(ReadOnlySpan sourceBuffers, Span destinationBuffer) - { - if (sourceBuffers.Length == 0) - { - return 0; - } - - int originalDestinationLength = destinationBuffer.Length; - QUIC_BUFFER nativeBuffer; - int takeLength; - int i = 0; - - do - { - nativeBuffer = sourceBuffers[i]; - takeLength = Math.Min((int)nativeBuffer.Length, destinationBuffer.Length); - - new Span(nativeBuffer.Buffer, takeLength).CopyTo(destinationBuffer); - destinationBuffer = destinationBuffer.Slice(takeLength); - } - while (destinationBuffer.Length != 0 && ++i < sourceBuffers.Length); - - return originalDestinationLength - destinationBuffer.Length; - } - - internal void AbortRead(long errorCode) - { - if (_disposed == 1) - { - // Dispose called AbortRead already - return; - } - - bool shouldComplete = false; - lock (_state) - { - shouldComplete = CleanupReadStateAndCheckPending(_state, ReadState.Aborted); - } - - if (shouldComplete) - { - _state.ReceiveResettableCompletionSource.CompleteException( - ExceptionDispatchInfo.SetCurrentStackTrace(ThrowHelper.GetOperationAbortedException(SR.net_quic_reading_aborted))); - } - - StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_RECEIVE, errorCode); - } - - internal void AbortWrite(long errorCode) - { - if (_disposed == 1) - { - // Dispose already triggered graceful shutdown - // It is unsafe to try to trigger abortive shutdown now, because final event arriving after Dispose releases SafeHandle - // so if it arrives after our check but before we call msquic, me might end up with access violation - return; - } - - bool shouldComplete = false; - bool shouldCompleteSends = false; - - lock (_state) - { - if (_state.SendState == SendState.None || _state.SendState == SendState.Pending) - { - shouldCompleteSends = true; - } - - if (_state.SendState < SendState.Aborted) - { - _state.SendState = SendState.Aborted; - } - - if (_state.ShutdownWriteState == ShutdownWriteState.None) - { - _state.ShutdownWriteState = ShutdownWriteState.Canceled; - shouldComplete = true; - } - } - - if (shouldComplete) - { - _state.ShutdownWriteCompletionSource.SetException( - ExceptionDispatchInfo.SetCurrentStackTrace(ThrowHelper.GetOperationAbortedException(SR.net_quic_writing_aborted))); - } - - if (shouldCompleteSends) - { - _state.SendResettableCompletionSource.CompleteException( - ExceptionDispatchInfo.SetCurrentStackTrace(ThrowHelper.GetOperationAbortedException(SR.net_quic_writing_aborted))); - } - - StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_SEND, errorCode); - } - - private unsafe void StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS flags, long errorCode) - { - Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)"); - ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->StreamShutdown( - _state.Handle.QuicHandle, - flags, - (uint)errorCode), "StreamShutdown failed"); - } - - internal async ValueTask ShutdownCompleted(CancellationToken cancellationToken = default) - { - ThrowIfDisposed(); - - lock (_state) - { - if (_state.ShutdownState == ShutdownState.ConnectionClosed) - { - throw GetConnectionAbortedException(_state); - } - } - - using CancellationTokenRegistration registration = cancellationToken.UnsafeRegister(static (s, token) => - { - var state = (State)s!; - bool shouldComplete = false; - lock (state) - { - if (state.ShutdownState == ShutdownState.None) - { - state.ShutdownState = ShutdownState.Canceled; - shouldComplete = true; - } - } - - if (shouldComplete) - { - state.ShutdownCompletionSource.SetException( - ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException("Wait for shutdown was canceled", token))); - } - }, _state); - - await _state.ShutdownCompletionSource.Task.ConfigureAwait(false); - } - - internal ValueTask WaitForWriteCompletionAsync(CancellationToken cancellationToken = default) - { - // TODO: What should happen if this is called for a unidirectional stream and there are no writes? - - ThrowIfDisposed(); - - lock (_state) - { - if (_state.ShutdownWriteState == ShutdownWriteState.ConnectionClosed) - { - throw GetConnectionAbortedException(_state); - } - } - - return new ValueTask(_state.ShutdownWriteCompletionSource.Task.WaitAsync(cancellationToken)); - } - - internal void Shutdown() - { - ThrowIfDisposed(); - - lock (_state) - { - if (_state.SendState < SendState.Finished) - { - _state.SendState = SendState.Finished; - } - } - - // it is ok to send shutdown several times, MsQuic will ignore it - StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, errorCode: 0); - } - - // TODO consider removing sync-over-async with blocking calls. - internal int Read(Span buffer) - { - ThrowIfDisposed(); - byte[] rentedBuffer = ArrayPool.Shared.Rent(buffer.Length); - CancellationTokenSource? cts = null; - try - { - if (_readTimeout > 0) - { - cts = new CancellationTokenSource(_readTimeout); - } - int readLength = ReadAsync(new Memory(rentedBuffer, 0, buffer.Length), cts != null ? cts.Token : default).AsTask().GetAwaiter().GetResult(); - rentedBuffer.AsSpan(0, readLength).CopyTo(buffer); - return readLength; - } - catch (OperationCanceledException) when (cts != null && cts.IsCancellationRequested) - { - // sync operations do not have Cancellation - throw new IOException(SR.net_quic_timeout); - } - finally - { - ArrayPool.Shared.Return(rentedBuffer); - cts?.Dispose(); - } - } - - internal void Write(ReadOnlySpan buffer) - { - ThrowIfDisposed(); - CancellationTokenSource? cts = null; - - - if (_writeTimeout > 0) - { - cts = new CancellationTokenSource(_writeTimeout); - } - - // TODO: optimize this. - try - { - WriteAsync(buffer.ToArray()).AsTask().GetAwaiter().GetResult(); - } - catch (OperationCanceledException) when (cts != null && cts.IsCancellationRequested) - { - // sync operations do not have Cancellation - throw new IOException(SR.net_quic_timeout); - } - finally - { - cts?.Dispose(); - } - } - - // MsQuic doesn't support explicit flushing - internal void Flush() - { - ThrowIfDisposed(); - } - - // MsQuic doesn't support explicit flushing - internal Task FlushAsync(CancellationToken cancellationToken = default) - { - ThrowIfDisposed(); - - return Task.CompletedTask; - } - - public ValueTask DisposeAsync() - { - // TODO: perform a graceful shutdown and wait for completion? - - Dispose(true); - return default; - } - - public void Dispose() - { - Dispose(true); - GC.SuppressFinalize(this); - } - - ~MsQuicStream() - { - Dispose(false); - } - - private void Dispose(bool disposing) - { - int disposed = Interlocked.Exchange(ref _disposed, 1); - if (disposed != 0) - { - return; - } - - - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(_state, $"{_state.Handle} Stream disposing {disposing}"); - - bool callShutdown = false; - bool abortRead = false; - bool completeRead = false; - lock (_state) - { - if (_state.SendState < SendState.Aborted) - { - callShutdown = true; - } - - // We can enter Aborted state from both AbortRead call (aborts on the wire) and a Cancellation callback (only changes state) - // We need to ensure read is aborted on the wire here. We let msquic handle a second call to abort as a no-op - if (_state.ReadState < ReadState.ReadsCompleted || _state.ReadState == ReadState.Aborted) - { - abortRead = true; - completeRead = CleanupReadStateAndCheckPending(_state, ReadState.Aborted); - } - - if (_state.ShutdownState == ShutdownState.None) - { - _state.ShutdownState = ShutdownState.Pending; - } - } - - if (_state.Handle != null && !_state.Handle.IsInvalid && !_state.Handle.IsClosed) - { - if (callShutdown) - { - try - { - // Handle race condition when stream can be closed handling SHUTDOWN_COMPLETE. - StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, errorCode: 0); - } - catch (ObjectDisposedException) { }; - } - - if (abortRead) - { - try - { - // TODO: error code used here MUST be specified by the application layer - StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_RECEIVE, 0xffffffff); - } - catch (ObjectDisposedException) { }; - } - } - - if (completeRead) - { - _state.ReceiveResettableCompletionSource.CompleteException( - ExceptionDispatchInfo.SetCurrentStackTrace(ThrowHelper.GetOperationAbortedException())); - } - - - // Check if we already got final event. - bool releaseHandles = Interlocked.Exchange(ref _state.ShutdownDone, State.ShutdownDone_Disposed) == State.ShutdownDone_NotificationReceived; - if (releaseHandles) - { - _state.Cleanup(); - } - - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(_state, $"{_state.Handle} Stream disposed"); - } - - private unsafe void EnableReceive() - { - Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)"); - ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->StreamReceiveSetEnabled(_state.Handle.QuicHandle, 1), "StreamReceiveSetEnabled failed"); - } - - /// - /// Callback calls for a single instance of a stream are serialized by msquic. - /// They happen on a msquic thread and shouldn't take too long to not to block msquic. - /// -#pragma warning disable CS3016 - [UnmanagedCallersOnly(CallConvs = new Type[] { typeof(CallConvCdecl) })] -#pragma warning restore CS3016 - private static unsafe int NativeCallback(QUIC_HANDLE* stream, void* context, QUIC_STREAM_EVENT* streamEvent) - { - GCHandle gcHandle = GCHandle.FromIntPtr((IntPtr)context); - Debug.Assert(gcHandle.IsAllocated); - Debug.Assert(gcHandle.Target is not null); - var state = (State)gcHandle.Target; - - if (NetEventSource.Log.IsEnabled()) - { - NetEventSource.Info(state, $"{state.Handle} Stream received event {streamEvent->Type}"); - } - - try - { - switch (streamEvent->Type) - { - // Stream has started. - // Will only be done for outbound streams (inbound streams have already started) - case QUIC_STREAM_EVENT_TYPE.START_COMPLETE: - return HandleEventStartComplete(state, ref *streamEvent); - // Received data on the stream - case QUIC_STREAM_EVENT_TYPE.RECEIVE: - return HandleEventReceive(state, ref *streamEvent); - // Send has completed. - // Contains a canceled bool to indicate if the send was canceled. - case QUIC_STREAM_EVENT_TYPE.SEND_COMPLETE: - return HandleEventSendComplete(state, ref *streamEvent); - // Peer has told us to shutdown the reading side of the stream. - case QUIC_STREAM_EVENT_TYPE.PEER_SEND_SHUTDOWN: - return HandleEventPeerSendShutdown(state); - // Peer has told us to abort the reading side of the stream. - case QUIC_STREAM_EVENT_TYPE.PEER_SEND_ABORTED: - return HandleEventPeerSendAborted(state, ref *streamEvent); - // Peer has stopped receiving data, don't send anymore. - case QUIC_STREAM_EVENT_TYPE.PEER_RECEIVE_ABORTED: - return HandleEventPeerRecvAborted(state, ref *streamEvent); - // Occurs when shutdown is completed for the send side. - // This only happens for shutdown on sending, not receiving - // Receive shutdown can only be abortive. - case QUIC_STREAM_EVENT_TYPE.SEND_SHUTDOWN_COMPLETE: - return HandleEventSendShutdownComplete(state, ref *streamEvent); - // Shutdown for both sending and receiving is completed. - case QUIC_STREAM_EVENT_TYPE.SHUTDOWN_COMPLETE: - return HandleEventShutdownComplete(state, ref *streamEvent); - // Asynchronous open finished, the stream is now within advertised stream limits. - case QUIC_STREAM_EVENT_TYPE.PEER_ACCEPTED: - return HandleEventPeerAccepted(state); - default: - return QUIC_STATUS_SUCCESS; - } - } - catch (Exception ex) - { - if (NetEventSource.Log.IsEnabled()) - { - NetEventSource.Error(state, $"{state.Handle} Exception occurred during handling Stream {streamEvent->Type} event: {ex}"); - } - - Debug.Fail($"{state.Handle} Exception occurred during handling Stream {streamEvent->Type} event: {ex}"); - - return QUIC_STATUS_INTERNAL_ERROR; - } - } - - private static unsafe int HandleEventReceive(State state, ref QUIC_STREAM_EVENT streamEvent) - { - // - // Handle MsQuic QUIC_STREAM_EVENT_RECEIVE event - // - // If there is a pending ReadAsync call, then we complete it. Otherwise we keep a pointer to the received data - // and use it to complete the next ReadAsync operation synchronously. - // - - ref var receiveEvent = ref streamEvent.RECEIVE; - - if (NetEventSource.Log.IsEnabled()) - { - NetEventSource.Info(state, $"{state.Handle} Stream received {receiveEvent.TotalBufferLength} bytes{(receiveEvent.Flags.HasFlag(QUIC_RECEIVE_FLAGS.FIN) ? " with FIN flag" : "")}"); - } - - int readLength; - - bool shouldComplete = false; - lock (state) - { - switch (state.ReadState) - { - // ReadAsync() hasn't been called yet. - case ReadState.None: - // A pending read has just been finished, and this is a second event in a row (before reading thread - // managed to clear the state) - case ReadState.PendingReadFinished: - // Stash the buffer so the next ReadAsync call completes synchronously. - - // We are overwriting state.ReceiveQuicBuffers here even if we only partially consumed them - // and it is intended, because unconsumed data will arrive again from the point we've stopped. - // New RECEIVE event wouldn't come until we call EnableReceive(), and we call it only after we've consumed - // as much as we could and said so to msquic in ReceiveComplete(taken), so new event will have all the - // remaining data. - - if ((uint)state.ReceiveQuicBuffers.Length < receiveEvent.BufferCount) - { - QUIC_BUFFER[] oldReceiveBuffers = state.ReceiveQuicBuffers; - state.ReceiveQuicBuffers = ArrayPool.Shared.Rent((int)receiveEvent.BufferCount); - - if (oldReceiveBuffers.Length != 0) // don't return Array.Empty. - { - ArrayPool.Shared.Return(oldReceiveBuffers); - } - } - - for (uint i = 0; i < receiveEvent.BufferCount; ++i) - { - state.ReceiveQuicBuffers[i] = receiveEvent.Buffers[i]; - } - - state.ReceiveQuicBuffersCount = (int)receiveEvent.BufferCount; - state.ReceiveQuicBuffersTotalBytes = checked((int)receiveEvent.TotalBufferLength); - state.ReceiveIsFinal = receiveEvent.Flags.HasFlag(QUIC_RECEIVE_FLAGS.FIN); - - // 0-length receive can happens once reads are finished (gracefully or otherwise). - if (state.ReceiveQuicBuffersTotalBytes == 0) - { - if (state.ReceiveIsFinal) - { - // We can complete the state without waiting for PEER_SEND_SHUTDOWN - state.ReadState = ReadState.ReadsCompleted; - } - - // if it was not a graceful shutdown, we defer aborting to PEER_SEND_ABORT event handler - return QUIC_STATUS_SUCCESS; - } - else - { - // Normal RECEIVE - data will be buffered until user calls ReadAsync() and no new event will be issued until EnableReceive() - state.ReadState = ReadState.IndividualReadComplete; - return QUIC_STATUS_PENDING; - } - - case ReadState.PendingRead: - // There is a pending ReadAsync(). - - state.ReceiveCancellationRegistration.Unregister(); - shouldComplete = true; - state.Stream = null; - state.ReadState = ReadState.PendingReadFinished; - // state.ReadState will be set to None later once the ReceiveResettableCompletionSource is awaited. - - readLength = CopyMsQuicBuffersToUserBuffer(new ReadOnlySpan(receiveEvent.Buffers, (int)receiveEvent.BufferCount), state.ReceiveUserBuffer.Span); - - // This was a final message and we've consumed everything. We can complete the state without waiting for PEER_SEND_SHUTDOWN - if (receiveEvent.Flags.HasFlag(QUIC_RECEIVE_FLAGS.FIN) && (uint)readLength == receiveEvent.TotalBufferLength) - { - state.ReadState = ReadState.ReadsCompleted; - } - // Else, if this was a final message, but we haven't consumed it fully, FIN flag will arrive again in the next RECEIVE event - - state.ReceiveUserBuffer = null; - break; - - default: - Debug.Assert(state.ReadState is ReadState.Aborted or ReadState.ConnectionClosed, $"Unexpected {nameof(ReadState)} '{state.ReadState}' in {nameof(HandleEventReceive)}."); - - // There was a race between a user aborting the read stream and the callback being ran. - // This will eat any received data. - return QUIC_STATUS_SUCCESS; - } - } - - if (shouldComplete) - { - state.ReceiveResettableCompletionSource.Complete(readLength); - // _state.ReadState will be reset to None on the reading thread. - } - - // Returning Success when the entire buffer hasn't been consumed will cause MsQuic to disable further receive events until EnableReceive() is called. - // Returning Continue will cause a second receive event to fire immediately after this returns, but allows MsQuic to clean up its buffers. - - int ret = (uint)readLength == receiveEvent.TotalBufferLength - ? QUIC_STATUS_SUCCESS - : QUIC_STATUS_CONTINUE; - - receiveEvent.TotalBufferLength = (uint)readLength; - return ret; - } - - private static int HandleEventPeerRecvAborted(State state, ref QUIC_STREAM_EVENT streamEvent) - { - bool shouldSendComplete = false; - bool shouldShutdownWriteComplete = false; - lock (state) - { - if (state.SendState == SendState.None || state.SendState == SendState.Pending) - { - shouldSendComplete = true; - } - - if (state.ShutdownWriteState == ShutdownWriteState.None) - { - state.ShutdownWriteState = ShutdownWriteState.Canceled; - shouldShutdownWriteComplete = true; - } - - state.SendErrorCode = (long)streamEvent.PEER_RECEIVE_ABORTED.ErrorCode; - // make sure the SendErrorCode above is commited to memory before we assign the state. This - // ensures that the code is read correctly in SetupWriteStartState when checking without lock - Volatile.Write(ref Unsafe.As(ref state.SendState), (int)SendState.Aborted); - } - - if (shouldSendComplete) - { - state.SendResettableCompletionSource.CompleteException( - ExceptionDispatchInfo.SetCurrentStackTrace(ThrowHelper.GetStreamAbortedException(state.SendErrorCode))); - } - - if (shouldShutdownWriteComplete) - { - state.ShutdownWriteCompletionSource.SetException( - ExceptionDispatchInfo.SetCurrentStackTrace(ThrowHelper.GetStreamAbortedException(state.SendErrorCode))); - } - - return QUIC_STATUS_SUCCESS; - } - - private static int HandleEventStartComplete(State state, ref QUIC_STREAM_EVENT streamEvent) - { - int status = streamEvent.START_COMPLETE.Status; - - // The way we expose Open(Uni|Bi)directionalStreamAsync operations is that the stream - // is also accepted by the peer (i.e. it is within advertised stream limits). However, - // We may receive START_COMPLETE notification before the stream is accepted, so we defer - // completing the StartcompletionSource until we get PeerAccepted notification. - - if (StatusSucceeded(status)) - { - state.StreamId = (long)streamEvent.START_COMPLETE.ID; - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(state, $"{state.Handle} StreamId = {state.StreamId}"); - - if (streamEvent.START_COMPLETE.PeerAccepted != 0) - { - // Start succeeded and we were within stream limits, stream already usable. - state.StartCompletionSource.TrySetResult(); - } - // if PeerAccepted == 0, we will later receive PEER_ACCEPTED event, which will - // complete the StartCompletionSource - } - else - { - // Start irrecoverably failed. The possible status codes are: - // - Aborted - connection aborted by peer - // - InvalidState - stream already started before, or connection aborted locally - // - StreamLimitReached - only if QUIC_STREAM_START_FLAG_FAIL_BLOCKED was specified (not in our case). - // - if (status == QUIC_STATUS_ABORTED) - { - state.StartCompletionSource.TrySetException( - ExceptionDispatchInfo.SetCurrentStackTrace(GetConnectionAbortedException(state))); - } - else - { - // TODO: Should we throw QuicOperationAbortedException when status is InvalidState? - // [ActiveIssue("https://github.com/dotnet/runtime/issues/55619")] - state.StartCompletionSource.TrySetException( - ExceptionDispatchInfo.SetCurrentStackTrace(ThrowHelper.GetExceptionForMsQuicStatus(status, "StreamStart failed"))); - } - } - - return QUIC_STATUS_SUCCESS; - } - - private static int HandleEventSendShutdownComplete(State state, ref QUIC_STREAM_EVENT streamEvent) - { - // Graceful will be false in three situations: - // 1. The peer aborted reads and the PEER_RECEIVE_ABORTED event was raised. - // ShutdownWriteCompletionSource is already complete with an error. - // 2. We aborted writes. - // ShutdownWriteCompletionSource is already complete with an error. - // 3. The connection was closed. - // SHUTDOWN_COMPLETE event will be raised immediately after this event. It will handle completing with an error. - // - // Only use this event with sends gracefully completed. - if (streamEvent.SEND_SHUTDOWN_COMPLETE.Graceful != 0) - { - bool shouldComplete = false; - lock (state) - { - if (state.ShutdownWriteState == ShutdownWriteState.None) - { - state.ShutdownWriteState = ShutdownWriteState.Finished; - shouldComplete = true; - } - } - - if (shouldComplete) - { - state.ShutdownWriteCompletionSource.SetResult(); - } - } - - return QUIC_STATUS_SUCCESS; - } - - private static int HandleEventShutdownComplete(State state, ref QUIC_STREAM_EVENT streamEvent) - { - var shutdownCompleteEvent = streamEvent.SHUTDOWN_COMPLETE; - - if (shutdownCompleteEvent.ConnectionShutdown != 0) - { - return HandleEventConnectionClose(state); - } - - bool shouldReadComplete = false; - bool shouldShutdownWriteComplete = false; - bool shouldShutdownComplete = false; - - lock (state) - { - // This event won't occur within the middle of a receive. - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(state, $"{state.Handle} Stream completing resettable event source."); - - shouldReadComplete = CleanupReadStateAndCheckPending(state, ReadState.ReadsCompleted); - - if (state.ShutdownWriteState == ShutdownWriteState.None) - { - // TODO: We can get to this point if the stream is unidirectional and there are no writes. - // Consider what is the best behavior here with write shutdown and the read side of - // unidirecitonal streams in the future. - state.ShutdownWriteState = ShutdownWriteState.Finished; - shouldShutdownWriteComplete = true; - } - - if (state.ShutdownState == ShutdownState.None) - { - state.ShutdownState = ShutdownState.Finished; - shouldShutdownComplete = true; - } - } - - if (shouldReadComplete) - { - if (state.StartCompletionSource.Task.IsCompletedSuccessfully) - { - state.ReceiveResettableCompletionSource.Complete(0); - } - else - { - state.ReceiveResettableCompletionSource.CompleteException( - ExceptionDispatchInfo.SetCurrentStackTrace(ThrowHelper.GetOperationAbortedException($"Stream start failed"))); - } - } - - if (shouldShutdownWriteComplete) - { - if (state.StartCompletionSource.Task.IsCompletedSuccessfully) - { - state.ShutdownWriteCompletionSource.SetResult(); - } - else - { - state.ShutdownWriteCompletionSource.SetException( - ExceptionDispatchInfo.SetCurrentStackTrace(ThrowHelper.GetOperationAbortedException($"Stream start failed"))); - } - } - - if (shouldShutdownComplete) - { - state.ShutdownCompletionSource.SetResult(); - } - - // If we are receiving stream shutdown notification, the start comletion source must have been already completed - // eihter by StreamOpen or PeerAccepted event, Connection closing, or it was cancelled by user. - Debug.Assert(state.StartCompletionSource.Task.IsCompleted); - - // Dispose was called before complete event. - bool releaseHandles = Interlocked.Exchange(ref state.ShutdownDone, State.ShutdownDone_NotificationReceived) == State.ShutdownDone_Disposed; - if (releaseHandles) - { - state.Cleanup(); - } - - return QUIC_STATUS_SUCCESS; - } - - private static int HandleEventPeerAccepted(State state) - { - state.StartCompletionSource.TrySetResult(); - return QUIC_STATUS_SUCCESS; - } - - private static int HandleEventPeerSendAborted(State state, ref QUIC_STREAM_EVENT streamEvent) - { - bool shouldComplete = false; - lock (state) - { - shouldComplete = CleanupReadStateAndCheckPending(state, ReadState.Aborted); - state.ReadErrorCode = (long)streamEvent.PEER_SEND_ABORTED.ErrorCode; - } - - if (shouldComplete) - { - state.ReceiveResettableCompletionSource.CompleteException( - ExceptionDispatchInfo.SetCurrentStackTrace(ThrowHelper.GetStreamAbortedException(state.ReadErrorCode))); - } - - return QUIC_STATUS_SUCCESS; - } - - private static int HandleEventPeerSendShutdown(State state) - { - bool shouldComplete = false; - - lock (state) - { - // This event won't occur within the middle of a receive. - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(state, $"{state.Handle} Stream completing resettable event source."); - - shouldComplete = CleanupReadStateAndCheckPending(state, ReadState.ReadsCompleted); - } - - if (shouldComplete) - { - state.ReceiveResettableCompletionSource.Complete(0); - } - - return QUIC_STATUS_SUCCESS; - } - - private static int HandleEventSendComplete(State state, ref QUIC_STREAM_EVENT streamEvent) - { - var sendCompleteEvent = streamEvent.SEND_COMPLETE; - bool canceled = sendCompleteEvent.Canceled != 0; - - bool complete = false; - - lock (state) - { - if (state.SendState == SendState.Pending) - { - state.SendState = SendState.Finished; - complete = true; - } - - if (canceled) - { - state.SendState = SendState.Aborted; - } - } - - if (complete) - { - CleanupSendState(state); - - if (!canceled) - { - state.SendResettableCompletionSource.Complete(QUIC_STATUS_SUCCESS); - } - else - { - // - // There are multiple reasons the send could have been cancelled: - // - Connection was aborted (either by transport or peer) => error-code already provided on the connection-level event - // - Stream's receive side was aborted by peer => already handled by HandleEventPeerRecvAborted - // and we will not set the exception due to complete == false - // - Stream's send side was aborted locally => no connection-level abort code and we return QuicOperationAbortException - // - state.SendResettableCompletionSource.CompleteException( - ExceptionDispatchInfo.SetCurrentStackTrace( - ThrowHelper.GetConnectionAbortedException(state.ConnectionState.AbortErrorCode))); - } - } - - return QUIC_STATUS_SUCCESS; - } - - private static void CleanupSendState(State state) - { - lock (state) - { - Debug.Assert(state.SendState != SendState.Pending); - state.SendBuffers.Reset(); - } - } - - private unsafe void ReceiveComplete(int bufferLength) - { - Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)"); - MsQuicApi.Api.ApiTable->StreamReceiveComplete(_state.Handle.QuicHandle, (ulong)bufferLength); - } - - // This can fail if the stream isn't started. - private static long GetStreamId(SafeMsQuicStreamHandle handle) - { - return (long)MsQuicParameterHelpers.GetULongParam(MsQuicApi.Api, handle, QUIC_PARAM_STREAM_ID); - } - - private void ThrowIfDisposed() - { - ObjectDisposedException.ThrowIf(_disposed == 1, this); - } - - private static int HandleEventConnectionClose(State state) - { - long errorCode = state.ConnectionState.AbortErrorCode; - if (NetEventSource.Log.IsEnabled()) - { - NetEventSource.Info(state, $"{state.Handle} Stream handling connection {state.ConnectionState.Handle} close" + - (errorCode != -1 ? $" with code {errorCode}" : "")); - } - - bool shouldCompleteRead = false; - bool shouldCompleteSend = false; - bool shouldCompleteShutdownWrite = false; - bool shouldCompleteShutdown = false; - - lock (state) - { - shouldCompleteRead = CleanupReadStateAndCheckPending(state, ReadState.ConnectionClosed); - - if (state.SendState == SendState.None || state.SendState == SendState.Pending) - { - shouldCompleteSend = true; - } - state.SendState = SendState.ConnectionClosed; - - if (state.ShutdownWriteState == ShutdownWriteState.None) - { - shouldCompleteShutdownWrite = true; - } - state.ShutdownWriteState = ShutdownWriteState.ConnectionClosed; - - if (state.ShutdownState == ShutdownState.None) - { - shouldCompleteShutdown = true; - } - state.ShutdownState = ShutdownState.ConnectionClosed; - } - - if (shouldCompleteRead) - { - state.ReceiveResettableCompletionSource.CompleteException( - ExceptionDispatchInfo.SetCurrentStackTrace(GetConnectionAbortedException(state))); - } - - if (shouldCompleteSend) - { - state.SendResettableCompletionSource.CompleteException( - ExceptionDispatchInfo.SetCurrentStackTrace(GetConnectionAbortedException(state))); - } - - if (shouldCompleteShutdownWrite) - { - state.ShutdownWriteCompletionSource.SetException( - ExceptionDispatchInfo.SetCurrentStackTrace(GetConnectionAbortedException(state))); - } - - if (shouldCompleteShutdown) - { - state.ShutdownCompletionSource.SetException( - ExceptionDispatchInfo.SetCurrentStackTrace(GetConnectionAbortedException(state))); - } - - if (!state.StartCompletionSource.Task.IsCompleted) - { - state.StartCompletionSource.TrySetException( - ExceptionDispatchInfo.SetCurrentStackTrace(GetConnectionAbortedException(state))); - } - - // Dispose was called before complete event. - bool releaseHandles = Interlocked.Exchange(ref state.ShutdownDone, State.ShutdownDone_NotificationReceived) == State.ShutdownDone_Disposed; - if (releaseHandles) - { - state.Cleanup(); - } - - return QUIC_STATUS_SUCCESS; - } - - private static Exception GetConnectionAbortedException(State state) => - ThrowHelper.GetConnectionAbortedException(state.ConnectionState.AbortErrorCode); - - private static bool CleanupReadStateAndCheckPending(State state, ReadState finalState) - { - Debug.Assert(finalState >= ReadState.ReadsCompleted, $"Expected final read state, got {finalState}"); - Debug.Assert(Monitor.IsEntered(state)); - - bool shouldComplete = false; - if (state.ReadState == ReadState.PendingRead) - { - shouldComplete = true; - state.Stream = null; - state.ReceiveUserBuffer = null; - state.ReceiveCancellationRegistration.Unregister(); - } - if (state.ReadState < ReadState.ReadsCompleted) - { - state.ReadState = finalState; - } - return shouldComplete; - } - - internal async ValueTask StartAsync(CancellationToken cancellationToken) - { - Debug.Assert(!Monitor.IsEntered(_state)); - - using var registration = cancellationToken.UnsafeRegister((state, token) => - { - ((State)state!).StartCompletionSource.TrySetCanceled(token); - }, _state); - - int status; - unsafe - { - status = MsQuicApi.Api.ApiTable->StreamStart( - _state.Handle.QuicHandle, - QUIC_STREAM_START_FLAGS.SHUTDOWN_ON_FAIL | QUIC_STREAM_START_FLAGS.INDICATE_PEER_ACCEPT); - } - - if (!StatusSucceeded(status)) - { - Exception exception = ThrowHelper.GetExceptionForMsQuicStatus(status, "Could not start stream"); - _state.StartCompletionSource.TrySetException(ExceptionDispatchInfo.SetCurrentStackTrace(exception)); - throw exception; - } - - await _state.StartCompletionSource.Task.ConfigureAwait(false); - } - - // Read state transitions: - // - // None --(data arrives in event RECV)-> IndividualReadComplete - // None --(data arrives in event RECV with FIN flag)-> IndividualReadComplete(+FIN) - // None --(0-byte data arrives in event RECV with FIN flag)-> ReadsCompleted - // None --(user calls ReadAsync() & waits)-> PendingRead - // - // IndividualReadComplete --(user calls ReadAsync())-> None - // IndividualReadComplete(+FIN) --(user calls ReadAsync() & consumes only partial data)-> None - // IndividualReadComplete(+FIN) --(user calls ReadAsync() & consumes full data)-> ReadsCompleted - // - // PendingRead --(data arrives in event RECV & completes user's ReadAsync())-> PendingReadFinished - // PendingRead --(data arrives in event RECV with FIN flag & completes user's ReadAsync() with only partial data)-> PendingReadFinished - // PendingRead --(data arrives in event RECV with FIN flag & completes user's ReadAsync() with full data)-> ReadsCompleted - // - // PendingReadFinished --(reading thread awaits ReceiveResettableCompletionSource)-> None - // - // Any non-final state --(event PEER_SEND_SHUTDOWN or SHUTDOWN_COMPLETED with ConnectionClosed=false)-> ReadsCompleted - // Any non-final state --(event PEER_SEND_ABORT)-> Aborted - // Any non-final state --(user calls AbortRead())-> Aborted - // Any non-final state --(CancellationToken's cancellation for ReadAsync())-> Aborted - // Any non-final state --(event SHUTDOWN_COMPLETED with ConnectionClosed=true)-> ConnectionClosed - // - // Closed - no transitions, set for Unidirectional write-only streams - private enum ReadState - { - /// - /// The stream is open, but there is no data available. - /// - None = 0, - - /// - /// Data is available in . - /// - IndividualReadComplete, - - /// - /// User called ReadAsync() - /// - PendingRead, - - /// - /// Read was completed from the MsQuic callback. - /// - PendingReadFinished, - - // following states are terminal: - - /// - /// The peer has gracefully shutdown their sends / our receives; the stream's reads are complete. - /// - ReadsCompleted, - - /// - /// User has aborted the stream, either via a cancellation token on ReadAsync(), or via AbortRead(). - /// - Aborted, - - /// - /// Connection was closed, either by user or by the peer. - /// - ConnectionClosed, - - /// - /// Stream is closed for reading (is send-only). - /// - Closed - } - - private enum ShutdownWriteState - { - None = 0, - Canceled, - Finished, - ConnectionClosed - } - - private enum ShutdownState - { - None = 0, - Canceled, - Pending, - Finished, - ConnectionClosed - } - - // Send state transitions: - // - // None --(user calls WriteAsync() & waits)-> Pending - // - // Pending --(event SEND_COMPLETE.Canceled == 0)-> Finished - // Pending --(event SEND_COMPLETE.Canceled == 1)-> Aborted - // - // Finished --(user awaits WriteAsync)-> None - // - // Any non-final state --(event PEER_RECEIVE_ABORTED)-> Aborted (With SendErrorCode) - // Any non-final state --(user calls AbortWrite())-> Aborted - // Any non-final state --(CancellationToken's cancellation for WriteAsync())-> Aborted - // Any non-final state --(event SHUTDOWN_COMPLETED with ConnectionClosed=true)-> ConnectionClosed - // - // Closed - no transitions, set for Unidirectional read-only streams - private enum SendState - { - /// - /// The stream is open and there are no pending write operations. - /// - None = 0, - - /// - /// There is a pending WriteAsync operation awaiting completion notification from MsQuic. - /// - Pending, - - /// - /// Send completion notification from MsQuic was received. - /// - Finished, - - // following states are terminal: - - /// - /// User has aborted the stream, either via a cancellation token on WriteAsync(), or via AbortWrite(). - /// - Aborted, - - /// - /// Connection was closed, either by user or by the peer. - /// - ConnectionClosed, - - /// - /// Stream is closed for writing (is receive-only). - /// - Closed - } - } -} diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicApi.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicApi.cs new file mode 100644 index 0000000000000..c4b1e00ada347 --- /dev/null +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicApi.cs @@ -0,0 +1,148 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; +using System.Runtime.InteropServices; +using Microsoft.Quic; + +using static Microsoft.Quic.MsQuic; + +#if TARGET_WINDOWS +using Microsoft.Win32; +#endif + +namespace System.Net.Quic; + +internal sealed unsafe class MsQuicApi +{ + private static readonly Version MinWindowsVersion = new Version(10, 0, 20145, 1000); + + private static readonly Version MsQuicVersion = new Version(2, 0); + + public MsQuicSafeHandle Registration { get; } + + public QUIC_API_TABLE* ApiTable { get; } + + // This is workaround for a bug in ILTrimmer. + // Without these DynamicDependency attributes, .ctor() will be removed from the safe handles. + // Remove once fixed: https://github.com/mono/linker/issues/1660 + [DynamicDependency(DynamicallyAccessedMemberTypes.PublicConstructors, typeof(MsQuicSafeHandle))] + [DynamicDependency(DynamicallyAccessedMemberTypes.PublicConstructors, typeof(MsQuicContextSafeHandle))] + private MsQuicApi(QUIC_API_TABLE* apiTable) + { + ApiTable = apiTable; + + fixed (byte* pAppName = "System.Net.Quic"u8) + { + var cfg = new QUIC_REGISTRATION_CONFIG + { + AppName = (sbyte*)pAppName, + ExecutionProfile = QUIC_EXECUTION_PROFILE.LOW_LATENCY + }; + + QUIC_HANDLE* handle; + ThrowHelper.ThrowIfMsQuicError(ApiTable->RegistrationOpen(&cfg, &handle), "RegistrationOpen failed"); + + Registration = new MsQuicSafeHandle(handle, apiTable->RegistrationClose, SafeHandleType.Registration); + } + } + + internal static MsQuicApi Api { get; } = null!; + + internal static bool IsQuicSupported { get; } + + internal static bool Tls13ServerMayBeDisabled { get; } + internal static bool Tls13ClientMayBeDisabled { get; } + + static MsQuicApi() + { + if (OperatingSystem.IsWindows()) + { + if (!IsWindowsVersionSupported()) + { + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Info(null, $"Current Windows version ({Environment.OSVersion}) is not supported by QUIC. Minimal supported version is {MinWindowsVersion}"); + } + + return; + } + + Tls13ServerMayBeDisabled = IsTls13Disabled(true); + Tls13ClientMayBeDisabled = IsTls13Disabled(false); + } + + IntPtr msQuicHandle; + if (NativeLibrary.TryLoad($"{Interop.Libraries.MsQuic}.{MsQuicVersion.Major}", typeof(MsQuicApi).Assembly, DllImportSearchPath.AssemblyDirectory, out msQuicHandle) || + NativeLibrary.TryLoad(Interop.Libraries.MsQuic, typeof(MsQuicApi).Assembly, DllImportSearchPath.AssemblyDirectory, out msQuicHandle)) + { + try + { + if (NativeLibrary.TryGetExport(msQuicHandle, "MsQuicOpenVersion", out IntPtr msQuicOpenVersionAddress)) + { + QUIC_API_TABLE* apiTable; + delegate* unmanaged[Cdecl] msQuicOpenVersion = (delegate* unmanaged[Cdecl])msQuicOpenVersionAddress; + if (StatusSucceeded(msQuicOpenVersion((uint)MsQuicVersion.Major, &apiTable))) + { + int arraySize = 4; + uint* libVersion = stackalloc uint[arraySize]; + uint size = (uint)arraySize * sizeof(uint); + if (StatusSucceeded(apiTable->GetParam(null, QUIC_PARAM_GLOBAL_LIBRARY_VERSION, &size, libVersion))) + { + var version = new Version((int)libVersion[0], (int)libVersion[1], (int)libVersion[2], (int)libVersion[3]); + if (version >= MsQuicVersion) + { + Api = new MsQuicApi(apiTable); + IsQuicSupported = true; + } + else + { + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Info(null, $"Incompatible MsQuic library version '{version}', expecting '{MsQuicVersion}'"); + } + } + } + } + } + } + finally + { + if (!IsQuicSupported) + { + NativeLibrary.Free(msQuicHandle); + } + } + } + } + + private static bool IsWindowsVersionSupported() => OperatingSystem.IsWindowsVersionAtLeast(MinWindowsVersion.Major, + MinWindowsVersion.Minor, MinWindowsVersion.Build, MinWindowsVersion.Revision); + + private static bool IsTls13Disabled(bool isServer) + { +#if TARGET_WINDOWS + string SChannelTls13RegistryKey = isServer + ? @"SYSTEM\CurrentControlSet\Control\SecurityProviders\SCHANNEL\Protocols\TLS 1.3\Server" + : @"SYSTEM\CurrentControlSet\Control\SecurityProviders\SCHANNEL\Protocols\TLS 1.3\Client"; + + using var regKey = Registry.LocalMachine.OpenSubKey(SChannelTls13RegistryKey); + + if (regKey is null) + { + return false; + } + + if (regKey.GetValue("Enabled") is int enabled && enabled == 0) + { + return true; + } + + if (regKey.GetValue("DisabledByDefault") is int disabled && disabled == 1) + { + return true; + } +#endif + return false; + } +} diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Internal/MsQuicBuffers.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicBuffers.cs similarity index 74% rename from src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Internal/MsQuicBuffers.cs rename to src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicBuffers.cs index cc0c9c177739c..b50e3c0f5c9a3 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Internal/MsQuicBuffers.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicBuffers.cs @@ -6,7 +6,7 @@ using System.Runtime.InteropServices; using Microsoft.Quic; -namespace System.Net.Quic.Implementations.MsQuic.Internal; +namespace System.Net.Quic; /// /// Helper class to convert managed data into QUIC_BUFFER* consumable by MsQuic. @@ -89,44 +89,6 @@ public void Initialize(ReadOnlyMemory buffer) SetBuffer(0, buffer); } - /// - /// Initializes QUIC_BUFFER* with the provided buffers. - /// Note that the struct either needs to be freshly created via new or previously cleaned up with Reset. - /// - /// Buffers to be passed to MsQuic as QUIC_BUFFER*. - public void Initialize(ReadOnlySequence buffers) - { - int count = 0; - foreach (ReadOnlyMemory _ in buffers) - { - ++count; - } - - Reserve(count); - int i = 0; - foreach (ReadOnlyMemory buffer in buffers) - { - SetBuffer(i++, buffer); - } - } - - /// - /// Initializes QUIC_BUFFER* with the provided buffers. - /// Note that the struct either needs to be freshly created via new or previously cleaned up with Reset. - /// - /// Buffers to be passed to MsQuic as QUIC_BUFFER*. - public void Initialize(ReadOnlyMemory> buffers) - { - int count = buffers.Length; - Reserve(count); - - ReadOnlySpan> span = buffers.Span; - for (int i = 0; i < span.Length; i++) - { - SetBuffer(i, span[i]); - } - } - /// /// Unpins the managed memory and allows reuse of this struct. /// diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicConfiguration.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicConfiguration.cs new file mode 100644 index 0000000000000..67a72e453ab72 --- /dev/null +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicConfiguration.cs @@ -0,0 +1,235 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Net.Security; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using Microsoft.Quic; +using static Microsoft.Quic.MsQuic; + +namespace System.Net.Quic; + +internal static class MsQuicConfiguration +{ + private static bool HasPrivateKey(this X509Certificate certificate) + => certificate is X509Certificate2 certificate2 && certificate2.Handle != IntPtr.Zero && certificate2.HasPrivateKey; + + public static MsQuicSafeHandle Create(QuicClientConnectionOptions options) + { + SslClientAuthenticationOptions authenticationOptions = options.ClientAuthenticationOptions; + + QUIC_CREDENTIAL_FLAGS flags = QUIC_CREDENTIAL_FLAGS.NONE; + flags |= QUIC_CREDENTIAL_FLAGS.CLIENT; + flags |= QUIC_CREDENTIAL_FLAGS.INDICATE_CERTIFICATE_RECEIVED; + flags |= QUIC_CREDENTIAL_FLAGS.NO_CERTIFICATE_VALIDATION; + if (OperatingSystem.IsWindows()) + { + flags |= QUIC_CREDENTIAL_FLAGS.USE_SUPPLIED_CREDENTIALS; + } + + // Find the first certificate with private key, either from selection callback or from a provided collection. + X509Certificate? certificate = null; + if (authenticationOptions.LocalCertificateSelectionCallback != null) + { + X509Certificate selectedCertificate = authenticationOptions.LocalCertificateSelectionCallback( + options, + authenticationOptions.TargetHost ?? string.Empty, + authenticationOptions.ClientCertificates ?? new X509CertificateCollection(), + null, + Array.Empty()); + if (selectedCertificate.HasPrivateKey()) + { + certificate = selectedCertificate; + } + } + else if (authenticationOptions.ClientCertificates != null) + { + foreach (X509Certificate clientCertificate in authenticationOptions.ClientCertificates) + { + if( clientCertificate.HasPrivateKey()) + { + certificate = clientCertificate; + break; + } + } + } + + return Create(options, flags, certificate, intermediates: null, authenticationOptions.ApplicationProtocols, authenticationOptions.CipherSuitesPolicy, authenticationOptions.EncryptionPolicy); + } + + public static MsQuicSafeHandle Create(QuicServerConnectionOptions options, string? targetHost) + { + SslServerAuthenticationOptions authenticationOptions = options.ServerAuthenticationOptions; + + QUIC_CREDENTIAL_FLAGS flags = QUIC_CREDENTIAL_FLAGS.NONE; + if (authenticationOptions.ClientCertificateRequired) + { + flags |= QUIC_CREDENTIAL_FLAGS.REQUIRE_CLIENT_AUTHENTICATION; + flags |= QUIC_CREDENTIAL_FLAGS.INDICATE_CERTIFICATE_RECEIVED; + flags |= QUIC_CREDENTIAL_FLAGS.NO_CERTIFICATE_VALIDATION; + } + + X509Certificate? certificate = null; + X509Certificate[]? intermediates = null; + if (authenticationOptions.ServerCertificateContext is not null) + { + certificate = authenticationOptions.ServerCertificateContext.Certificate; + intermediates = authenticationOptions.ServerCertificateContext.IntermediateCertificates; + } + + certificate ??= authenticationOptions.ServerCertificate ?? authenticationOptions.ServerCertificateSelectionCallback?.Invoke(authenticationOptions, targetHost); + if (certificate is null) + { + throw new ArgumentException(SR.Format(SR.net_quic_not_null_ceritifcate, nameof(SslServerAuthenticationOptions.ServerCertificate), nameof(SslServerAuthenticationOptions.ServerCertificateContext), nameof(SslServerAuthenticationOptions.ServerCertificateSelectionCallback)), nameof(options)); + } + + return Create(options, flags, certificate, intermediates, authenticationOptions.ApplicationProtocols, authenticationOptions.CipherSuitesPolicy, authenticationOptions.EncryptionPolicy); + } + + private static unsafe MsQuicSafeHandle Create(QuicConnectionOptions options, QUIC_CREDENTIAL_FLAGS flags, X509Certificate? certificate, X509Certificate[]? intermediates, List? alpnProtocols, CipherSuitesPolicy? cipherSuitesPolicy, EncryptionPolicy encryptionPolicy) + { + // Validate options and SSL parameters. + if (alpnProtocols is null || alpnProtocols.Count <= 0) + { + throw new ArgumentException(SR.Format(SR.net_quic_not_null_not_empty_connection, nameof(SslApplicationProtocol)), nameof(options)); + } + +#pragma warning disable SYSLIB0040 // NoEncryption and AllowNoEncryption are obsolete + if (encryptionPolicy == EncryptionPolicy.NoEncryption) + { + throw new PlatformNotSupportedException(SR.Format(SR.net_quic_ssl_option, encryptionPolicy)); + } +#pragma warning restore SYSLIB0040 + + QUIC_SETTINGS settings = default(QUIC_SETTINGS); + settings.IsSet.PeerUnidiStreamCount = 1; + settings.PeerUnidiStreamCount = (ushort)options.MaxInboundUnidirectionalStreams; + settings.IsSet.PeerBidiStreamCount = 1; + settings.PeerBidiStreamCount = (ushort)options.MaxInboundBidirectionalStreams; + if (options.IdleTimeout != TimeSpan.Zero) + { + settings.IsSet.IdleTimeoutMs = 1; + settings.IdleTimeoutMs = options.IdleTimeout != Timeout.InfiniteTimeSpan ? (ulong)options.IdleTimeout.TotalMilliseconds : 0; + } + + QUIC_HANDLE* handle; + + using MsQuicBuffers msquicBuffers = new MsQuicBuffers(); + msquicBuffers.Initialize(alpnProtocols, alpnProtocol => alpnProtocol.Protocol); + ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->ConfigurationOpen( + MsQuicApi.Api.Registration.QuicHandle, + msquicBuffers.Buffers, + (uint)alpnProtocols.Count, + &settings, + (uint)sizeof(QUIC_SETTINGS), + (void*)IntPtr.Zero, + &handle), + "ConfigurationOpen failed"); + MsQuicSafeHandle configurationHandle = new MsQuicSafeHandle(handle, MsQuicApi.Api.ApiTable->ConfigurationClose, SafeHandleType.Configuration); + + try + { + QUIC_CREDENTIAL_CONFIG config = new QUIC_CREDENTIAL_CONFIG { Flags = flags }; + config.Flags |= (OperatingSystem.IsWindows() ? QUIC_CREDENTIAL_FLAGS.NONE : QUIC_CREDENTIAL_FLAGS.USE_PORTABLE_CERTIFICATES); + + if (cipherSuitesPolicy != null) + { + config.Flags |= QUIC_CREDENTIAL_FLAGS.SET_ALLOWED_CIPHER_SUITES; + config.AllowedCipherSuites = CipherSuitePolicyToFlags(cipherSuitesPolicy); + } + + int status; + if (certificate is null) + { + config.Type = QUIC_CREDENTIAL_TYPE.NONE; + status = MsQuicApi.Api.ApiTable->ConfigurationLoadCredential(configurationHandle.QuicHandle, &config); + } + else if (OperatingSystem.IsWindows()) + { + config.Type = QUIC_CREDENTIAL_TYPE.CERTIFICATE_CONTEXT; + config.CertificateContext = (void*)certificate.Handle; + status = MsQuicApi.Api.ApiTable->ConfigurationLoadCredential(configurationHandle.QuicHandle, &config); + } + else + { + config.Type = QUIC_CREDENTIAL_TYPE.CERTIFICATE_PKCS12; + + byte[] certificateData; + + if (intermediates?.Length > 0) + { + X509Certificate2Collection collection = new X509Certificate2Collection(); + collection.Add(certificate); + collection.AddRange(intermediates); + certificateData = collection.Export(X509ContentType.Pkcs12)!; + } + else + { + certificateData = certificate.Export(X509ContentType.Pkcs12); + } + + fixed (byte* ptr = certificateData) + { + QUIC_CERTIFICATE_PKCS12 pkcs12Certificate = new QUIC_CERTIFICATE_PKCS12 + { + Asn1Blob = ptr, + Asn1BlobLength = (uint)certificateData.Length, + PrivateKeyPassword = (sbyte*)IntPtr.Zero + }; + config.CertificatePkcs12 = &pkcs12Certificate; + status = MsQuicApi.Api.ApiTable->ConfigurationLoadCredential(configurationHandle.QuicHandle, &config); + } + } + +#if TARGET_WINDOWS + if ((Interop.SECURITY_STATUS)status == Interop.SECURITY_STATUS.AlgorithmMismatch && + ((flags & QUIC_CREDENTIAL_FLAGS.CLIENT) == 0 ? MsQuicApi.Tls13ServerMayBeDisabled : MsQuicApi.Tls13ClientMayBeDisabled)) + { + ThrowHelper.ThrowIfMsQuicError(status, SR.net_quic_tls_version_notsupported); + } +#endif + + ThrowHelper.ThrowIfMsQuicError(status, "ConfigurationLoadCredential failed"); + } + catch + { + configurationHandle.Dispose(); + throw; + } + + return configurationHandle; + } + + private static QUIC_ALLOWED_CIPHER_SUITE_FLAGS CipherSuitePolicyToFlags(CipherSuitesPolicy cipherSuitesPolicy) + { + QUIC_ALLOWED_CIPHER_SUITE_FLAGS flags = QUIC_ALLOWED_CIPHER_SUITE_FLAGS.NONE; + + foreach (TlsCipherSuite cipher in cipherSuitesPolicy.AllowedCipherSuites) + { + switch (cipher) + { + case TlsCipherSuite.TLS_AES_128_GCM_SHA256: + flags |= QUIC_ALLOWED_CIPHER_SUITE_FLAGS.AES_128_GCM_SHA256; + break; + case TlsCipherSuite.TLS_AES_256_GCM_SHA384: + flags |= QUIC_ALLOWED_CIPHER_SUITE_FLAGS.AES_256_GCM_SHA384; + break; + case TlsCipherSuite.TLS_CHACHA20_POLY1305_SHA256: + flags |= QUIC_ALLOWED_CIPHER_SUITE_FLAGS.CHACHA20_POLY1305_SHA256; + break; + case TlsCipherSuite.TLS_AES_128_CCM_SHA256: // not supported by MsQuic (yet?), but QUIC RFC allows it so we ignore it. + default: + // ignore + break; + } + } + + if (flags == QUIC_ALLOWED_CIPHER_SUITE_FLAGS.NONE) + { + throw new ArgumentException(SR.net_quic_empty_cipher_suite, nameof(SslClientAuthenticationOptions.CipherSuitesPolicy)); + } + + return flags; + } +} diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicHelpers.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicHelpers.cs new file mode 100644 index 0000000000000..210cc20776875 --- /dev/null +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicHelpers.cs @@ -0,0 +1,84 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Net.Sockets; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using Microsoft.Quic; +using static Microsoft.Quic.MsQuic; + +namespace System.Net.Quic; + +internal static class MsQuicHelpers +{ + internal static bool TryParse(this EndPoint endPoint, out string? host, out IPAddress? address, out int port) + { + if (endPoint is DnsEndPoint dnsEndPoint) + { + host = IPAddress.TryParse(dnsEndPoint.Host, out address) ? null : dnsEndPoint.Host; + port = dnsEndPoint.Port; + return true; + } + + if (endPoint is IPEndPoint ipEndPoint) + { + host = null; + address = ipEndPoint.Address; + port = ipEndPoint.Port; + return true; + } + + host = default; + address = default; + port = default; + return false; + } + + internal static unsafe IPEndPoint ToIPEndPoint(this ref QuicAddr quicAddress, AddressFamily? addressFamilyOverride = null) + { + // MsQuic always uses storage size as if IPv6 was used + Span addressBytes = new Span((byte*)Unsafe.AsPointer(ref quicAddress), Internals.SocketAddress.IPv6AddressSize); + return new Internals.SocketAddress(addressFamilyOverride ?? SocketAddressPal.GetAddressFamily(addressBytes), addressBytes).GetIPEndPoint(); + } + + internal static unsafe QuicAddr ToQuicAddr(this IPEndPoint iPEndPoint) + { + // TODO: is the layout same for SocketAddress.Buffer and QuicAddr on all platforms? + QuicAddr result = default; + Span rawAddress = MemoryMarshal.AsBytes(MemoryMarshal.CreateSpan(ref result, 1)); + + Internals.SocketAddress address = IPEndPointExtensions.Serialize(iPEndPoint); + Debug.Assert(address.Size <= rawAddress.Length); + + address.Buffer.AsSpan(0, address.Size).CopyTo(rawAddress); + return result; + } + + internal static unsafe T GetMsQuicParameter(MsQuicSafeHandle handle, uint parameter) + where T: unmanaged + { + T value; + uint length = (uint)sizeof(T); + + ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->GetParam( + handle.QuicHandle, + parameter, + &length, + (byte*)&value), + $"GetParam({handle}, {parameter}) failed"); + + return value; + } + + internal static unsafe void SetMsQuicParameter(MsQuicSafeHandle handle, uint parameter, T value) + where T: unmanaged + { + ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->SetParam( + handle.QuicHandle, + parameter, + (uint)sizeof(T), + (byte*)&value), + $"SetParam({handle}, {parameter}) failed"); + } +} diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Interop/MsQuicSafeHandle.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicSafeHandle.cs similarity index 89% rename from src/libraries/System.Net.Quic/src/System/Net/Quic/Interop/MsQuicSafeHandle.cs rename to src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicSafeHandle.cs index 1e6d924bbac98..b635e1876c3f4 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Interop/MsQuicSafeHandle.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicSafeHandle.cs @@ -34,7 +34,7 @@ public MsQuicSafeHandle(QUIC_HANDLE* handle, delegate* unmanaged[Cdecl] quicBuffers, int totalLength, bool final) + { + lock (_syncRoot) + { + if (_buffer.ActiveMemory.Length > MaxBufferedBytes - totalLength) + { + totalLength = MaxBufferedBytes - _buffer.ActiveMemory.Length; + final = false; + } + + _final = final; + _buffer.EnsureAvailableSpace(totalLength); + + int totalCopied = 0; + for (int i = 0; i < quicBuffers.Length; ++i) + { + Span quicBuffer = quicBuffers[i].Span; + if (totalLength < quicBuffer.Length) + { + quicBuffer = quicBuffer.Slice(0, totalLength); + } + _buffer.AvailableMemory.CopyFrom(quicBuffer); + _buffer.Commit(quicBuffer.Length); + totalCopied += quicBuffer.Length; + totalLength -= quicBuffer.Length; + } + return totalCopied; + } + } + + public int CopyTo(Memory buffer, out bool isCompleted, out bool isEmpty) + { + lock (_syncRoot) + { + int copied = 0; + if (!_buffer.IsEmpty) + { + MultiMemory activeBuffer = _buffer.ActiveMemory; + copied = Math.Min(buffer.Length, activeBuffer.Length); + activeBuffer.Slice(0, copied).CopyTo(buffer.Span); + _buffer.Discard(copied); + } + + isCompleted = _buffer.IsEmpty && _final; + isEmpty = _buffer.IsEmpty; + + return copied; + } + } +} diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ResettableValueTaskSource.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ResettableValueTaskSource.cs new file mode 100644 index 0000000000000..975bd86b1a6fa --- /dev/null +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ResettableValueTaskSource.cs @@ -0,0 +1,281 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.ExceptionServices; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; +using System.Threading.Tasks.Sources; + +namespace System.Net.Quic; + +internal sealed class ResettableValueTaskSource : IValueTaskSource +{ + // None -> [TryGetValueTask] -> Awaiting -> [TrySetResult|TrySetException(final: false)] -> Ready -> [GetResult] -> None + // None -> [TrySetResult|TrySetException(final: false)] -> Ready -> [TryGetValueTask] -> [GetResult] -> None + // None|Awaiting -> [TrySetResult|TrySetException(final: true)] -> Final(never leaves this state) + private enum State + { + None, + Awaiting, + Ready, + Completed + } + + private State _state; + private ManualResetValueTaskSourceCore _valueTaskSource; + private CancellationTokenRegistration _cancellationRegistration; + private Action? _cancellationAction; + private GCHandle _keepAlive; + + private FinalTaskSource _finalTaskSource; + + public ResettableValueTaskSource(bool runContinuationsAsynchronously = true) + { + _state = State.None; + _valueTaskSource = new ManualResetValueTaskSourceCore() { RunContinuationsAsynchronously = runContinuationsAsynchronously }; + _cancellationRegistration = default; + _keepAlive = default; + + _finalTaskSource = new FinalTaskSource(runContinuationsAsynchronously); + } + + /// + /// Allows setting additional cancellation action to be called if token passed to fires off. + /// The argument for the action is the keepAlive object from the same call. + /// + public Action CancellationAction { init { _cancellationAction = value; } } + + /// + /// Returns true is this task source has entered its final state, i.e. or + /// was called with final set to true and the result was propagated. + /// + public bool IsCompleted => (State)Volatile.Read(ref Unsafe.As(ref _state)) == State.Completed; + + public bool TryGetValueTask(out ValueTask valueTask, object? keepAlive = null, CancellationToken cancellationToken = default) + { + lock (this) + { + // Cancellation might kick off synchronously, re-entering the lock and changing the state to completed. + if (_state == State.None) + { + // Register cancellation if the token can be cancelled and the task is not completed yet. + if (cancellationToken.CanBeCanceled) + { + _cancellationRegistration = cancellationToken.UnsafeRegister(static (obj, cancellationToken) => + { + (ResettableValueTaskSource parent, object? target) = ((ResettableValueTaskSource, object?))obj!; + if (parent.TrySetException(new OperationCanceledException(cancellationToken))) + { + parent._cancellationAction?.Invoke(target); + } + }, (this, keepAlive)); + } + } + + State state = _state; + + // None: prepare for the actual operation happening and transition to Awaiting. + if (state == State.None) + { + // Keep alive the caller object until the result is read from the task. + // Used for keeping caller alive during async interop calls. + if (keepAlive is not null) + { + Debug.Assert(!_keepAlive.IsAllocated); + _keepAlive = GCHandle.Alloc(keepAlive); + } + + _state = State.Awaiting; + } + // None, Completed, Final: return the current task. + if (state == State.None || + state == State.Ready || + state == State.Completed) + { + valueTask = new ValueTask(this, _valueTaskSource.Version); + return true; + } + + // Awaiting: forbidden concurrent call. + valueTask = default; + return false; + } + } + + public Task GetFinalTask() => _finalTaskSource.Task; + + private bool TryComplete(Exception? exception, bool final) + { + CancellationTokenRegistration cancellationRegistration = default; + try + { + lock (this) + { + try + { + State state = _state; + + // None,Awaiting: clean up and finish the task source. + if (state == State.Awaiting || + state == State.None) + { + _state = final ? State.Completed : State.Ready; + + // Swap the cancellation registration so the one that's been registered gets eventually Disposed. + // Ideally, we would dispose it here, but if the callbacks kicks in, it tries to take the lock held by this thread leading to deadlock. + cancellationRegistration = _cancellationRegistration; + _cancellationRegistration = default; + + // Unblock the current task source and in case of a final also the final task source. + if (exception is not null) + { + // Set up the exception stack strace for the caller. + exception = exception.StackTrace is null ? ExceptionDispatchInfo.SetCurrentStackTrace(exception) : exception; + _valueTaskSource.SetException(exception); + } + else + { + _valueTaskSource.SetResult(final); + } + + if (final) + { + _finalTaskSource.TryComplete(exception); + _finalTaskSource.TrySignal(out _); + } + + return true; + } + + // Final: remember the first final result to set it once the current non-final result gets retrieved. + if (final) + { + return _finalTaskSource.TryComplete(exception); + } + + return false; + } + finally + { + // Un-root the the kept alive object in all cases. + if (_keepAlive.IsAllocated) + { + _keepAlive.Free(); + } + } + } + } + finally + { + // Dispose the cancellation if registered. + // Must be done outside of lock since Dispose will wait on pending cancellation callbacks which require taking the lock. + cancellationRegistration.Dispose(); + } + } + + public bool TrySetResult(bool final = false) + { + return TryComplete(null, final); + } + + public bool TrySetException(Exception exception, bool final = false) + { + return TryComplete(exception, final); + } + + ValueTaskSourceStatus IValueTaskSource.GetStatus(short token) + => _valueTaskSource.GetStatus(token); + + void IValueTaskSource.OnCompleted(Action continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags) + => _valueTaskSource.OnCompleted(continuation, state, token, flags); + + void IValueTaskSource.GetResult(short token) + { + try + { + _valueTaskSource.GetResult(token); + } + finally + { + lock (this) + { + State state = _state; + + if (state == State.Ready) + { + _valueTaskSource.Reset(); + if (_finalTaskSource.TrySignal(out Exception? exception)) + { + _state = State.Completed; + + if (exception is not null) + { + _valueTaskSource.SetException(exception); + } + else + { + _valueTaskSource.SetResult(true); + } + } + else + { + _state = State.None; + } + } + } + } + } + + private struct FinalTaskSource + { + private TaskCompletionSource _finalTaskSource; + private bool _isCompleted; + private Exception? _exception; + + public FinalTaskSource(bool runContinuationsAsynchronously = true) + { + // TODO: defer instantiation only after Task is retrieved + _finalTaskSource = new TaskCompletionSource(runContinuationsAsynchronously ? TaskCreationOptions.RunContinuationsAsynchronously : TaskCreationOptions.None); + _isCompleted = false; + _exception = null; + } + + public Task Task => _finalTaskSource.Task; + + public bool TryComplete(Exception? exception = null) + { + if (_isCompleted) + { + return false; + } + + _exception = exception; + _isCompleted = true; + return true; + } + + public bool TrySignal(out Exception? exception) + { + if (!_isCompleted) + { + exception = default; + return false; + } + + if (_exception is not null) + { + _finalTaskSource.SetException(_exception); + } + else + { + _finalTaskSource.SetResult(); + } + + exception = _exception; + return true; + } + } +} diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ThrowHelper.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ThrowHelper.cs index cdfebdb725457..01a6fd9806df9 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ThrowHelper.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ThrowHelper.cs @@ -5,151 +5,150 @@ using System.Security.Authentication; using static Microsoft.Quic.MsQuic; -namespace System.Net.Quic.Implementations.MsQuic +namespace System.Net.Quic; + +internal static class ThrowHelper { - internal static class ThrowHelper + internal static QuicException GetConnectionAbortedException(long errorCode) { - internal static QuicException GetConnectionAbortedException(long errorCode) + return errorCode switch { - return errorCode switch - { - -1 => GetOperationAbortedException(), // Shutdown initiated by us. - long err => new QuicException(QuicError.ConnectionAborted, err, SR.Format(SR.net_quic_connectionaborted, err)) // Shutdown initiated by peer. - }; - } + -1 => GetOperationAbortedException(), // Shutdown initiated by us. + long err => new QuicException(QuicError.ConnectionAborted, err, SR.Format(SR.net_quic_connectionaborted, err)) // Shutdown initiated by peer. + }; + } - internal static QuicException GetStreamAbortedException(long errorCode) + internal static QuicException GetStreamAbortedException(long errorCode) + { + return errorCode switch { - return errorCode switch - { - -1 => GetOperationAbortedException(), // Shutdown initiated by us. - long err => new QuicException(QuicError.StreamAborted, err, SR.Format(SR.net_quic_streamaborted, err)) // Shutdown initiated by peer. - }; - } + -1 => GetOperationAbortedException(), // Shutdown initiated by us. + long err => new QuicException(QuicError.StreamAborted, err, SR.Format(SR.net_quic_streamaborted, err)) // Shutdown initiated by peer. + }; + } - internal static QuicException GetOperationAbortedException(string? message = null) + internal static QuicException GetOperationAbortedException(string? message = null) + { + return new QuicException(QuicError.OperationAborted, null, message ?? SR.net_quic_operationaborted); + } + + internal static Exception GetExceptionForMsQuicStatus(int status, string? message = null) + { + Exception ex = GetExceptionInternal(status, message); + if (status != 0) { - return new QuicException(QuicError.OperationAborted, null, message ?? SR.net_quic_operationaborted); + // Include the raw MsQuic status in the HResult property for better diagnostics + ex.HResult = status; } - internal static Exception GetExceptionForMsQuicStatus(int status, string? message = null) + return ex; + + static Exception GetExceptionInternal(int status, string? message) { - Exception ex = GetExceptionInternal(status, message); - if (status != 0) + // + // Start by checking for statuses mapped to QuicError enum + // + if (status == QUIC_STATUS_ADDRESS_IN_USE) return new QuicException(QuicError.AddressInUse, null, SR.net_quic_address_in_use); + if (status == QUIC_STATUS_UNREACHABLE) return new QuicException(QuicError.HostUnreachable, null, SR.net_quic_host_unreachable); + if (status == QUIC_STATUS_CONNECTION_REFUSED) return new QuicException(QuicError.ConnectionRefused, null, SR.net_quic_connection_refused); + if (status == QUIC_STATUS_VER_NEG_ERROR) return new QuicException(QuicError.VersionNegotiationError, null, SR.net_quic_ver_neg_error); + if (status == QUIC_STATUS_INVALID_ADDRESS) return new QuicException(QuicError.InvalidAddress, null, SR.net_quic_invalid_address); + if (status == QUIC_STATUS_CONNECTION_IDLE) return new QuicException(QuicError.ConnectionIdle, null, SR.net_quic_connection_idle); + if (status == QUIC_STATUS_PROTOCOL_ERROR) return new QuicException(QuicError.ProtocolError, null, SR.net_quic_protocol_error); + + if (status == QUIC_STATUS_TLS_ERROR || + status == QUIC_STATUS_CERT_EXPIRED || + status == QUIC_STATUS_CERT_UNTRUSTED_ROOT || + status == QUIC_STATUS_CERT_NO_CERT) { - // Include the raw MsQuic status in the HResult property for better diagnostics - ex.HResult = status; + return new AuthenticationException(SR.Format(SR.net_quic_auth, GetErrorMessageForStatus(status, message))); } - return ex; - - static Exception GetExceptionInternal(int status, string? message) + // + // Although ALPN negotiation failure is triggered by a TLS Alert, it is mapped differently + // + if (status == QUIC_STATUS_ALPN_NEG_FAILURE) { - // - // Start by checking for statuses mapped to QuicError enum - // - if (status == QUIC_STATUS_ADDRESS_IN_USE) return new QuicException(QuicError.AddressInUse, null, SR.net_quic_address_in_use); - if (status == QUIC_STATUS_UNREACHABLE) return new QuicException(QuicError.HostUnreachable, null, SR.net_quic_host_unreachable); - if (status == QUIC_STATUS_CONNECTION_REFUSED) return new QuicException(QuicError.ConnectionRefused, null, SR.net_quic_connection_refused); - if (status == QUIC_STATUS_VER_NEG_ERROR) return new QuicException(QuicError.VersionNegotiationError, null, SR.net_quic_ver_neg_error); - if (status == QUIC_STATUS_INVALID_ADDRESS) return new QuicException(QuicError.InvalidAddress, null, SR.net_quic_invalid_address); - if (status == QUIC_STATUS_CONNECTION_IDLE) return new QuicException(QuicError.ConnectionIdle, null, SR.net_quic_connection_idle); - if (status == QUIC_STATUS_PROTOCOL_ERROR) return new QuicException(QuicError.ProtocolError, null, SR.net_quic_protocol_error); - - if (status == QUIC_STATUS_TLS_ERROR || - status == QUIC_STATUS_CERT_EXPIRED || - status == QUIC_STATUS_CERT_UNTRUSTED_ROOT || - status == QUIC_STATUS_CERT_NO_CERT) - { - return new AuthenticationException(SR.Format(SR.net_quic_auth, GetErrorMessageForStatus(status, message))); - } - - // - // Although ALPN negotiation failure is triggered by a TLS Alert, it is mapped differently - // - if (status == QUIC_STATUS_ALPN_NEG_FAILURE) - { - return new AuthenticationException(SR.net_quic_alpn_neg_error); - } - - // - // other TLS Alerts: MsQuic maps TLS alerts by offsetting them by a - // certain value. CloseNotify is the TLS Alert with value 0x00, so - // all TLS Alert codes are mapped to [QUIC_STATUS_CLOSE_NOTIFY, - // QUIC_STATUS_CLOSE_NOTIFY + 255] - // - // Mapped TLS alerts include following statuses: - // - QUIC_STATUS_CLOSE_NOTIFY - // - QUIC_STATUS_BAD_CERTIFICATE - // - QUIC_STATUS_UNSUPPORTED_CERTIFICATE - // - QUIC_STATUS_REVOKED_CERTIFICATE - // - QUIC_STATUS_EXPIRED_CERTIFICATE - // - QUIC_STATUS_UNKNOWN_CERTIFICATE - // - QUIC_STATUS_REQUIRED_CERTIFICATE - // - if ((uint)status >= (uint)QUIC_STATUS_CLOSE_NOTIFY && (uint)status < (uint)QUIC_STATUS_CLOSE_NOTIFY + 256) - { - int alert = status - QUIC_STATUS_CLOSE_NOTIFY; - return new AuthenticationException(SR.Format(SR.net_auth_tls_alert, alert)); - } - - // - // for everything else, use general InternalError - // - return new QuicException(QuicError.InternalError, null, SR.Format(SR.net_quic_internal_error, GetErrorMessageForStatus(status, message))); + return new AuthenticationException(SR.net_quic_alpn_neg_error); } - } - internal static void ThrowIfMsQuicError(int status, string? message = null) - { - if (StatusFailed(status)) + // + // other TLS Alerts: MsQuic maps TLS alerts by offsetting them by a + // certain value. CloseNotify is the TLS Alert with value 0x00, so + // all TLS Alert codes are mapped to [QUIC_STATUS_CLOSE_NOTIFY, + // QUIC_STATUS_CLOSE_NOTIFY + 255] + // + // Mapped TLS alerts include following statuses: + // - QUIC_STATUS_CLOSE_NOTIFY + // - QUIC_STATUS_BAD_CERTIFICATE + // - QUIC_STATUS_UNSUPPORTED_CERTIFICATE + // - QUIC_STATUS_REVOKED_CERTIFICATE + // - QUIC_STATUS_EXPIRED_CERTIFICATE + // - QUIC_STATUS_UNKNOWN_CERTIFICATE + // - QUIC_STATUS_REQUIRED_CERTIFICATE + // + if ((uint)status >= (uint)QUIC_STATUS_CLOSE_NOTIFY && (uint)status < (uint)QUIC_STATUS_CLOSE_NOTIFY + 256) { - throw GetExceptionForMsQuicStatus(status, message); + int alert = status - QUIC_STATUS_CLOSE_NOTIFY; + return new AuthenticationException(SR.Format(SR.net_auth_tls_alert, alert)); } - } - internal static string GetErrorMessageForStatus(int status, string? message) - { - return (message ?? "Status code") + ": " + GetErrorMessageForStatus(status); + // + // for everything else, use general InternalError + // + return new QuicException(QuicError.InternalError, null, SR.Format(SR.net_quic_internal_error, GetErrorMessageForStatus(status, message))); } + } - internal static string GetErrorMessageForStatus(int status) + internal static void ThrowIfMsQuicError(int status, string? message = null) + { + if (StatusFailed(status)) { - if (status == QUIC_STATUS_SUCCESS) return "QUIC_STATUS_SUCCESS"; - else if (status == QUIC_STATUS_PENDING) return "QUIC_STATUS_PENDING"; - else if (status == QUIC_STATUS_CONTINUE) return "QUIC_STATUS_CONTINUE"; - else if (status == QUIC_STATUS_OUT_OF_MEMORY) return "QUIC_STATUS_OUT_OF_MEMORY"; - else if (status == QUIC_STATUS_INVALID_PARAMETER) return "QUIC_STATUS_INVALID_PARAMETER"; - else if (status == QUIC_STATUS_INVALID_STATE) return "QUIC_STATUS_INVALID_STATE"; - else if (status == QUIC_STATUS_NOT_SUPPORTED) return "QUIC_STATUS_NOT_SUPPORTED"; - else if (status == QUIC_STATUS_NOT_FOUND) return "QUIC_STATUS_NOT_FOUND"; - else if (status == QUIC_STATUS_BUFFER_TOO_SMALL) return "QUIC_STATUS_BUFFER_TOO_SMALL"; - else if (status == QUIC_STATUS_HANDSHAKE_FAILURE) return "QUIC_STATUS_HANDSHAKE_FAILURE"; - else if (status == QUIC_STATUS_ABORTED) return "QUIC_STATUS_ABORTED"; - else if (status == QUIC_STATUS_ADDRESS_IN_USE) return "QUIC_STATUS_ADDRESS_IN_USE"; - else if (status == QUIC_STATUS_INVALID_ADDRESS) return "QUIC_STATUS_INVALID_ADDRESS"; - else if (status == QUIC_STATUS_CONNECTION_TIMEOUT) return "QUIC_STATUS_CONNECTION_TIMEOUT"; - else if (status == QUIC_STATUS_CONNECTION_IDLE) return "QUIC_STATUS_CONNECTION_IDLE"; - else if (status == QUIC_STATUS_UNREACHABLE) return "QUIC_STATUS_UNREACHABLE"; - else if (status == QUIC_STATUS_INTERNAL_ERROR) return "QUIC_STATUS_INTERNAL_ERROR"; - else if (status == QUIC_STATUS_CONNECTION_REFUSED) return "QUIC_STATUS_CONNECTION_REFUSED"; - else if (status == QUIC_STATUS_PROTOCOL_ERROR) return "QUIC_STATUS_PROTOCOL_ERROR"; - else if (status == QUIC_STATUS_VER_NEG_ERROR) return "QUIC_STATUS_VER_NEG_ERROR"; - else if (status == QUIC_STATUS_TLS_ERROR) return "QUIC_STATUS_TLS_ERROR"; - else if (status == QUIC_STATUS_USER_CANCELED) return "QUIC_STATUS_USER_CANCELED"; - else if (status == QUIC_STATUS_ALPN_NEG_FAILURE) return "QUIC_STATUS_ALPN_NEG_FAILURE"; - else if (status == QUIC_STATUS_STREAM_LIMIT_REACHED) return "QUIC_STATUS_STREAM_LIMIT_REACHED"; - else if (status == QUIC_STATUS_CLOSE_NOTIFY) return "QUIC_STATUS_CLOSE_NOTIFY"; - else if (status == QUIC_STATUS_BAD_CERTIFICATE) return "QUIC_STATUS_BAD_CERTIFICATE"; - else if (status == QUIC_STATUS_UNSUPPORTED_CERTIFICATE) return "QUIC_STATUS_UNSUPPORTED_CERTIFICATE"; - else if (status == QUIC_STATUS_REVOKED_CERTIFICATE) return "QUIC_STATUS_REVOKED_CERTIFICATE"; - else if (status == QUIC_STATUS_EXPIRED_CERTIFICATE) return "QUIC_STATUS_EXPIRED_CERTIFICATE"; - else if (status == QUIC_STATUS_UNKNOWN_CERTIFICATE) return "QUIC_STATUS_UNKNOWN_CERTIFICATE"; - else if (status == QUIC_STATUS_REQUIRED_CERTIFICATE) return "QUIC_STATUS_REQUIRED_CERTIFICATE"; - else if (status == QUIC_STATUS_CERT_EXPIRED) return "QUIC_STATUS_CERT_EXPIRED"; - else if (status == QUIC_STATUS_CERT_UNTRUSTED_ROOT) return "QUIC_STATUS_CERT_UNTRUSTED_ROOT"; - else if (status == QUIC_STATUS_CERT_NO_CERT) return "QUIC_STATUS_CERT_NO_CERT"; - else return $"Unknown (0x{status:x})"; + throw GetExceptionForMsQuicStatus(status, message); } } + + internal static string GetErrorMessageForStatus(int status, string? message) + { + return (message ?? "Status code") + ": " + GetErrorMessageForStatus(status); + } + + internal static string GetErrorMessageForStatus(int status) + { + if (status == QUIC_STATUS_SUCCESS) return "QUIC_STATUS_SUCCESS"; + else if (status == QUIC_STATUS_PENDING) return "QUIC_STATUS_PENDING"; + else if (status == QUIC_STATUS_CONTINUE) return "QUIC_STATUS_CONTINUE"; + else if (status == QUIC_STATUS_OUT_OF_MEMORY) return "QUIC_STATUS_OUT_OF_MEMORY"; + else if (status == QUIC_STATUS_INVALID_PARAMETER) return "QUIC_STATUS_INVALID_PARAMETER"; + else if (status == QUIC_STATUS_INVALID_STATE) return "QUIC_STATUS_INVALID_STATE"; + else if (status == QUIC_STATUS_NOT_SUPPORTED) return "QUIC_STATUS_NOT_SUPPORTED"; + else if (status == QUIC_STATUS_NOT_FOUND) return "QUIC_STATUS_NOT_FOUND"; + else if (status == QUIC_STATUS_BUFFER_TOO_SMALL) return "QUIC_STATUS_BUFFER_TOO_SMALL"; + else if (status == QUIC_STATUS_HANDSHAKE_FAILURE) return "QUIC_STATUS_HANDSHAKE_FAILURE"; + else if (status == QUIC_STATUS_ABORTED) return "QUIC_STATUS_ABORTED"; + else if (status == QUIC_STATUS_ADDRESS_IN_USE) return "QUIC_STATUS_ADDRESS_IN_USE"; + else if (status == QUIC_STATUS_INVALID_ADDRESS) return "QUIC_STATUS_INVALID_ADDRESS"; + else if (status == QUIC_STATUS_CONNECTION_TIMEOUT) return "QUIC_STATUS_CONNECTION_TIMEOUT"; + else if (status == QUIC_STATUS_CONNECTION_IDLE) return "QUIC_STATUS_CONNECTION_IDLE"; + else if (status == QUIC_STATUS_UNREACHABLE) return "QUIC_STATUS_UNREACHABLE"; + else if (status == QUIC_STATUS_INTERNAL_ERROR) return "QUIC_STATUS_INTERNAL_ERROR"; + else if (status == QUIC_STATUS_CONNECTION_REFUSED) return "QUIC_STATUS_CONNECTION_REFUSED"; + else if (status == QUIC_STATUS_PROTOCOL_ERROR) return "QUIC_STATUS_PROTOCOL_ERROR"; + else if (status == QUIC_STATUS_VER_NEG_ERROR) return "QUIC_STATUS_VER_NEG_ERROR"; + else if (status == QUIC_STATUS_TLS_ERROR) return "QUIC_STATUS_TLS_ERROR"; + else if (status == QUIC_STATUS_USER_CANCELED) return "QUIC_STATUS_USER_CANCELED"; + else if (status == QUIC_STATUS_ALPN_NEG_FAILURE) return "QUIC_STATUS_ALPN_NEG_FAILURE"; + else if (status == QUIC_STATUS_STREAM_LIMIT_REACHED) return "QUIC_STATUS_STREAM_LIMIT_REACHED"; + else if (status == QUIC_STATUS_CLOSE_NOTIFY) return "QUIC_STATUS_CLOSE_NOTIFY"; + else if (status == QUIC_STATUS_BAD_CERTIFICATE) return "QUIC_STATUS_BAD_CERTIFICATE"; + else if (status == QUIC_STATUS_UNSUPPORTED_CERTIFICATE) return "QUIC_STATUS_UNSUPPORTED_CERTIFICATE"; + else if (status == QUIC_STATUS_REVOKED_CERTIFICATE) return "QUIC_STATUS_REVOKED_CERTIFICATE"; + else if (status == QUIC_STATUS_EXPIRED_CERTIFICATE) return "QUIC_STATUS_EXPIRED_CERTIFICATE"; + else if (status == QUIC_STATUS_UNKNOWN_CERTIFICATE) return "QUIC_STATUS_UNKNOWN_CERTIFICATE"; + else if (status == QUIC_STATUS_REQUIRED_CERTIFICATE) return "QUIC_STATUS_REQUIRED_CERTIFICATE"; + else if (status == QUIC_STATUS_CERT_EXPIRED) return "QUIC_STATUS_CERT_EXPIRED"; + else if (status == QUIC_STATUS_CERT_UNTRUSTED_ROOT) return "QUIC_STATUS_CERT_UNTRUSTED_ROOT"; + else if (status == QUIC_STATUS_CERT_NO_CERT) return "QUIC_STATUS_CERT_NO_CERT"; + else return $"Unknown (0x{status:x})"; + } } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ValueTaskSource.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ValueTaskSource.cs index d51864562c9e8..4a6c49b5df560 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ValueTaskSource.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ValueTaskSource.cs @@ -36,6 +36,7 @@ public ValueTaskSource(bool runContinuationsAsynchronously = true) } public bool IsCompleted => (State)Volatile.Read(ref Unsafe.As(ref _state)) == State.Completed; + public bool IsCompletedSuccessfully => IsCompleted && _valueTaskSource.GetStatus(_valueTaskSource.Version) == ValueTaskSourceStatus.Succeeded; public bool TryInitialize(out ValueTask valueTask, object? keepAlive = null, CancellationToken cancellationToken = default) { @@ -60,7 +61,7 @@ public bool TryInitialize(out ValueTask valueTask, object? keepAlive = null, Can State state = _state; - // If we're the first here and we will return true. + // If we're the first here, we will return true. if (state == State.None) { // Keep alive the caller object until the result is read from the task. @@ -95,13 +96,13 @@ private bool TryComplete(Exception? exception) _state = State.Completed; // Swap the cancellation registration so the one that's been registered gets eventually Disposed. - // Ideally, we would dispose it here, but if the callbacks kicks in, it tries to take the lock held by this thread. + // Ideally, we would dispose it here, but if the callbacks kicks in, it tries to take the lock held by this thread leading to deadlock. cancellationRegistration = _cancellationRegistration; _cancellationRegistration = default; if (exception is not null) { - // Set up the exception stack strace for the caller. + // Set up the exception stack trace for the caller. exception = exception.StackTrace is null ? ExceptionDispatchInfo.SetCurrentStackTrace(exception) : exception; _valueTaskSource.SetException(exception); } @@ -117,7 +118,7 @@ private bool TryComplete(Exception? exception) } finally { - // Un-root the the kept alive object in all cases. + // Un-root the kept alive object in all cases. if (_keepAlive.IsAllocated) { _keepAlive.Free(); @@ -128,7 +129,7 @@ private bool TryComplete(Exception? exception) finally { // Dispose the cancellation if registered. - // Must be done outside of lock since Dispose will wait on pending cancellation callbacks which requires taking the lock. + // Must be done outside of lock since Dispose will wait on pending cancellation callbacks which require taking the lock. cancellationRegistration.Dispose(); } } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/msquic.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Interop/msquic.cs similarity index 100% rename from src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/msquic.cs rename to src/libraries/System.Net.Quic/src/System/Net/Quic/Interop/msquic.cs diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/msquic_extensions.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Interop/msquic_extensions.cs similarity index 100% rename from src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/msquic_extensions.cs rename to src/libraries/System.Net.Quic/src/System/Net/Quic/Interop/msquic_extensions.cs diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/msquic_generated.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Interop/msquic_generated.cs similarity index 100% rename from src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/msquic_generated.cs rename to src/libraries/System.Net.Quic/src/System/Net/Quic/Interop/msquic_generated.cs diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/msquic_generated_linux.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Interop/msquic_generated_linux.cs similarity index 100% rename from src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/msquic_generated_linux.cs rename to src/libraries/System.Net.Quic/src/System/Net/Quic/Interop/msquic_generated_linux.cs diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/msquic_generated_macos.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Interop/msquic_generated_macos.cs similarity index 100% rename from src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/msquic_generated_macos.cs rename to src/libraries/System.Net.Quic/src/System/Net/Quic/Interop/msquic_generated_macos.cs diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/msquic_generated_windows.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Interop/msquic_generated_windows.cs similarity index 100% rename from src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/msquic_generated_windows.cs rename to src/libraries/System.Net.Quic/src/System/Net/Quic/Interop/msquic_generated_windows.cs diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicAbortDirection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicAbortDirection.cs new file mode 100644 index 0000000000000..d28d3b4cc5faa --- /dev/null +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicAbortDirection.cs @@ -0,0 +1,24 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Net.Quic; + +/// +/// Specifies direction of the which is to be aborted. +/// +[Flags] +public enum QuicAbortDirection +{ + /// + /// Abort read side of the stream. + /// + Read = 1, + /// + /// Abort write side of the stream. + /// + Write = 2, + /// + /// Abort both sides of the stream, i.e.: and ) at the same time. + /// + Both = Read | Write +} diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.SslConnectionOptions.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.SslConnectionOptions.cs new file mode 100644 index 0000000000000..0b9e13fb022ab --- /dev/null +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.SslConnectionOptions.cs @@ -0,0 +1,127 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net.Security; +using System.Security.Authentication; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using Microsoft.Quic; +using static Microsoft.Quic.MsQuic; + +namespace System.Net.Quic; + +public partial class QuicConnection +{ + private readonly struct SslConnectionOptions + { + private static readonly Oid s_serverAuthOid = new Oid("1.3.6.1.5.5.7.3.1", null); + private static readonly Oid s_clientAuthOid = new Oid("1.3.6.1.5.5.7.3.2", null); + + /// + /// The connection to which these options belong. + /// + private readonly QuicConnection _connection; + /// + /// Determines if the connection is outbound/client or inbound/server. + /// + private readonly bool _isClient; + /// + /// Host name send in SNI, set only for outbound/client connections. Configured via . + /// + private readonly string? _targetHost; + /// + /// Always true for outbound/client connections. Configured for inbound/server ones via . + /// + private readonly bool _certificateRequired; + /// + /// Configured via or . + /// + private readonly X509RevocationMode _revocationMode; + /// + /// Configured via or . + /// + private readonly RemoteCertificateValidationCallback? _validationCallback; + + public SslConnectionOptions(QuicConnection connection, bool isClient, string? targetHost, bool certificateRequired, X509RevocationMode revocationMode, RemoteCertificateValidationCallback? validationCallback) + { + _connection = connection; + _isClient = isClient; + _targetHost = targetHost; + _certificateRequired = certificateRequired; + _revocationMode = revocationMode; + _validationCallback = validationCallback; + } + + public unsafe int ValidateCertificate(QUIC_BUFFER* certificatePtr, QUIC_BUFFER* chainPtr, out X509Certificate2? certificate) + { + SslPolicyErrors sslPolicyErrors = SslPolicyErrors.None; + X509Chain? chain = null; + IntPtr certificateBuffer = default; + int certificateLength = default; + + certificate = null; + + if (certificatePtr is not null) + { + chain = new X509Chain(); + chain.ChainPolicy.RevocationMode = _revocationMode; + chain.ChainPolicy.RevocationFlag = X509RevocationFlag.ExcludeRoot; + chain.ChainPolicy.ApplicationPolicy.Add(_isClient ? s_serverAuthOid : s_clientAuthOid); + + if (OperatingSystem.IsWindows()) + { + certificate = new X509Certificate2((IntPtr)certificatePtr); + } + else + { + if (certificatePtr->Length > 0) + { + certificateBuffer = (IntPtr)certificatePtr->Buffer; + certificateLength = (int)certificatePtr->Length; + certificate = new X509Certificate2(certificatePtr->Span); + } + if (chainPtr->Length > 0) + { + X509Certificate2Collection additionalCertificates = new X509Certificate2Collection(); + additionalCertificates.Import(chainPtr->Span); + chain.ChainPolicy.ExtraStore.AddRange(additionalCertificates); + } + } + } + + if (certificate is not null) + { + sslPolicyErrors |= CertificateValidation.BuildChainAndVerifyProperties(chain!, certificate, true, !_isClient, _targetHost, certificateBuffer, certificateLength); + } + + if (certificate is null && _certificateRequired) + { + sslPolicyErrors |= SslPolicyErrors.RemoteCertificateNotAvailable; + } + + if (_validationCallback is not null) + { + if (!_validationCallback(_connection, certificate, chain, sslPolicyErrors)) + { + if (_isClient) + { + throw new AuthenticationException(SR.net_quic_cert_custom_validation); + } + return QUIC_STATUS_USER_CANCELED; + } + return QUIC_STATUS_SUCCESS; + } + + if (sslPolicyErrors != SslPolicyErrors.None) + { + if (_isClient) + { + throw new AuthenticationException(SR.Format(SR.net_quic_cert_chain_validation, sslPolicyErrors)); + } + return QUIC_STATUS_HANDSHAKE_FAILURE; + } + + return QUIC_STATUS_SUCCESS; + } + } +} diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs index d76d6c7700c8b..8b1411efc9a03 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs @@ -1,93 +1,613 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Net.Quic.Implementations; -using System.Net.Quic.Implementations.MsQuic; -using System.Net.Quic.Implementations.MsQuic.Internal; using System.Net.Security; +using System.Net.Sockets; +using System.Runtime.CompilerServices; +using System.Runtime.ExceptionServices; +using System.Runtime.InteropServices; using System.Security.Cryptography.X509Certificates; using System.Threading; +using System.Threading.Channels; using System.Threading.Tasks; +using Microsoft.Quic; +using static Microsoft.Quic.MsQuic; -namespace System.Net.Quic +using CONNECTED_DATA = Microsoft.Quic.QUIC_CONNECTION_EVENT._Anonymous_e__Union._CONNECTED_e__Struct; +using SHUTDOWN_INITIATED_BY_TRANSPORT_DATA = Microsoft.Quic.QUIC_CONNECTION_EVENT._Anonymous_e__Union._SHUTDOWN_INITIATED_BY_TRANSPORT_e__Struct; +using SHUTDOWN_INITIATED_BY_PEER_DATA = Microsoft.Quic.QUIC_CONNECTION_EVENT._Anonymous_e__Union._SHUTDOWN_INITIATED_BY_PEER_e__Struct; +using SHUTDOWN_COMPLETE_DATA = Microsoft.Quic.QUIC_CONNECTION_EVENT._Anonymous_e__Union._SHUTDOWN_COMPLETE_e__Struct; +using LOCAL_ADDRESS_CHANGED_DATA = Microsoft.Quic.QUIC_CONNECTION_EVENT._Anonymous_e__Union._LOCAL_ADDRESS_CHANGED_e__Struct; +using PEER_ADDRESS_CHANGED_DATA = Microsoft.Quic.QUIC_CONNECTION_EVENT._Anonymous_e__Union._PEER_ADDRESS_CHANGED_e__Struct; +using PEER_STREAM_STARTED_DATA = Microsoft.Quic.QUIC_CONNECTION_EVENT._Anonymous_e__Union._PEER_STREAM_STARTED_e__Struct; +using PEER_CERTIFICATE_RECEIVED_DATA = Microsoft.Quic.QUIC_CONNECTION_EVENT._Anonymous_e__Union._PEER_CERTIFICATE_RECEIVED_e__Struct; + +namespace System.Net.Quic; + +/// +/// Represents a QUIC connection, see RFC 9000: Connections for more details. +/// itself doesn't send or receive data but rather allows opening and/or accepting multiple . +/// +/// +/// can either be accepted from (inbound connection), +/// or create with a static method (outbound connection). +/// +/// Each connection can then open outbound stream: , +/// or accept an inbound stream: . +/// +/// After all the streams have been finished, connection should be properly closed with an application code: . +/// If not, the connection will not send the peer information about being closed and the peer's connection will have to wait on its idle timeout. +/// +public sealed partial class QuicConnection : IAsyncDisposable { - public sealed class QuicConnection : IDisposable + /// + /// Returns true if QUIC is supported on the current machine and can be used; otherwise, false. + /// + /// + /// The current implementation depends on MsQuic native library, this property checks its presence (Linux machines). + /// It also checks whether TLS 1.3, requirement for QUIC protocol, is available and enabled (Windows machines). + /// + public static bool IsSupported => MsQuicApi.IsQuicSupported; + + /// + /// Creates a new and connects it to the peer. + /// + /// Options for the connection. + /// A cancellation token that can be used to cancel the asynchronous operation. + /// An asynchronous task that completes with the connected connection. + public static async ValueTask ConnectAsync(QuicClientConnectionOptions options, CancellationToken cancellationToken = default) + { + if (!IsSupported) + { + throw new PlatformNotSupportedException(SR.SystemNetQuic_PlatformNotSupported); + } + + // Validate and fill in defaults for the options. + options.Validate(nameof(options)); + + QuicConnection connection = new QuicConnection(); + try + { + await connection.FinishConnectAsync(options, cancellationToken).ConfigureAwait(false); + } + catch + { + await connection.DisposeAsync().ConfigureAwait(false); + throw; + } + return connection; + } + + /// + /// Handle to MsQuic connection object. + /// + private MsQuicContextSafeHandle _handle; + + /// + /// Set to non-zero once disposed. Prevents double and/or concurrent disposal. + /// + private int _disposed; + + private readonly ValueTaskSource _connectedTcs = new ValueTaskSource(); + private readonly ValueTaskSource _shutdownTcs = new ValueTaskSource(); + + private readonly Channel _acceptQueue = Channel.CreateUnbounded(new UnboundedChannelOptions() + { + SingleWriter = true + }); + + /// + /// Holds options to validate peer certificate. + /// Set up either in for an inbound connection or in for an outbound. + /// + private SslConnectionOptions _sslConnectionOptions; + /// + /// Holds MsQuic connection configuration. + /// Set up either in for an inbound connection or in for an outbound. + /// + private MsQuicSafeHandle? _configuration; + + /// + /// Used by to throw in case no stream can be opened from the peer. + /// true when at least one of or is greater than 0. + /// + private bool _canAccept; + /// + /// From , passed to newly created . + /// + private long _defaultStreamErrorCode; + /// + /// From , used to close connection in . + /// + private long _defaultCloseErrorCode; + + // TODO: remove once/if https://github.com/microsoft/msquic/pull/2883 is merged + internal sealed class State + { + public long AbortErrorCode = -1; + } + private State _state = new State(); + + /// + /// Set when CONNECTED is received or inside the constructor for an inbound connection from NEW_CONNECTION data. + /// + private IPEndPoint _remoteEndPoint = null!; + /// + /// Set when CONNECTED is received or inside the constructor for an inbound connection from NEW_CONNECTION data. + /// + private IPEndPoint _localEndPoint = null!; + /// + /// Keeps track whether has been accessed so that we know whether to dispose the certificate or not. + /// + private bool _remoteCertificateExposed; + /// + /// Set when PEER_CERTIFICATE_RECEIVED is received (before CONNECTED). + /// For an outbound/client connection will always have the peer's (server) certificate; for an inbound/server one, only if the connection requested and the peer (client) provided one. + /// + private X509Certificate2? _remoteCertificate; + /// + /// Set when CONNECTED is received. + /// + private SslApplicationProtocol _negotiatedApplicationProtocol; + + /// + /// The remote endpoint used for this connection. + /// + public IPEndPoint RemoteEndPoint => _remoteEndPoint; + /// + /// The local endpoint used for this connection. + /// + public IPEndPoint LocalEndPoint => _localEndPoint; + + /// + /// The certificate provided by the peer. + /// For an outbound/client connection will always have the peer's (server) certificate; for an inbound/server one, only if the connection requested and the peer (client) provided one. + /// + public X509Certificate? RemoteCertificate + { + get + { + _remoteCertificateExposed = true; + return _remoteCertificate; + } + } + + /// + /// Final, negotiated application protocol. + /// + public SslApplicationProtocol NegotiatedApplicationProtocol => _negotiatedApplicationProtocol; + + public override string ToString() => _handle.ToString(); + + /// + /// Initializes a new instance of an outbound . + /// + private unsafe QuicConnection() + { + GCHandle context = GCHandle.Alloc(this, GCHandleType.Weak); + try + { + QUIC_HANDLE* handle; + ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->ConnectionOpen( + MsQuicApi.Api.Registration.QuicHandle, + &NativeCallback, + (void*)GCHandle.ToIntPtr(context), + &handle), + "ConnectionOpen failed"); + _handle = new MsQuicContextSafeHandle(handle, context, MsQuicApi.Api.ApiTable->ConnectionClose, SafeHandleType.Connection); + } + catch + { + context.Free(); + throw; + } + } + + /// + /// Initializes a new instance of an inbound . + /// + /// Native handle. + /// Related data from the NEW_CONNECTION listener event. + internal unsafe QuicConnection(QUIC_HANDLE* handle, QUIC_NEW_CONNECTION_INFO* info) + { + GCHandle context = GCHandle.Alloc(this, GCHandleType.Weak); + try + { + delegate* unmanaged[Cdecl] nativeCallback = &NativeCallback; + MsQuicApi.Api.ApiTable->SetCallbackHandler( + handle, + nativeCallback, + (void*)GCHandle.ToIntPtr(context)); + _handle = new MsQuicContextSafeHandle(handle, context, MsQuicApi.Api.ApiTable->ConnectionClose, SafeHandleType.Connection); + } + catch + { + context.Free(); + throw; + } + + _remoteEndPoint = info->RemoteAddress->ToIPEndPoint(); + _localEndPoint = info->LocalAddress->ToIPEndPoint(); + } + + private async ValueTask FinishConnectAsync(QuicClientConnectionOptions options, CancellationToken cancellationToken = default) + { + ObjectDisposedException.ThrowIf(_disposed == 1, this); + + if (_connectedTcs.TryInitialize(out ValueTask valueTask, this, cancellationToken)) + { + _canAccept = options.MaxInboundBidirectionalStreams > 0 || options.MaxInboundUnidirectionalStreams > 0; + _defaultStreamErrorCode = options.DefaultStreamErrorCode; + _defaultCloseErrorCode = options.DefaultCloseErrorCode; + + if (!options.RemoteEndPoint.TryParse(out string? host, out IPAddress? address, out int port)) + { + throw new ArgumentException(SR.Format(SR.net_quic_unsupported_endpoint_type, options.RemoteEndPoint.GetType()), nameof(options)); + } + int addressFamily = QUIC_ADDRESS_FAMILY_UNSPEC; + + // RemoteEndPoint is either IPEndPoint or DnsEndPoint containing IPAddress string. + // --> Set the IP directly, no name resolution needed. + if (address is not null) + { + QuicAddr quicAddress = new IPEndPoint(address, port).ToQuicAddr(); + MsQuicHelpers.SetMsQuicParameter(_handle, QUIC_PARAM_CONN_REMOTE_ADDRESS, quicAddress); + } + // RemoteEndPoint is DnsEndPoint containing hostname that is different from requested SNI. + // --> Resolve the hostname and set the IP directly, use requested SNI in ConnectionStart. + else if (host is not null && + !host.Equals(options.ClientAuthenticationOptions.TargetHost, StringComparison.InvariantCultureIgnoreCase)) + { + IPAddress[] addresses = await Dns.GetHostAddressesAsync(host!, cancellationToken).ConfigureAwait(false); + cancellationToken.ThrowIfCancellationRequested(); + if (addresses.Length == 0) + { + throw new SocketException((int)SocketError.HostNotFound); + } + + QuicAddr quicAddress = new IPEndPoint(addresses[0], port).ToQuicAddr(); + MsQuicHelpers.SetMsQuicParameter(_handle, QUIC_PARAM_CONN_REMOTE_ADDRESS, quicAddress); + } + // RemoteEndPoint is DnsEndPoint containing hostname that is the same as the requested SNI. + // --> Let MsQuic resolve the hostname/SNI, give address family hint is specified in DnsEndPoint. + else + { + if (options.RemoteEndPoint.AddressFamily == AddressFamily.InterNetwork) + { + addressFamily = QUIC_ADDRESS_FAMILY_INET; + } + if (options.RemoteEndPoint.AddressFamily == AddressFamily.InterNetworkV6) + { + addressFamily = QUIC_ADDRESS_FAMILY_INET6; + } + } + + if (options.LocalEndPoint is not null) + { + QuicAddr quicAddress = options.LocalEndPoint.ToQuicAddr(); + MsQuicHelpers.SetMsQuicParameter(_handle, QUIC_PARAM_CONN_LOCAL_ADDRESS, quicAddress); + } + + _sslConnectionOptions = new SslConnectionOptions( + this, + isClient: true, + options.ClientAuthenticationOptions.TargetHost, + certificateRequired: true, + options.ClientAuthenticationOptions.CertificateRevocationCheckMode, + options.ClientAuthenticationOptions.RemoteCertificateValidationCallback); + _configuration = MsQuicConfiguration.Create(options); + + IntPtr targetHostPtr = Marshal.StringToCoTaskMemUTF8(options.ClientAuthenticationOptions.TargetHost ?? host ?? address?.ToString()); + try + { + unsafe + { + ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->ConnectionStart( + _handle.QuicHandle, + _configuration.QuicHandle, + (ushort)addressFamily, + (sbyte*)targetHostPtr, + (ushort)port), + "ConnectionStart failed"); + } + } + finally + { + Marshal.FreeCoTaskMem(targetHostPtr); + } + } + + await valueTask.ConfigureAwait(false); + } + + internal ValueTask FinishHandshakeAsync(QuicServerConnectionOptions options, string? targetHost, CancellationToken cancellationToken = default) + { + ObjectDisposedException.ThrowIf(_disposed == 1, this); + + if (_connectedTcs.TryInitialize(out ValueTask valueTask, this, cancellationToken)) + { + _canAccept = options.MaxInboundBidirectionalStreams > 0 || options.MaxInboundUnidirectionalStreams > 0; + _defaultStreamErrorCode = options.DefaultStreamErrorCode; + _defaultCloseErrorCode = options.DefaultCloseErrorCode; + + _sslConnectionOptions = new SslConnectionOptions( + this, + isClient: false, + targetHost: null, + options.ServerAuthenticationOptions.ClientCertificateRequired, + options.ServerAuthenticationOptions.CertificateRevocationCheckMode, + options.ServerAuthenticationOptions.RemoteCertificateValidationCallback); + _configuration = MsQuicConfiguration.Create(options, targetHost); + + unsafe + { + ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->ConnectionSetConfiguration( + _handle.QuicHandle, + _configuration.QuicHandle), + "ConnectionSetConfiguration failed"); + } + } + + return valueTask; + } + + /// + /// Create an outbound uni/bidirectional . + /// In case the connection doesn't have any available stream capacity, i.e.: the peer limits the concurrent stream count, + /// the operation will pend until the stream can be opened (other stream gets closed or peer increases the limit). + /// + /// The type of the stream, i.e. unidirectional or bidirectional. + /// A cancellation token that can be used to cancel the asynchronous operation. + /// An asynchronous task that completes with the opened . + public async ValueTask OpenOutboundStreamAsync(QuicStreamType type, CancellationToken cancellationToken = default) { - public static bool IsSupported => MsQuicApi.IsQuicSupported; + ObjectDisposedException.ThrowIf(_disposed == 1, this); - public static ValueTask ConnectAsync(QuicClientConnectionOptions options, CancellationToken cancellationToken = default) + QuicStream? stream = null; + try { - if (!IsSupported) + stream = new QuicStream(_state, _handle, type, _defaultStreamErrorCode); + await stream.StartAsync(cancellationToken).ConfigureAwait(false); + } + catch + { + if (stream is not null) + { + await stream.DisposeAsync().ConfigureAwait(false); + } + // Propagate connection error if present. + if (_acceptQueue.Reader.Completion.IsFaulted) { - throw new PlatformNotSupportedException(SR.SystemNetQuic_PlatformNotSupported); + await _acceptQueue.Reader.Completion.ConfigureAwait(false); } + throw; + } + return stream; + } + + /// + /// Accepts an inbound . + /// + /// A cancellation token that can be used to cancel the asynchronous operation. + /// An asynchronous task that completes with the accepted . + public async ValueTask AcceptInboundStreamAsync(CancellationToken cancellationToken = default) + { + ObjectDisposedException.ThrowIf(_disposed == 1, this); - return ValueTask.FromResult(new QuicConnection(new MsQuicConnection(options))); + if (!_canAccept) + { + throw new InvalidOperationException(SR.net_quic_accept_not_allowed); + } + + try + { + return await _acceptQueue.Reader.ReadAsync(cancellationToken).ConfigureAwait(false); + } + catch (ChannelClosedException ex) when (ex.InnerException is not null) + { + ExceptionDispatchInfo.Capture(ex.InnerException).Throw(); + throw; } + } - private readonly MsQuicConnection _provider; + /// + /// Closes the connection with the application provided code, see RFC 9000: Connection Termination for more details. + /// + /// + /// Connection close is not graceful in regards to its streams, i.e.: calling will immediately abort all streams associated with this connection. + /// Make sure, that all streams have been closed and all their data consumed before calling this method; + /// otherwise, all the data that were received but not consumed yet, will be lost. + /// + /// If is not called before disposing the connection, + /// the will be used by to close the connection. + /// + /// Application provided code with the reason for closure. + /// A cancellation token that can be used to cancel the asynchronous operation. + /// An asynchronous task that completes when the connection is closed. + public ValueTask CloseAsync(long errorCode, CancellationToken cancellationToken = default) + { + ObjectDisposedException.ThrowIf(_disposed == 1, this); - internal QuicConnection(MsQuicConnection provider) + if (_shutdownTcs.TryInitialize(out ValueTask valueTask, this, cancellationToken)) { - _provider = provider; + unsafe + { + MsQuicApi.Api.ApiTable->ConnectionShutdown( + _handle.QuicHandle, + QUIC_CONNECTION_SHUTDOWN_FLAGS.NONE, + (ulong)errorCode); + } } - /// - /// Indicates whether the QuicConnection is connected. - /// - public bool Connected => _provider.Connected; + return valueTask; + } - public IPEndPoint? LocalEndPoint => _provider.LocalEndPoint; + private unsafe int HandleEventConnected(ref CONNECTED_DATA data) + { + _negotiatedApplicationProtocol = new SslApplicationProtocol(new Span(data.NegotiatedAlpn, data.NegotiatedAlpnLength).ToArray()); - public EndPoint RemoteEndPoint => _provider.RemoteEndPoint; + QuicAddr remoteAddress = MsQuicHelpers.GetMsQuicParameter(_handle, QUIC_PARAM_CONN_REMOTE_ADDRESS); + _remoteEndPoint = remoteAddress.ToIPEndPoint(); - public X509Certificate? RemoteCertificate => _provider.RemoteCertificate; + QuicAddr localAddress = MsQuicHelpers.GetMsQuicParameter(_handle, QUIC_PARAM_CONN_LOCAL_ADDRESS); + _localEndPoint = localAddress.ToIPEndPoint(); - public SslApplicationProtocol NegotiatedApplicationProtocol => _provider.NegotiatedApplicationProtocol; + _connectedTcs.TrySetResult(); - internal ValueTask FinishHandshakeAsync(QuicServerConnectionOptions options, string? targetHost, CancellationToken cancellationToken = default) => _provider.FinishHandshakeAsync(options, targetHost, cancellationToken); + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Info(this, $"{this} Connection connected {LocalEndPoint} -> {RemoteEndPoint}"); + } + return QUIC_STATUS_SUCCESS; + } + private unsafe int HandleEventShutdownInitiatedByTransport(ref SHUTDOWN_INITIATED_BY_TRANSPORT_DATA data) + { + _state.AbortErrorCode = 0; + Exception exception = ExceptionDispatchInfo.SetCurrentStackTrace(ThrowHelper.GetExceptionForMsQuicStatus(data.Status)); + _connectedTcs.TrySetException(exception); + _acceptQueue.Writer.TryComplete(exception); + return QUIC_STATUS_SUCCESS; + } + private unsafe int HandleEventShutdownInitiatedByPeer(ref SHUTDOWN_INITIATED_BY_PEER_DATA data) + { + _state.AbortErrorCode = (long)data.ErrorCode; + _acceptQueue.Writer.TryComplete(ExceptionDispatchInfo.SetCurrentStackTrace(ThrowHelper.GetConnectionAbortedException((long)data.ErrorCode))); + return QUIC_STATUS_SUCCESS; + } + private unsafe int HandleEventShutdownComplete(ref SHUTDOWN_COMPLETE_DATA data) + { + _shutdownTcs.TrySetResult(); + _acceptQueue.Writer.TryComplete(ExceptionDispatchInfo.SetCurrentStackTrace(ThrowHelper.GetOperationAbortedException())); + return QUIC_STATUS_SUCCESS; + } + private unsafe int HandleEventLocalAddressChanged(ref LOCAL_ADDRESS_CHANGED_DATA data) + { + _localEndPoint = data.Address->ToIPEndPoint(); + return QUIC_STATUS_SUCCESS; + } + private unsafe int HandleEventPeerAddressChanged(ref PEER_ADDRESS_CHANGED_DATA data) + { + _remoteEndPoint = data.Address->ToIPEndPoint(); + return QUIC_STATUS_SUCCESS; + } + private unsafe int HandleEventPeerStreamStarted(ref PEER_STREAM_STARTED_DATA data) + { + QuicStream stream = new QuicStream(_state, _handle, data.Stream, data.Flags, _defaultStreamErrorCode); + if (!_acceptQueue.Writer.TryWrite(stream)) + { + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Error(this, $"{this} Unable to enqueue incoming stream {stream}"); + } + stream.Dispose(); + return QUIC_STATUS_SUCCESS; + } + return QUIC_STATUS_SUCCESS; + } + private unsafe int HandleEventPeerCertificateReceived(ref PEER_CERTIFICATE_RECEIVED_DATA data) + { + try + { + return _sslConnectionOptions.ValidateCertificate((QUIC_BUFFER*)data.Certificate, (QUIC_BUFFER*)data.Chain, out _remoteCertificate); + } + catch (Exception ex) + { + _connectedTcs.TrySetException(ex); + return QUIC_STATUS_HANDSHAKE_FAILURE; + } + } - /// - /// Connect to the remote endpoint. - /// - /// - /// - public ValueTask ConnectAsync(CancellationToken cancellationToken = default) => _provider.ConnectAsync(cancellationToken); + private unsafe int HandleConnectionEvent(ref QUIC_CONNECTION_EVENT connectionEvent) + => connectionEvent.Type switch + { + QUIC_CONNECTION_EVENT_TYPE.CONNECTED => HandleEventConnected(ref connectionEvent.CONNECTED), + QUIC_CONNECTION_EVENT_TYPE.SHUTDOWN_INITIATED_BY_TRANSPORT => HandleEventShutdownInitiatedByTransport(ref connectionEvent.SHUTDOWN_INITIATED_BY_TRANSPORT), + QUIC_CONNECTION_EVENT_TYPE.SHUTDOWN_INITIATED_BY_PEER => HandleEventShutdownInitiatedByPeer(ref connectionEvent.SHUTDOWN_INITIATED_BY_PEER), + QUIC_CONNECTION_EVENT_TYPE.SHUTDOWN_COMPLETE => HandleEventShutdownComplete(ref connectionEvent.SHUTDOWN_COMPLETE), + QUIC_CONNECTION_EVENT_TYPE.LOCAL_ADDRESS_CHANGED => HandleEventLocalAddressChanged(ref connectionEvent.LOCAL_ADDRESS_CHANGED), + QUIC_CONNECTION_EVENT_TYPE.PEER_ADDRESS_CHANGED => HandleEventPeerAddressChanged(ref connectionEvent.PEER_ADDRESS_CHANGED), + QUIC_CONNECTION_EVENT_TYPE.PEER_STREAM_STARTED => HandleEventPeerStreamStarted(ref connectionEvent.PEER_STREAM_STARTED), + QUIC_CONNECTION_EVENT_TYPE.PEER_CERTIFICATE_RECEIVED => HandleEventPeerCertificateReceived(ref connectionEvent.PEER_CERTIFICATE_RECEIVED), + _ => QUIC_STATUS_SUCCESS + }; - /// - /// Create an outbound unidirectional stream. - /// - /// - public async ValueTask OpenUnidirectionalStreamAsync(CancellationToken cancellationToken = default) => new QuicStream(await _provider.OpenUnidirectionalStreamAsync(cancellationToken).ConfigureAwait(false)); +#pragma warning disable CS3016 + [UnmanagedCallersOnly(CallConvs = new Type[] { typeof(CallConvCdecl) })] +#pragma warning restore CS3016 + private static unsafe int NativeCallback(QUIC_HANDLE* connection, void* context, QUIC_CONNECTION_EVENT* connectionEvent) + { + GCHandle stateHandle = GCHandle.FromIntPtr((IntPtr)context); - /// - /// Create an outbound bidirectional stream. - /// - /// - public async ValueTask OpenBidirectionalStreamAsync(CancellationToken cancellationToken = default) => new QuicStream(await _provider.OpenBidirectionalStreamAsync(cancellationToken).ConfigureAwait(false)); + // Check if the instance hasn't been collected. + if (!stateHandle.IsAllocated || stateHandle.Target is not QuicConnection instance) + { + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Error(null, $"Received event {connectionEvent->Type} while connection is already disposed"); + } + return QUIC_STATUS_INVALID_STATE; + } + try + { + // Process the event. + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Info(instance, $"{instance} Received event {connectionEvent->Type}"); + } + return instance.HandleConnectionEvent(ref *connectionEvent); + } + catch (Exception ex) + { + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Error(instance, $"{instance} Exception while processing event {connectionEvent->Type}: {ex}"); + } + return QUIC_STATUS_INTERNAL_ERROR; + } + } - /// - /// Accept an incoming stream. - /// - /// - public async ValueTask AcceptStreamAsync(CancellationToken cancellationToken = default) => new QuicStream(await _provider.AcceptStreamAsync(cancellationToken).ConfigureAwait(false)); + /// + /// If not closed explicitly by , closes the connection silently (leading to idle timeout on the peer side). + /// And releases all resources associated with the connection. + /// + /// A task that represents the asynchronous dispose operation. + public async ValueTask DisposeAsync() + { + if (Interlocked.Exchange(ref _disposed, 1) != 0) + { + return; + } - /// - /// Close the connection and terminate any active streams. - /// - public ValueTask CloseAsync(long errorCode, CancellationToken cancellationToken = default) => _provider.CloseAsync(errorCode, cancellationToken); + // Check if the connection has been shut down and if not, shut it down silently. + if (_shutdownTcs.TryInitialize(out ValueTask valueTask, this)) + { + unsafe + { + MsQuicApi.Api.ApiTable->ConnectionShutdown( + _handle.QuicHandle, + QUIC_CONNECTION_SHUTDOWN_FLAGS.NONE, + (ulong)_defaultCloseErrorCode); + } + } - public void Dispose() => _provider.Dispose(); + // Wait for SHUTDOWN_COMPLETE, the last event, so that all resources can be safely released. + await valueTask.ConfigureAwait(false); + _handle.Dispose(); - /// - /// Gets the maximum number of bidirectional streams that can be made to the peer. - /// - public int GetRemoteAvailableUnidirectionalStreamCount() => _provider.GetRemoteAvailableUnidirectionalStreamCount(); + _configuration?.Dispose(); - /// - /// Gets the maximum number of unidirectional streams that can be made to the peer. - /// - public int GetRemoteAvailableBidirectionalStreamCount() => _provider.GetRemoteAvailableBidirectionalStreamCount(); + // Dispose remote certificate only if it hasn't been accessed via getter, in which case the accessing code becomes the owner of the certificate lifetime. + if (!_remoteCertificateExposed) + { + _remoteCertificate?.Dispose(); + } + + // Flush the queue and dispose all remaining streams. + _acceptQueue.Writer.TryComplete(ExceptionDispatchInfo.SetCurrentStackTrace(ThrowHelper.GetOperationAbortedException())); + while (_acceptQueue.Reader.TryRead(out QuicStream? stream)) + { + await stream.DisposeAsync().ConfigureAwait(false); + } } } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnectionOptions.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnectionOptions.cs index 34f7aa47d0d9a..30dd916404dc7 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnectionOptions.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnectionOptions.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Net.Security; +using System.Threading; namespace System.Net.Quic; @@ -20,19 +21,62 @@ internal QuicConnectionOptions() /// Limit on the number of bidirectional streams the remote peer connection can create on an open connection. /// Default 0 for client and 100 for server connection. /// - public int MaxBidirectionalStreams { get; set; } + public int MaxInboundBidirectionalStreams { get; set; } /// /// Limit on the number of unidirectional streams the remote peer connection can create on an open connection. /// Default 0 for client and 10 for server connection. /// - public int MaxUnidirectionalStreams { get; set; } + public int MaxInboundUnidirectionalStreams { get; set; } /// /// Idle timeout for connections, after which the connection will be closed. /// Default means underlying implementation default idle timeout. /// public TimeSpan IdleTimeout { get; set; } = TimeSpan.Zero; + + /// + /// Error code used when the stream needs to abort read or write side of the stream internally. + /// + // QUIC doesn't allow negative value: https://www.rfc-editor.org/rfc/rfc9000.html#integer-encoding + // We can safely use this to distinguish if user provided value during validation. + public long DefaultStreamErrorCode { get; set; } = -1; + + /// + /// Error code used for when the connection gets disposed. + /// To use different close error code, call explicitly before disposing. + /// + // QUIC doesn't allow negative value: https://www.rfc-editor.org/rfc/rfc9000.html#integer-encoding + // We can safely use this to distinguish if user provided value during validation. + public long DefaultCloseErrorCode { get; set; } = -1; + + /// + /// Validates the options and potentially sets platform specific defaults. + /// + /// Name of the from the caller. + internal virtual void Validate(string argumentName) + { + if (MaxInboundBidirectionalStreams < 0 || MaxInboundBidirectionalStreams > ushort.MaxValue) + { + throw new ArgumentNullException(SR.Format(SR.net_quic_in_range, nameof(QuicConnectionOptions.MaxInboundBidirectionalStreams), ushort.MaxValue), argumentName); + } + if (MaxInboundUnidirectionalStreams < 0 || MaxInboundUnidirectionalStreams > ushort.MaxValue) + { + throw new ArgumentNullException(SR.Format(SR.net_quic_in_range, nameof(QuicConnectionOptions.MaxInboundUnidirectionalStreams), ushort.MaxValue), argumentName); + } + if (IdleTimeout < TimeSpan.Zero && IdleTimeout != Timeout.InfiniteTimeSpan) + { + throw new ArgumentOutOfRangeException(nameof(QuicConnectionOptions.IdleTimeout), SR.net_quic_timeout_use_gt_zero); + } + if (DefaultStreamErrorCode < 0 || DefaultStreamErrorCode > QuicDefaults.MaxErrorCodeValue) + { + throw new ArgumentNullException(SR.Format(SR.net_quic_in_range, nameof(QuicConnectionOptions.DefaultStreamErrorCode), QuicDefaults.MaxErrorCodeValue), argumentName); + } + if (DefaultCloseErrorCode < 0 || DefaultCloseErrorCode > QuicDefaults.MaxErrorCodeValue) + { + throw new ArgumentNullException(SR.Format(SR.net_quic_in_range, nameof(QuicConnectionOptions.DefaultCloseErrorCode), QuicDefaults.MaxErrorCodeValue), argumentName); + } + } } /// @@ -40,15 +84,24 @@ internal QuicConnectionOptions() /// public sealed class QuicClientConnectionOptions : QuicConnectionOptions { + /// + /// Initializes a new instance of the class. + /// + public QuicClientConnectionOptions() + { + MaxInboundBidirectionalStreams = QuicDefaults.DefaultClientMaxInboundBidirectionalStreams; + MaxInboundUnidirectionalStreams = QuicDefaults.DefaultClientMaxInboundUnidirectionalStreams; + } + /// /// Client authentication options to use when establishing a new connection. /// - public required SslClientAuthenticationOptions ClientAuthenticationOptions { get; set; } + public SslClientAuthenticationOptions ClientAuthenticationOptions { get; set; } = null!; /// /// The remote endpoint to connect to. May be both , which will get resolved to an IP before connecting, or directly . /// - public required EndPoint RemoteEndPoint { get; set; } + public EndPoint RemoteEndPoint { get; set; } = null!; /// /// The optional local endpoint that will be bound to. @@ -56,12 +109,22 @@ public sealed class QuicClientConnectionOptions : QuicConnectionOptions public IPEndPoint? LocalEndPoint { get; set; } /// - /// Initializes a new instance of the class. + /// Validates the options and potentially sets platform specific defaults. /// - public QuicClientConnectionOptions() + /// Name of the from the caller. + internal override void Validate(string argumentName) { - MaxBidirectionalStreams = 0; - MaxUnidirectionalStreams = 0; + base.Validate(argumentName); + + // The content of ClientAuthenticationOptions gets validate in MsQuicConfiguration.Create. + if (ClientAuthenticationOptions is null) + { + throw new ArgumentNullException(SR.Format(SR.net_quic_not_null_open_connection, nameof(QuicClientConnectionOptions.ClientAuthenticationOptions)), argumentName); + } + if (RemoteEndPoint is null) + { + throw new ArgumentNullException(SR.Format(SR.net_quic_not_null_open_connection, nameof(QuicClientConnectionOptions.RemoteEndPoint)), argumentName); + } } } @@ -70,17 +133,32 @@ public QuicClientConnectionOptions() /// public sealed class QuicServerConnectionOptions : QuicConnectionOptions { + /// + /// Initializes a new instance of the class. + /// + public QuicServerConnectionOptions() + { + MaxInboundBidirectionalStreams = QuicDefaults.DefaultServerMaxInboundBidirectionalStreams; + MaxInboundUnidirectionalStreams = QuicDefaults.DefaultServerMaxInboundUnidirectionalStreams; + } + /// /// Server authentication options to use when accepting a new connection. /// - public required SslServerAuthenticationOptions ServerAuthenticationOptions { get; set; } + public SslServerAuthenticationOptions ServerAuthenticationOptions { get; set; } = null!; /// - /// Initializes a new instance of the class. + /// Validates the options and potentially sets platform specific defaults. /// - public QuicServerConnectionOptions() + /// Name of the from the caller. + internal override void Validate(string argumentName) { - MaxBidirectionalStreams = 100; - MaxUnidirectionalStreams = 10; + base.Validate(argumentName); + + // The content of ServerAuthenticationOptions gets validate in MsQuicConfiguration.Create. + if (ServerAuthenticationOptions is null) + { + throw new ArgumentNullException(SR.Format(SR.net_quic_not_null_accept_connection, nameof(QuicServerConnectionOptions.ServerAuthenticationOptions)), argumentName); + } } } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicDefaults.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicDefaults.cs new file mode 100644 index 0000000000000..08eacf261caef --- /dev/null +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicDefaults.cs @@ -0,0 +1,35 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Net.Quic; + +/// +/// Default values for , and . +/// +internal static partial class QuicDefaults +{ + /// + /// . + /// + public const int DefaultListenBacklog = 512; + /// + /// .. + /// + public const int DefaultClientMaxInboundBidirectionalStreams = 0; + /// + /// .. + /// + public const int DefaultClientMaxInboundUnidirectionalStreams = 0; + /// + /// .. + /// + public const int DefaultServerMaxInboundBidirectionalStreams = 100; + /// + /// .. + /// + public const int DefaultServerMaxInboundUnidirectionalStreams = 10; + /// + /// Max value for application error codes that can be sent by QUIC, see . + /// + public const long MaxErrorCodeValue = (1L << 62) - 1; +} diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicError.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicError.cs index 884360b5332c8..7207ebbaa33c1 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicError.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicError.cs @@ -14,7 +14,7 @@ public enum QuicError Success, /// - /// An internal implementation error has occured. + /// An internal implementation error has occurred. /// InternalError, diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.PendingConnection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.PendingConnection.cs index fe51fd3774ec0..3b139ccc61663 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.PendingConnection.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.PendingConnection.cs @@ -57,6 +57,7 @@ public async void StartHandshake(QuicConnection connection, SslClientHelloInfo c { _cancellationTokenSource.CancelAfter(s_handshakeTimeout); QuicServerConnectionOptions options = await connectionOptionsCallback(connection, clientHello, _cancellationTokenSource.Token).ConfigureAwait(false); + options.Validate(nameof(options)); // Validate and fill in defaults for the options. await connection.FinishHandshakeAsync(options, clientHello.ServerName, _cancellationTokenSource.Token).ConfigureAwait(false); _finishHandshakeTask.SetResult(connection); } @@ -68,11 +69,11 @@ public async void StartHandshake(QuicConnection connection, SslClientHelloInfo c if (NetEventSource.Log.IsEnabled()) { - NetEventSource.Error(connection, $"Connection handshake failed: {ex}"); + NetEventSource.Error(connection, $"{connection} Connection handshake failed: {ex}"); } await connection.CloseAsync(default).ConfigureAwait(false); - connection.Dispose(); + await connection.DisposeAsync().ConfigureAwait(false); _finishHandshakeTask.SetResult(null); } } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs index 5b895fe4a9396..f1b9ee1e40c29 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs @@ -1,8 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Net.Quic.Implementations.MsQuic; -using System.Net.Quic.Implementations.MsQuic.Internal; using System.Net.Security; using System.Runtime.CompilerServices; using System.Runtime.ExceptionServices; @@ -12,6 +10,7 @@ using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.Quic; +using static System.Net.Quic.MsQuicHelpers; using static Microsoft.Quic.MsQuic; using NEW_CONNECTION_DATA = Microsoft.Quic.QUIC_LISTENER_EVENT._Anonymous_e__Union._NEW_CONNECTION_e__Struct; @@ -52,20 +51,13 @@ public static ValueTask ListenAsync(QuicListenerOptions options, C } // Validate and fill in defaults for the options. - if (options.ApplicationProtocols.Count <= 0) - { - throw new ArgumentException($"Expected at least one item in '{nameof(QuicListenerOptions.ApplicationProtocols)}' to start the listener.", nameof(options)); - } - if (options.ListenBacklog == 0) - { - options.ListenBacklog = 512; - } + options.Validate(nameof(options)); QuicListener listener = new QuicListener(options); if (NetEventSource.Log.IsEnabled()) { - NetEventSource.Info(listener, $"Listener listens on {listener.LocalEndPoint}"); + NetEventSource.Info(listener, $"{listener} Listener listens on {listener.LocalEndPoint}"); } return ValueTask.FromResult(listener); @@ -103,6 +95,10 @@ public static ValueTask ListenAsync(QuicListenerOptions options, C public override string ToString() => _handle.ToString(); + /// + /// Initializes and starts a new instance of a . + /// + /// Options to start the listener. private unsafe QuicListener(QuicListenerOptions options) { GCHandle context = GCHandle.Alloc(this, GCHandleType.Weak); @@ -146,7 +142,8 @@ private unsafe QuicListener(QuicListenerOptions options) "ListenerStart failed"); // Get the actual listening endpoint. - LocalEndPoint = MsQuicParameterHelpers.GetIPEndPointParam(MsQuicApi.Api, _handle, QUIC_PARAM_LISTENER_LOCAL_ADDRESS, options.ListenEndPoint.AddressFamily); + address = GetMsQuicParameter(_handle, QUIC_PARAM_LISTENER_LOCAL_ADDRESS); + LocalEndPoint = address.ToIPEndPoint(options.ListenEndPoint.AddressFamily); } /// @@ -156,7 +153,6 @@ private unsafe QuicListener(QuicListenerOptions options) /// Note that doesn't have a mechanism to report inbound connections that fail the handshake process. /// Such connections are only logged by the listener and never surfaced on the outside. /// - /// A cancellation token that can be used to cancel the asynchronous operation. /// A task that will contain a fully connected which successfully finished the handshake and is ready to be used. public async ValueTask AcceptConnectionAsync(CancellationToken cancellationToken = default) { @@ -196,7 +192,7 @@ private unsafe int HandleEventNewConnection(ref NEW_CONNECTION_DATA data) return QUIC_STATUS_CONNECTION_REFUSED; } - QuicConnection connection = new QuicConnection(new MsQuicConnection(data.Connection, data.Info)); + QuicConnection connection = new QuicConnection(data.Connection, data.Info); SslClientHelloInfo clientHello = new SslClientHelloInfo(data.Info->ServerNameLength > 0 ? Marshal.PtrToStringUTF8((IntPtr)data.Info->ServerName, data.Info->ServerNameLength) : "", SslProtocols.Tls13); // Kicks off the rest of the handshake in the background. @@ -231,7 +227,7 @@ private static unsafe int NativeCallback(QUIC_HANDLE* listener, void* context, Q { if (NetEventSource.Log.IsEnabled()) { - NetEventSource.Error(null, $"Received event {listenerEvent->Type}"); + NetEventSource.Error(null, $"Received event {listenerEvent->Type} while listener is already disposed"); } return QUIC_STATUS_INVALID_STATE; } @@ -241,7 +237,7 @@ private static unsafe int NativeCallback(QUIC_HANDLE* listener, void* context, Q // Process the event. if (NetEventSource.Log.IsEnabled()) { - NetEventSource.Info(instance, $"Received event {listenerEvent->Type}"); + NetEventSource.Info(instance, $"{instance} Received event {listenerEvent->Type}"); } return instance.HandleListenerEvent(ref *listenerEvent); } @@ -249,7 +245,7 @@ private static unsafe int NativeCallback(QUIC_HANDLE* listener, void* context, Q { if (NetEventSource.Log.IsEnabled()) { - NetEventSource.Error(instance, $"Exception while processing event {listenerEvent->Type}: {ex}"); + NetEventSource.Error(instance, $"{instance} Exception while processing event {listenerEvent->Type}: {ex}"); } return QUIC_STATUS_INTERNAL_ERROR; } @@ -275,6 +271,7 @@ public async ValueTask DisposeAsync() } } + // Wait for STOP_COMPLETE, the last event, so that all resources can be safely released. await valueTask.ConfigureAwait(false); _handle.Dispose(); diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListenerOptions.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListenerOptions.cs index 36241a36e2405..24a2b1e6050ec 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListenerOptions.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListenerOptions.cs @@ -16,21 +16,44 @@ public sealed class QuicListenerOptions /// /// The endpoint to listen on. /// - public required IPEndPoint ListenEndPoint { get; set; } + public IPEndPoint ListenEndPoint { get; set; } = null!; /// /// List of application protocols which the listener will accept. At least one must be specified. /// - public required List ApplicationProtocols { get; set; } + public List ApplicationProtocols { get; set; } = null!; /// - /// Number of connections to be held without accepting the connection. - /// + /// Number of connections to be held without accepting any them, i.e. maximum size of the pending connection queue. /// public int ListenBacklog { get; set; } /// /// Selection callback to choose inbound connection options dynamically. /// - public required Func> ConnectionOptionsCallback { get; set; } + public Func> ConnectionOptionsCallback { get; set; } = null!; + + /// + /// Validates the options and potentially sets platform specific defaults. + /// + /// Name of the from the caller. + internal void Validate(string argumentName) + { + if (ListenEndPoint is null) + { + throw new ArgumentNullException(SR.Format(SR.net_quic_not_null_listener, nameof(QuicListenerOptions.ListenEndPoint)), argumentName); + } + if (ApplicationProtocols is null || ApplicationProtocols.Count <= 0) + { + throw new ArgumentNullException(SR.Format(SR.net_quic_not_null_not_empty_listener, nameof(QuicListenerOptions.ApplicationProtocols)), argumentName); + } + if (ListenBacklog == 0) + { + ListenBacklog = QuicDefaults.DefaultListenBacklog; + } + if (ConnectionOptionsCallback is null) + { + throw new ArgumentNullException(SR.Format(SR.net_quic_not_null_listener, nameof(QuicListenerOptions.ConnectionOptionsCallback)), argumentName); + } + } } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.Stream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.Stream.cs new file mode 100644 index 0000000000000..eb8f17346fce1 --- /dev/null +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.Stream.cs @@ -0,0 +1,167 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.IO; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net.Quic; + +// Boilerplate implementation of Stream methods. +public partial class QuicStream : Stream +{ + // Seek and length. + public override bool CanSeek => false; + public override long Length => throw new NotSupportedException(); + public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); } + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + public override void SetLength(long value) => throw new NotSupportedException(); + + // Read and Write timeouts. + public override bool CanTimeout => true; + private TimeSpan _readTimeout = Timeout.InfiniteTimeSpan; + private TimeSpan _writeTimeout = Timeout.InfiniteTimeSpan; + public override int ReadTimeout + { + get + { + ObjectDisposedException.ThrowIf(_disposed == 1, this); + return (int)_readTimeout.TotalMilliseconds; + } + set + { + ObjectDisposedException.ThrowIf(_disposed == 1, this); + if (value <= 0 && value != Timeout.Infinite) + { + throw new ArgumentOutOfRangeException(nameof(value), SR.net_quic_timeout_use_gt_zero); + } + _readTimeout = TimeSpan.FromMilliseconds(value); + } + } + public override int WriteTimeout + { + get + { + ObjectDisposedException.ThrowIf(_disposed == 1, this); + return (int)_writeTimeout.TotalMilliseconds; + } + set + { + ObjectDisposedException.ThrowIf(_disposed == 1, this); + if (value <= 0 && value != Timeout.Infinite) + { + throw new ArgumentOutOfRangeException(nameof(value), SR.net_quic_timeout_use_gt_zero); + } + _writeTimeout = TimeSpan.FromMilliseconds(value); + } + } + + // Read boilerplate. + public override bool CanRead => Volatile.Read(ref _disposed) == 0 && _canRead; + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) + => TaskToApm.Begin(ReadAsync(buffer, offset, count, default), callback, state); + public override int EndRead(IAsyncResult asyncResult) + => TaskToApm.End(asyncResult); + public override int Read(byte[] buffer, int offset, int count) + { + ValidateBufferArguments(buffer, offset, count); + return Read(buffer.AsSpan(offset, count)); + } + public override int ReadByte() + { + byte b = 0; + return Read(MemoryMarshal.CreateSpan(ref b, 1)) != 0 ? b : -1; + } + public override int Read(Span buffer) + { + ObjectDisposedException.ThrowIf(_disposed == 1, this); + + byte[] rentedBuffer = ArrayPool.Shared.Rent(buffer.Length); + CancellationTokenSource? cts = null; + try + { + if (_readTimeout > TimeSpan.Zero) + { + cts = new CancellationTokenSource(_readTimeout); + } + int readLength = ReadAsync(new Memory(rentedBuffer, 0, buffer.Length), cts?.Token ?? default).AsTask().GetAwaiter().GetResult(); + rentedBuffer.AsSpan(0, readLength).CopyTo(buffer); + return readLength; + } + catch (OperationCanceledException) when (cts?.IsCancellationRequested == true) + { + // sync operations do not have Cancellation + throw new IOException(SR.net_quic_timeout); + } + finally + { + ArrayPool.Shared.Return(rentedBuffer); + cts?.Dispose(); + } + } + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default) + { + ValidateBufferArguments(buffer, offset, count); + return ReadAsync(new Memory(buffer, offset, count), cancellationToken).AsTask(); + } + + // Write boilerplate. + public override bool CanWrite => Volatile.Read(ref _disposed) == 0 && _canWrite; + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) + => TaskToApm.Begin(WriteAsync(buffer, offset, count, default), callback, state); + public override void EndWrite(IAsyncResult asyncResult) + => TaskToApm.End(asyncResult); + public override void Write(byte[] buffer, int offset, int count) + { + ValidateBufferArguments(buffer, offset, count); + Write(buffer.AsSpan(offset, count)); + } + public override void WriteByte(byte value) + { + Write(MemoryMarshal.CreateReadOnlySpan(ref value, 1)); + } + public override void Write(ReadOnlySpan buffer) + { + ObjectDisposedException.ThrowIf(_disposed == 1, this); + + CancellationTokenSource? cts = null; + if (_writeTimeout > TimeSpan.Zero) + { + cts = new CancellationTokenSource(_writeTimeout); + } + try + { + WriteAsync(buffer.ToArray(), cts?.Token ?? default).AsTask().GetAwaiter().GetResult(); + } + catch (OperationCanceledException) when (cts?.IsCancellationRequested == true) + { + // sync operations do not have Cancellation + throw new IOException(SR.net_quic_timeout); + } + finally + { + cts?.Dispose(); + } + } + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default) + { + ValidateBufferArguments(buffer, offset, count); + return WriteAsync(new ReadOnlyMemory(buffer, offset, count), cancellationToken).AsTask(); + } + + // Flush. + public override void Flush() + => FlushAsync().GetAwaiter().GetResult(); + public override Task FlushAsync(CancellationToken cancellationToken = default) + // NOP for now + => Task.CompletedTask; + + // Dispose. + protected override void Dispose(bool disposing) + { + DisposeAsync().AsTask().GetAwaiter().GetResult(); + base.Dispose(disposing); + } +} diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs index ae5f7368fe09a..a974a54257ebb 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs @@ -3,138 +3,607 @@ using System.Buffers; using System.IO; -using System.Net.Quic.Implementations; -using System.Net.Quic.Implementations.MsQuic; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; - -namespace System.Net.Quic +using Microsoft.Quic; +using static System.Net.Quic.MsQuicHelpers; +using static System.Net.Quic.QuicDefaults; +using static Microsoft.Quic.MsQuic; + +using START_COMPLETE = Microsoft.Quic.QUIC_STREAM_EVENT._Anonymous_e__Union._START_COMPLETE_e__Struct; +using RECEIVE = Microsoft.Quic.QUIC_STREAM_EVENT._Anonymous_e__Union._RECEIVE_e__Struct; +using SEND_COMPLETE = Microsoft.Quic.QUIC_STREAM_EVENT._Anonymous_e__Union._SEND_COMPLETE_e__Struct; +using PEER_SEND_ABORTED = Microsoft.Quic.QUIC_STREAM_EVENT._Anonymous_e__Union._PEER_SEND_ABORTED_e__Struct; +using PEER_RECEIVE_ABORTED = Microsoft.Quic.QUIC_STREAM_EVENT._Anonymous_e__Union._PEER_RECEIVE_ABORTED_e__Struct; +using SEND_SHUTDOWN_COMPLETE = Microsoft.Quic.QUIC_STREAM_EVENT._Anonymous_e__Union._SEND_SHUTDOWN_COMPLETE_e__Struct; +using SHUTDOWN_COMPLETE = Microsoft.Quic.QUIC_STREAM_EVENT._Anonymous_e__Union._SHUTDOWN_COMPLETE_e__Struct; + +namespace System.Net.Quic; + +public sealed partial class QuicStream { - public sealed class QuicStream : Stream + /// + /// Handle to MsQuic connection object. + /// + private MsQuicContextSafeHandle _handle; + + /// + /// Set to non-zero once disposed. Prevents double and/or concurrent disposal. + /// + private int _disposed; + + private readonly ValueTaskSource _startedTcs = new ValueTaskSource(); + private readonly ValueTaskSource _shutdownTcs = new ValueTaskSource(); + + private readonly ResettableValueTaskSource _receiveTcs = new ResettableValueTaskSource() { - private readonly MsQuicStream _provider; + CancellationAction = target => + { + if (target is QuicStream stream) + { + stream.Abort(QuicAbortDirection.Read, stream._defaultErrorCode); + } + } + }; +// [ActiveIssue("https://github.com/dotnet/roslyn-analyzers/issues/5750")] Structs can have parameterless ctor now and thus the behavior differs from just defaulting the struct to zeros. +#pragma warning disable CA1805 + private ReceiveBuffers _receiveBuffers = new ReceiveBuffers(); +#pragma warning restore CA1805 + private int _receivedNeedsEnable; + + private readonly ResettableValueTaskSource _sendTcs = new ResettableValueTaskSource() + { + CancellationAction = target => + { + if (target is QuicStream stream) + { + stream.Abort(QuicAbortDirection.Write, stream._defaultErrorCode); + } + } + }; +// [ActiveIssue("https://github.com/dotnet/roslyn-analyzers/issues/5750")] Structs can have parameterless ctor now and thus the behavior differs from just defaulting the struct to zeros. +#pragma warning disable CA1805 + private MsQuicBuffers _sendBuffers = new MsQuicBuffers(); +#pragma warning restore CA1805 + + private readonly long _defaultErrorCode; + + private bool _canRead; + private bool _canWrite; + + // TODO: remove once/if https://github.com/microsoft/msquic/pull/2883 is merged + private readonly QuicConnection.State _connectionState; + + private long _id = -1; + private QuicStreamType _type; + + /// + /// Stream id, see . + /// + public long Id => _id; + + /// + /// Stream type, see . + /// + public QuicStreamType Type => _type; + + public Task ReadsClosed => _receiveTcs.GetFinalTask(); + + public Task WritesClosed => _sendTcs.GetFinalTask(); - internal QuicStream(MsQuicStream provider) + public override string ToString() => _handle.ToString(); + + /// + /// Initializes a new instance of an outbound . + /// + /// Connection state + /// safe handle, used to increment/decrement reference count with each associated stream. + /// The type of the stream to open. + /// Error code used when the stream needs to abort read or write side of the stream internally. + internal unsafe QuicStream(QuicConnection.State connectionState, MsQuicContextSafeHandle connectionHandle, QuicStreamType type, long defaultErrorCode) + { + _connectionState = connectionState; + GCHandle context = GCHandle.Alloc(this, GCHandleType.Weak); + try + { + QUIC_HANDLE* handle; + ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->StreamOpen( + connectionHandle.QuicHandle, + type == QuicStreamType.Unidirectional ? QUIC_STREAM_OPEN_FLAGS.UNIDIRECTIONAL : QUIC_STREAM_OPEN_FLAGS.NONE, + &NativeCallback, + (void*)GCHandle.ToIntPtr(context), + &handle), + "StreamOpen failed"); + _handle = new MsQuicContextSafeHandle(handle, context, MsQuicApi.Api.ApiTable->StreamClose, SafeHandleType.Stream, connectionHandle); + } + catch { - _provider = provider; + context.Free(); + throw; } - // - // Boilerplate implementation stuff - // + _defaultErrorCode = defaultErrorCode; - public override bool CanSeek => false; - public override long Length => throw new NotSupportedException(); - public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); - public override void SetLength(long value) => throw new NotSupportedException(); - public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); } + _canRead = type == QuicStreamType.Bidirectional; + _canWrite = true; + if (!_canRead) + { + _receiveTcs.TrySetResult(final: true); + } + _type = type; + } - public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) => - TaskToApm.Begin(ReadAsync(buffer, offset, count, default), callback, state); + /// + /// Initializes a new instance of an inbound . + /// + /// Connection state + /// safe handle, used to increment/decrement reference count with each associated stream. + /// Native handle. + /// Related data from the PEER_STREAM_STARTED connection event. + /// Error code used when the stream needs to abort read or write side of the stream internally. + internal unsafe QuicStream(QuicConnection.State connectionState, MsQuicContextSafeHandle connectionHandle, QUIC_HANDLE* handle, QUIC_STREAM_OPEN_FLAGS flags, long defaultErrorCode) + { + _connectionState = connectionState; + GCHandle context = GCHandle.Alloc(this, GCHandleType.Weak); + try + { + delegate* unmanaged[Cdecl] nativeCallback = &NativeCallback; + MsQuicApi.Api.ApiTable->SetCallbackHandler( + handle, + nativeCallback, + (void*)GCHandle.ToIntPtr(context)); + _handle = new MsQuicContextSafeHandle(handle, context, MsQuicApi.Api.ApiTable->StreamClose, SafeHandleType.Stream, connectionHandle); + } + catch + { + context.Free(); + throw; + } - public override int EndRead(IAsyncResult asyncResult) => - TaskToApm.End(asyncResult); + _defaultErrorCode = defaultErrorCode; - public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) => - TaskToApm.Begin(WriteAsync(buffer, offset, count, default), callback, state); + _canRead = true; + _canWrite = !flags.HasFlag(QUIC_STREAM_OPEN_FLAGS.UNIDIRECTIONAL); + if (!_canWrite) + { + _sendTcs.TrySetResult(final: true); + } + _id = (long)GetMsQuicParameter(_handle, QUIC_PARAM_STREAM_ID); + _type = flags.HasFlag(QUIC_STREAM_OPEN_FLAGS.UNIDIRECTIONAL) ? QuicStreamType.Unidirectional : QuicStreamType.Bidirectional; + _startedTcs.TrySetResult(); + } - public override void EndWrite(IAsyncResult asyncResult) => - TaskToApm.End(asyncResult); + internal ValueTask StartAsync(CancellationToken cancellationToken = default) + { + ObjectDisposedException.ThrowIf(_disposed == 1, this); - public override int Read(byte[] buffer, int offset, int count) + _startedTcs.TryInitialize(out ValueTask valueTask, this, cancellationToken); { - ValidateBufferArguments(buffer, offset, count); - return Read(buffer.AsSpan(offset, count)); + unsafe + { + int status = MsQuicApi.Api.ApiTable->StreamStart( + _handle.QuicHandle, + QUIC_STREAM_START_FLAGS.SHUTDOWN_ON_FAIL | QUIC_STREAM_START_FLAGS.INDICATE_PEER_ACCEPT); + if (StatusFailed(status)) + { + // TODO: aborted and the exception type + _startedTcs.TrySetException(ThrowHelper.GetExceptionForMsQuicStatus(status)); + } + } } - public override int ReadByte() + return valueTask; + } + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + ObjectDisposedException.ThrowIf(_disposed == 1, this); + + if (!_canRead) { - byte b = 0; - return Read(new Span(ref b)) != 0 ? b : -1; + throw new InvalidOperationException(SR.net_quic_reading_notallowed); } - public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + if (NetEventSource.Log.IsEnabled()) { - ValidateBufferArguments(buffer, offset, count); - return ReadAsync(new Memory(buffer, offset, count), cancellationToken).AsTask(); + NetEventSource.Info(this, $"{this} Stream reading into memory of '{buffer.Length}' bytes."); } - public override void Write(byte[] buffer, int offset, int count) + if (_receiveTcs.IsCompleted) { - ValidateBufferArguments(buffer, offset, count); - Write(buffer.AsSpan(offset, count)); + // Special case exception type for pre-canceled token while we've already transitioned to a final state and don't need to abort read. + // It must happen before we try to get the value task, since the task source is versioned and each instance must be awaited. + cancellationToken.ThrowIfCancellationRequested(); } - public override void WriteByte(byte value) + // The following loop will repeat at most twice depending whether some data are readily available in the buffer (one iteration) or not. + // In which case, it'll wait on RECEIVE or any of PEER_SEND_(SHUTDOWN|ABORTED) event and attempt to copy data in the second iteration. + int totalCopied = 0; + do { - Write(new ReadOnlySpan(in value)); + // Concurrent call, this one lost the race. + if (!_receiveTcs.TryGetValueTask(out ValueTask valueTask, this, cancellationToken)) + { + throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, "read")); + } + + // Copy data from the buffer, reduce target and increment total. + int copied = _receiveBuffers.CopyTo(buffer, out bool complete, out bool empty); + buffer = buffer.Slice(copied); + totalCopied += copied; + + // Make sure the task transitions into final state before the method finishes. + if (complete) + { + _receiveTcs.TrySetResult(final: true); + } + + // Unblock the next await to end immediately, i.e. there were/are any data in the buffer. + if (totalCopied > 0 || !empty) + { + _receiveTcs.TrySetResult(); + } + + // This will either wait for RECEIVE event (no data in buffer) or complete immediately and reset the task. + await valueTask.ConfigureAwait(false); + + // This is the last read, finish even despite not copying anything. + if (complete) + { + break; + } + } while (!buffer.IsEmpty && totalCopied == 0); // Exit the loop if target buffer is full we at least copied something. + + if (totalCopied > 0 && Interlocked.CompareExchange(ref _receivedNeedsEnable, 0, 1) == 1) + { + unsafe + { + ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->StreamReceiveSetEnabled( + _handle.QuicHandle, + 1), + "StreamReceivedSetEnabled failed"); + } } - public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + return totalCopied; + } + + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + => WriteAsync(buffer, completeWrites: false, cancellationToken); + + public ValueTask WriteAsync(ReadOnlyMemory buffer, bool completeWrites, CancellationToken cancellationToken = default) + { + ObjectDisposedException.ThrowIf(_disposed == 1, this); + + if (!_canWrite) + { + throw new InvalidOperationException(SR.net_quic_writing_notallowed); + } + + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Info(this, $"{this} Stream writing memory of '{buffer.Length}' bytes while {(completeWrites ? "completing" : "not completing")} writes."); + } + + if (_sendTcs.IsCompleted) + { + // Special case exception type for pre-canceled token while we've already transitioned to a final state and don't need to abort write. + // It must happen before we try to get the value task, since the task source is versioned and each instance must be awaited. + cancellationToken.ThrowIfCancellationRequested(); + } + + // Concurrent call, this one lost the race. + if (!_sendTcs.TryGetValueTask(out ValueTask valueTask, this, cancellationToken)) { - ValidateBufferArguments(buffer, offset, count); - return WriteAsync(new ReadOnlyMemory(buffer, offset, count), cancellationToken).AsTask(); + throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, "write")); } - /// - /// QUIC stream ID. - /// - public long StreamId => _provider.StreamId; + // No need to call anything since we already have a result, most likely an exception. + if (valueTask.IsCompleted) + { + return valueTask; + } + + // For an empty buffer complete immediately, close the writing side of the stream if necessary. + if (buffer.IsEmpty) + { + _sendTcs.TrySetResult(); + if (completeWrites) + { + CompleteWrites(); + } + return valueTask; + } - public override bool CanRead => _provider.CanRead; + try + { + _sendBuffers.Initialize(buffer); + unsafe + { + int status = MsQuicApi.Api.ApiTable->StreamSend( + _handle.QuicHandle, + _sendBuffers.Buffers, + (uint)_sendBuffers.Count, + completeWrites ? QUIC_SEND_FLAGS.FIN : QUIC_SEND_FLAGS.NONE, + null); + if (status == QUIC_STATUS_ABORTED) + { + // If status == QUIC_STATUS_ABORTED, we either received PEER_RECEIVE_ABORTED or will receive SHUTDOWN_COMPLETE(ConnectionClose) later, all of which completes the _sendTcs. + _sendBuffers.Reset(); + return valueTask; + } + ThrowHelper.ThrowIfMsQuicError(status, "StreamSend failed"); + } + } + catch (Exception ex) + { + _sendTcs.TrySetException(ex, final: true); + _sendBuffers.Reset(); + throw; + } - public bool ReadsCompleted => _provider.ReadsCompleted; + return valueTask; + } - public override int Read(Span buffer) => _provider.Read(buffer); - public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) => _provider.ReadAsync(buffer, cancellationToken); + public void Abort(QuicAbortDirection abortDirection, long errorCode) + { + if (_disposed == 1) + { + return; + } - public override bool CanWrite => _provider.CanWrite; + QUIC_STREAM_SHUTDOWN_FLAGS flags = QUIC_STREAM_SHUTDOWN_FLAGS.NONE; + if (abortDirection.HasFlag(QuicAbortDirection.Read)) + { + flags |= QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_RECEIVE; + if (_receiveTcs.TrySetException(ThrowHelper.GetOperationAbortedException(SR.net_quic_reading_aborted), final: true)) + { + flags |= QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_RECEIVE; + } + } + if (abortDirection.HasFlag(QuicAbortDirection.Write)) + { + if (_sendTcs.TrySetException(ThrowHelper.GetOperationAbortedException(SR.net_quic_writing_aborted), final: true)) + { + flags |= QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_SEND; + } + } + // Nothing to abort, the requested sides to abort are already closed. + if (flags == QUIC_STREAM_SHUTDOWN_FLAGS.NONE) + { + return; + } - public override void Write(ReadOnlySpan buffer) => _provider.Write(buffer); + unsafe + { + ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->StreamShutdown( + _handle.QuicHandle, + flags, + (ulong)errorCode), + "StreamShutdown failed"); + } + } - public override bool CanTimeout => _provider.CanTimeout; + public void CompleteWrites() + { + ObjectDisposedException.ThrowIf(_disposed == 1, this); - public override int ReadTimeout + if (_shutdownTcs.TryInitialize(out _, this)) { - get => _provider.ReadTimeout; - set => _provider.ReadTimeout = value; + unsafe + { + ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->StreamShutdown( + _handle.QuicHandle, + QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, + default), + "StreamShutdown failed"); + } } + } - public override int WriteTimeout + private unsafe int HandleEventStartComplete(ref START_COMPLETE data) + { + _id = unchecked((long)data.ID); + if (StatusSucceeded(data.Status)) { - get => _provider.WriteTimeout; - set => _provider.WriteTimeout = value; + if (data.PeerAccepted != 0) + { + _startedTcs.TrySetResult(); + } + // If PeerAccepted == 0, we will later receive PEER_ACCEPTED event, which will complete the _startedTcs. + } + else + { + _startedTcs.TrySetException(ThrowHelper.GetExceptionForMsQuicStatus(data.Status)); + // TODO: aborted and exception type } - public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) => _provider.WriteAsync(buffer, cancellationToken); + return QUIC_STATUS_SUCCESS; + } + private unsafe int HandleEventReceive(ref RECEIVE data) + { + ulong totalCopied = (ulong)_receiveBuffers.CopyFrom( + new ReadOnlySpan(data.Buffers, (int) data.BufferCount), + (int) data.TotalBufferLength, + data.Flags.HasFlag(QUIC_RECEIVE_FLAGS.FIN)); + if (totalCopied < data.TotalBufferLength) + { + Volatile.Write(ref _receivedNeedsEnable, 1); + } - public override void Flush() => _provider.Flush(); + _receiveTcs.TrySetResult(); - public override Task FlushAsync(CancellationToken cancellationToken) => _provider.FlushAsync(cancellationToken); + data.TotalBufferLength = totalCopied; + return QUIC_STATUS_SUCCESS; + } + private unsafe int HandleEventSendComplete(ref SEND_COMPLETE data) + { + _sendBuffers.Reset(); + if (data.Canceled == 0) + { + _sendTcs.TrySetResult(); + } + // If Canceled != 0, we either aborted write, received PEER_RECEIVE_ABORTED or will receive SHUTDOWN_COMPLETE(ConnectionClose) later, all of which completes the _sendTcs. + return QUIC_STATUS_SUCCESS; + } + private unsafe int HandleEventPeerSendShutdown() + { + // Same as RECEIVE with FIN flag. Remember that no more RECEIVE events will come. + // Don't set the task to its final state yet, but wait for all the buffered data to get consumed first. + _receiveBuffers.SetFinal(); + _receiveTcs.TrySetResult(); + return QUIC_STATUS_SUCCESS; + } + private unsafe int HandleEventPeerSendAborted(ref PEER_SEND_ABORTED data) + { + _receiveBuffers.SetFinal(); + _receiveTcs.TrySetException(ThrowHelper.GetStreamAbortedException((long)data.ErrorCode), final: true); + return QUIC_STATUS_SUCCESS; + } + private unsafe int HandleEventPeerReceiveAborted(ref PEER_RECEIVE_ABORTED data) + { + _sendTcs.TrySetException(ThrowHelper.GetStreamAbortedException((long)data.ErrorCode), final: true); + return QUIC_STATUS_SUCCESS; + } + private unsafe int HandleEventSendShutdownComplete(ref SEND_SHUTDOWN_COMPLETE data) + { + if (data.Graceful != 0) + { + _sendTcs.TrySetResult(final: true); + } + // If Graceful == 0, we either aborted write, received PEER_RECEIVE_ABORTED or will receive SHUTDOWN_COMPLETE(ConnectionClose) later, all of which completes the _sendTcs. + return QUIC_STATUS_SUCCESS; + } + private unsafe int HandleEventShutdownComplete(ref SHUTDOWN_COMPLETE data) + { + if (data.ConnectionShutdown != 0) + { + Exception exception = ThrowHelper.GetConnectionAbortedException(_connectionState.AbortErrorCode); + _startedTcs.TrySetException(exception); + _receiveTcs.TrySetException(exception, final: true); + _sendTcs.TrySetException(exception, final: true); + } + _shutdownTcs.TrySetResult(); + return QUIC_STATUS_SUCCESS; + } + private unsafe int HandleEventPeerAccepted() + { + _startedTcs.TrySetResult(); + return QUIC_STATUS_SUCCESS; + } - public void AbortRead(long errorCode) => _provider.AbortRead(errorCode); + private unsafe int HandleStreamEvent(ref QUIC_STREAM_EVENT streamEvent) + => streamEvent.Type switch + { + QUIC_STREAM_EVENT_TYPE.START_COMPLETE => HandleEventStartComplete(ref streamEvent.START_COMPLETE), + QUIC_STREAM_EVENT_TYPE.RECEIVE => HandleEventReceive(ref streamEvent.RECEIVE), + QUIC_STREAM_EVENT_TYPE.SEND_COMPLETE => HandleEventSendComplete(ref streamEvent.SEND_COMPLETE), + QUIC_STREAM_EVENT_TYPE.PEER_SEND_SHUTDOWN => HandleEventPeerSendShutdown(), + QUIC_STREAM_EVENT_TYPE.PEER_SEND_ABORTED => HandleEventPeerSendAborted(ref streamEvent.PEER_SEND_ABORTED), + QUIC_STREAM_EVENT_TYPE.PEER_RECEIVE_ABORTED => HandleEventPeerReceiveAborted(ref streamEvent.PEER_RECEIVE_ABORTED), + QUIC_STREAM_EVENT_TYPE.SEND_SHUTDOWN_COMPLETE => HandleEventSendShutdownComplete(ref streamEvent.SEND_SHUTDOWN_COMPLETE), + QUIC_STREAM_EVENT_TYPE.SHUTDOWN_COMPLETE => HandleEventShutdownComplete(ref streamEvent.SHUTDOWN_COMPLETE), + QUIC_STREAM_EVENT_TYPE.PEER_ACCEPTED => HandleEventPeerAccepted(), + _ => QUIC_STATUS_SUCCESS + }; + +#pragma warning disable CS3016 + [UnmanagedCallersOnly(CallConvs = new Type[] { typeof(CallConvCdecl) })] +#pragma warning restore CS3016 + private static unsafe int NativeCallback(QUIC_HANDLE* connection, void* context, QUIC_STREAM_EVENT* streamEvent) + { + GCHandle stateHandle = GCHandle.FromIntPtr((IntPtr)context); - public void AbortWrite(long errorCode) => _provider.AbortWrite(errorCode); + // Check if the instance hasn't been collected. + if (!stateHandle.IsAllocated || stateHandle.Target is not QuicStream instance) + { + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Error(null, $"Received event {streamEvent->Type} while connection is already disposed"); + } + return QUIC_STATUS_INVALID_STATE; + } - public ValueTask WriteAsync(ReadOnlyMemory buffer, bool endStream, CancellationToken cancellationToken = default) => _provider.WriteAsync(buffer, endStream, cancellationToken); + try + { + // Process the event. + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Info(instance, $"{instance} Received event {streamEvent->Type}"); + } + return instance.HandleStreamEvent(ref *streamEvent); + } + catch (Exception ex) + { + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Error(instance, $"{instance} Exception while processing event {streamEvent->Type}: {ex}"); + } + return QUIC_STATUS_INTERNAL_ERROR; + } + } - public ValueTask WriteAsync(ReadOnlySequence buffers, CancellationToken cancellationToken = default) => _provider.WriteAsync(buffers, cancellationToken); + /// + /// If the read side is not fully consumed, i.e.: is completed and/or returned 0, + /// dispose will abort the read side with provided . + /// If the write side hasn't been closed, it'll be closed gracefully as if was called. + /// Finally, all resources associated with the stream will be released. + /// + /// A task that represents the asynchronous dispose operation. + public override async ValueTask DisposeAsync() + { + if (Interlocked.Exchange(ref _disposed, 1) != 0) + { + return; + } - public ValueTask WriteAsync(ReadOnlySequence buffers, bool endStream, CancellationToken cancellationToken = default) => _provider.WriteAsync(buffers, endStream, cancellationToken); + ValueTask valueTask; - public ValueTask ShutdownCompleted(CancellationToken cancellationToken = default) => _provider.ShutdownCompleted(cancellationToken); + // If the stream wasn't started successfully, gracelessly abort it. + if (!_startedTcs.IsCompletedSuccessfully) + { + // Check if the stream has been shut down and if not, shut it down. + if (_shutdownTcs.TryInitialize(out valueTask, this)) + { + StreamShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.ABORT | QUIC_STREAM_SHUTDOWN_FLAGS.IMMEDIATE, _defaultErrorCode); + } + } + else + { + // Abort the read side of the stream if it hasn't been fully consumed. + if (_receiveTcs.TrySetException(ThrowHelper.GetOperationAbortedException(), final: true)) + { + StreamShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_RECEIVE, _defaultErrorCode); + } + // Check if the stream has been shut down and if not, shut it down. + if (_shutdownTcs.TryInitialize(out valueTask, this)) + { + StreamShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, default); + } + } - public ValueTask WaitForWriteCompletionAsync(CancellationToken cancellationToken = default) => _provider.WaitForWriteCompletionAsync(cancellationToken); + // Wait for SHUTDOWN_COMPLETE, the last event, so that all resources can be safely released. + await valueTask.ConfigureAwait(false); + _handle.Dispose(); - public void Shutdown() => _provider.Shutdown(); + // TODO: memory leak if not disposed + _sendBuffers.Dispose(); - protected override void Dispose(bool disposing) + unsafe void StreamShutdown(QUIC_STREAM_SHUTDOWN_FLAGS flags, long errorCode) { - if (disposing) + int status = MsQuicApi.Api.ApiTable->StreamShutdown( + _handle.QuicHandle, + flags, + (ulong)errorCode); + if (StatusFailed(status)) { - _provider.Dispose(); + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Error(this, $"{this} StreamShutdown({flags}) failed: {ThrowHelper.GetErrorMessageForStatus(status)}."); + } } } } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStreamType.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStreamType.cs new file mode 100644 index 0000000000000..d03744887e489 --- /dev/null +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStreamType.cs @@ -0,0 +1,21 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Net.Quic; + +/// +/// Represents type of the stream. +/// +/// +public enum QuicStreamType +{ + /// + /// Write-only for the peer that opened the stream. Read-only for the peer that accepted the stream. + /// + Unidirectional, + + /// + /// For both peers, read and write capable. + /// + Bidirectional +} diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicCipherSuitesPolicyTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicCipherSuitesPolicyTests.cs index 1996b0d774da4..a16681154fd9d 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicCipherSuitesPolicyTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicCipherSuitesPolicyTests.cs @@ -32,9 +32,8 @@ private async Task TestConnection(CipherSuitesPolicy serverPolicy, CipherSuitesP var clientOptions = CreateQuicClientOptions(listener.LocalEndPoint); clientOptions.ClientAuthenticationOptions.CipherSuitesPolicy = clientPolicy; - using QuicConnection clientConnection = await CreateQuicConnection(clientOptions); + await using QuicConnection clientConnection = await CreateQuicConnection(clientOptions); - await clientConnection.ConnectAsync(); await clientConnection.CloseAsync(0); } diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs index f6e485618a305..b33f2f86e9c7c 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs @@ -29,37 +29,6 @@ public class MsQuicTests : QuicTestBase public MsQuicTests(ITestOutputHelper output) : base(output) { } - [Fact] - public async Task UnidirectionalAndBidirectionalStreamCountsWork() - { - (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(); - - Assert.Equal(0, serverConnection.GetRemoteAvailableBidirectionalStreamCount()); - Assert.Equal(0, serverConnection.GetRemoteAvailableUnidirectionalStreamCount()); - serverConnection.Dispose(); - clientConnection.Dispose(); - } - - [Fact] - public async Task UnidirectionalAndBidirectionalChangeValues() - { - QuicClientConnectionOptions clientOptions = new QuicClientConnectionOptions() - { - MaxBidirectionalStreams = 10, - MaxUnidirectionalStreams = 20, - RemoteEndPoint = new IPEndPoint(IPAddress.Loopback, 0), - ClientAuthenticationOptions = GetSslClientAuthenticationOptions() - }; - - (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(clientOptions); - Assert.Equal(100, clientConnection.GetRemoteAvailableBidirectionalStreamCount()); - Assert.Equal(10, clientConnection.GetRemoteAvailableUnidirectionalStreamCount()); - Assert.Equal(10, serverConnection.GetRemoteAvailableBidirectionalStreamCount()); - Assert.Equal(20, serverConnection.GetRemoteAvailableUnidirectionalStreamCount()); - serverConnection.Dispose(); - clientConnection.Dispose(); - } - [Fact] public async Task ConnectWithCertificateChain() { @@ -110,8 +79,8 @@ public async Task ConnectWithCertificateChain() (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(clientOptions, listenerOptions); Assert.Equal(certificate, clientConnection.RemoteCertificate); Assert.Null(serverConnection.RemoteCertificate); - serverConnection.Dispose(); - clientConnection.Dispose(); + await serverConnection.DisposeAsync(); + await clientConnection.DisposeAsync(); } [ConditionalFact] @@ -141,21 +110,19 @@ public async Task UntrustedClientCertificateFails() await using QuicListener listener = await CreateQuicListener(listenerOptions); QuicClientConnectionOptions clientOptions = CreateQuicClientOptions(listener.LocalEndPoint); clientOptions.ClientAuthenticationOptions.ClientCertificates = new X509CertificateCollection() { ClientCertificate }; - QuicConnection clientConnection = await CreateQuicConnection(clientOptions); + ValueTask clientTask = CreateQuicConnection(clientOptions); using CancellationTokenSource cts = new CancellationTokenSource(); cts.CancelAfter(500); //Some delay to see if we would get failed connection. Task serverTask = listener.AcceptConnectionAsync(cts.Token).AsTask(); - ValueTask t = clientConnection.ConnectAsync(cts.Token); - - t.AsTask().Wait(PassingTestTimeout); + clientTask.AsTask().Wait(PassingTestTimeout); await Assert.ThrowsAsync(() => serverTask); // The task will likely succeed but we don't really care. // It may fail if the server aborts quickly. try { - await t; + await clientTask; } catch (Exception ex) { @@ -167,7 +134,6 @@ public async Task UntrustedClientCertificateFails() public async Task CertificateCallbackThrowPropagates() { using CancellationTokenSource cts = new CancellationTokenSource(PassingTestTimeout); - X509Certificate? receivedCertificate = null; bool validationResult = false; var listenerOptions = new QuicListenerOptions() @@ -181,7 +147,7 @@ public async Task CertificateCallbackThrowPropagates() QuicClientConnectionOptions clientOptions = CreateQuicClientOptions(listener.LocalEndPoint); clientOptions.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) => { - receivedCertificate = cert; + Assert.Equal(ServerCertificate, cert); if (validationResult) { return validationResult; @@ -191,19 +157,15 @@ public async Task CertificateCallbackThrowPropagates() }; clientOptions.ClientAuthenticationOptions.TargetHost = "foobar1"; - QuicConnection clientConnection = await CreateQuicConnection(clientOptions); - await Assert.ThrowsAsync(() => clientConnection.ConnectAsync(cts.Token).AsTask()); - - Assert.Equal(ServerCertificate, receivedCertificate); - clientConnection.Dispose(); + await Assert.ThrowsAsync(() => CreateQuicConnection(clientOptions).AsTask()); // Make sure the listener is still usable and there is no lingering bad connection validationResult = true; - (clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(listener); + (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(listener); await PingPong(clientConnection, serverConnection); - clientConnection.Dispose(); - serverConnection.Dispose(); + await clientConnection.DisposeAsync(); + await serverConnection.DisposeAsync(); } [Fact] @@ -256,17 +218,15 @@ public async Task ConnectWithServerCertificateCallback() (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(clientOptions, listener); Assert.Equal(clientOptions.ClientAuthenticationOptions.TargetHost, receivedHostName); Assert.Equal(c1, receivedCertificate); - clientConnection.Dispose(); - serverConnection.Dispose(); + await clientConnection.DisposeAsync(); + await serverConnection.DisposeAsync(); // This should fail when callback return null. clientOptions.ClientAuthenticationOptions.TargetHost = "foobar3"; - clientConnection = await CreateQuicConnection(clientOptions); - Task clientTask = clientConnection.ConnectAsync(cts.Token).AsTask(); + Task clientTask = CreateQuicConnection(clientOptions).AsTask(); await Assert.ThrowsAnyAsync(() => clientTask); Assert.Equal(clientOptions.ClientAuthenticationOptions.TargetHost, receivedHostName); - clientConnection.Dispose(); // Do this last to make sure Listener is still functional. clientOptions.ClientAuthenticationOptions.TargetHost = "foobar2"; @@ -275,8 +235,8 @@ public async Task ConnectWithServerCertificateCallback() (clientConnection, serverConnection) = await CreateConnectedQuicConnection(clientOptions, listener); Assert.Equal(clientOptions.ClientAuthenticationOptions.TargetHost, receivedHostName); Assert.Equal(c2, receivedCertificate); - clientConnection.Dispose(); - serverConnection.Dispose(); + await clientConnection.DisposeAsync(); + await serverConnection.DisposeAsync(); } [Theory] @@ -313,8 +273,8 @@ public async Task ConnectWithIpSetsSni(string destination) (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(clientOptions, listener); Assert.Equal(expectedName, receivedHostName); - clientConnection.Dispose(); - serverConnection.Dispose(); + await clientConnection.DisposeAsync(); + await serverConnection.DisposeAsync(); } [Fact] @@ -346,10 +306,7 @@ public async Task ConnectWithCertificateForDifferentName_Throws() return SslPolicyErrors.None == errors; }; - using QuicConnection clientConnection = await CreateQuicConnection(clientOptions); - ValueTask clientTask = clientConnection.ConnectAsync(); - - await Assert.ThrowsAsync(async () => await clientTask); + await Assert.ThrowsAsync(async () => await CreateQuicConnection(clientOptions)); } [ConditionalTheory] @@ -449,8 +406,8 @@ public async Task ConnectWithClientCertificate(bool sendCertificate, bool useCli Assert.Equal(sendCertificate ? ClientCertificate : null, serverConnection.RemoteCertificate); await serverConnection.CloseAsync(0); - clientConnection.Dispose(); - serverConnection.Dispose(); + await clientConnection.DisposeAsync(); + await serverConnection.DisposeAsync(); } [Theory] @@ -459,8 +416,8 @@ public async Task ConnectWithClientCertificate(bool sendCertificate, bool useCli public async Task OpenStreamAsync_BlocksUntilAvailable(bool unidirectional) { ValueTask OpenStreamAsync(QuicConnection connection) => unidirectional - ? connection.OpenUnidirectionalStreamAsync() - : connection.OpenBidirectionalStreamAsync(); + ? connection.OpenOutboundStreamAsync(QuicStreamType.Unidirectional) + : connection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); QuicListenerOptions listenerOptions = new QuicListenerOptions() @@ -470,8 +427,8 @@ ValueTask OpenStreamAsync(QuicConnection connection) => unidirection ConnectionOptionsCallback = (_, _, _) => { var serverOptions = CreateQuicServerOptions(); - serverOptions.MaxBidirectionalStreams = 1; - serverOptions.MaxUnidirectionalStreams = 1; + serverOptions.MaxInboundBidirectionalStreams = 1; + serverOptions.MaxInboundUnidirectionalStreams = 1; return ValueTask.FromResult(serverOptions); } }; @@ -483,15 +440,15 @@ ValueTask OpenStreamAsync(QuicConnection connection) => unidirection Assert.False(waitTask.IsCompleted); // Close the streams, the waitTask should finish as a result. - stream.Dispose(); - QuicStream newStream = await serverConnection.AcceptStreamAsync(); - newStream.Dispose(); + await stream.DisposeAsync(); + QuicStream newStream = await serverConnection.AcceptInboundStreamAsync(); + await newStream.DisposeAsync(); newStream = await waitTask.AsTask().WaitAsync(TimeSpan.FromSeconds(10)); - newStream.Dispose(); + await newStream.DisposeAsync(); - clientConnection.Dispose(); - serverConnection.Dispose(); + await clientConnection.DisposeAsync(); + await serverConnection.DisposeAsync(); } [Theory] @@ -500,8 +457,8 @@ ValueTask OpenStreamAsync(QuicConnection connection) => unidirection public async Task OpenStreamAsync_Canceled_Throws_OperationCanceledException(bool unidirectional) { ValueTask OpenStreamAsync(QuicConnection connection, CancellationToken token = default) => unidirectional - ? connection.OpenUnidirectionalStreamAsync(token) - : connection.OpenBidirectionalStreamAsync(token); + ? connection.OpenOutboundStreamAsync(QuicStreamType.Unidirectional, token) + : connection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional, token); QuicListenerOptions listenerOptions = new QuicListenerOptions() { @@ -510,8 +467,8 @@ ValueTask OpenStreamAsync(QuicConnection connection, CancellationTok ConnectionOptionsCallback = (_, _, _) => { var serverOptions = CreateQuicServerOptions(); - serverOptions.MaxBidirectionalStreams = 1; - serverOptions.MaxUnidirectionalStreams = 1; + serverOptions.MaxInboundBidirectionalStreams = 1; + serverOptions.MaxInboundUnidirectionalStreams = 1; return ValueTask.FromResult(serverOptions); } }; @@ -531,16 +488,16 @@ ValueTask OpenStreamAsync(QuicConnection connection, CancellationTok Assert.Equal(cts.Token, ex.CancellationToken); // Close the streams, the waitTask should finish as a result. - stream.Dispose(); - QuicStream newStream = await serverConnection.AcceptStreamAsync(); - newStream.Dispose(); + await stream.DisposeAsync(); + QuicStream newStream = await serverConnection.AcceptInboundStreamAsync(); + await newStream.DisposeAsync(); // next call should work as intended newStream = await OpenStreamAsync(clientConnection).AsTask().WaitAsync(TimeSpan.FromSeconds(10)); - newStream.Dispose(); + await newStream.DisposeAsync(); - clientConnection.Dispose(); - serverConnection.Dispose(); + await clientConnection.DisposeAsync(); + await serverConnection.DisposeAsync(); } [Theory] @@ -549,8 +506,8 @@ ValueTask OpenStreamAsync(QuicConnection connection, CancellationTok public async Task OpenStreamAsync_PreCanceled_Throws_OperationCanceledException(bool unidirectional) { ValueTask OpenStreamAsync(QuicConnection connection, CancellationToken token = default) => unidirectional - ? connection.OpenUnidirectionalStreamAsync(token) - : connection.OpenBidirectionalStreamAsync(token); + ? connection.OpenOutboundStreamAsync(QuicStreamType.Unidirectional, token) + : connection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional, token); (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(null, CreateQuicListenerOptions()); @@ -560,8 +517,8 @@ ValueTask OpenStreamAsync(QuicConnection connection, CancellationTok var ex = await Assert.ThrowsAnyAsync(() => OpenStreamAsync(clientConnection, cts.Token).AsTask().WaitAsync(TimeSpan.FromSeconds(3))); Assert.Equal(cts.Token, ex.CancellationToken); - clientConnection.Dispose(); - serverConnection.Dispose(); + await clientConnection.DisposeAsync(); + await serverConnection.DisposeAsync(); } [Theory] @@ -570,8 +527,8 @@ ValueTask OpenStreamAsync(QuicConnection connection, CancellationTok public async Task OpenStreamAsync_ConnectionAbort_Throws(bool unidirectional, bool localAbort) { ValueTask OpenStreamAsync(QuicConnection connection, CancellationToken token = default) => unidirectional - ? connection.OpenUnidirectionalStreamAsync(token) - : connection.OpenBidirectionalStreamAsync(token); + ? connection.OpenOutboundStreamAsync(QuicStreamType.Unidirectional, token) + : connection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional, token); QuicListenerOptions listenerOptions = new QuicListenerOptions() { @@ -580,8 +537,8 @@ ValueTask OpenStreamAsync(QuicConnection connection, CancellationTok ConnectionOptionsCallback = (_, _, _) => { var serverOptions = CreateQuicServerOptions(); - serverOptions.MaxBidirectionalStreams = 1; - serverOptions.MaxUnidirectionalStreams = 1; + serverOptions.MaxInboundBidirectionalStreams = 1; + serverOptions.MaxInboundUnidirectionalStreams = 1; return ValueTask.FromResult(serverOptions); } }; @@ -607,8 +564,8 @@ ValueTask OpenStreamAsync(QuicConnection connection, CancellationTok await AssertThrowsQuicExceptionAsync(QuicError.ConnectionAborted, () => waitTask.AsTask().WaitAsync(TimeSpan.FromSeconds(3))); } - clientConnection.Dispose(); - serverConnection.Dispose(); + await clientConnection.DisposeAsync(); + await serverConnection.DisposeAsync(); } @@ -624,54 +581,33 @@ public async Task SetListenerTimeoutWorksWithSmallTimeout() }; (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(null, listenerOptions); - await AssertThrowsQuicExceptionAsync(QuicError.ConnectionAborted, async () => await serverConnection.AcceptStreamAsync().AsTask().WaitAsync(TimeSpan.FromSeconds(100))); - serverConnection.Dispose(); - clientConnection.Dispose(); + await AssertThrowsQuicExceptionAsync(QuicError.ConnectionAborted, async () => await serverConnection.AcceptInboundStreamAsync().AsTask().WaitAsync(TimeSpan.FromSeconds(100))); + await serverConnection.DisposeAsync(); + await clientConnection.DisposeAsync(); } [Theory] [MemberData(nameof(WriteData))] - public async Task WriteTests(int[][] writes, WriteType writeType) + public async Task WriteTests(int[][] writes) { await RunClientServer( async clientConnection => { - await using QuicStream stream = await clientConnection.OpenUnidirectionalStreamAsync(); + await using QuicStream stream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Unidirectional); foreach (int[] bufferLengths in writes) { - switch (writeType) + foreach (int bufferLength in bufferLengths) { - case WriteType.SingleBuffer: - foreach (int bufferLength in bufferLengths) - { - await stream.WriteAsync(new byte[bufferLength]); - } - break; - case WriteType.GatheredSequence: - var firstSegment = new BufferSegment(new byte[bufferLengths[0]]); - BufferSegment lastSegment = firstSegment; - - foreach (int bufferLength in bufferLengths.Skip(1)) - { - lastSegment = lastSegment.Append(new byte[bufferLength]); - } - - var buffer = new ReadOnlySequence(firstSegment, 0, lastSegment, lastSegment.Memory.Length); - await stream.WriteAsync(buffer); - break; - default: - Debug.Fail("Unknown write type."); - break; + await stream.WriteAsync(new byte[bufferLength]); } } - stream.Shutdown(); - await stream.ShutdownCompleted(); + stream.CompleteWrites(); }, async serverConnection => { - await using QuicStream stream = await serverConnection.AcceptStreamAsync(); + await using QuicStream stream = await serverConnection.AcceptInboundStreamAsync(); var buffer = new byte[4096]; int receivedBytes = 0, totalBytes = 0; @@ -684,8 +620,7 @@ await RunClientServer( int expectedTotalBytes = writes.SelectMany(x => x).Sum(); Assert.Equal(expectedTotalBytes, totalBytes); - stream.Shutdown(); - await stream.ShutdownCompleted(); + stream.CompleteWrites(); }); } @@ -696,7 +631,6 @@ public static IEnumerable WriteData() return from bufferCount in new[] { 1, 2, 3, 10 } - from writeType in Enum.GetValues() let writes = Enumerable.Range(0, 5) .Select(_ => @@ -704,13 +638,7 @@ from writeType in Enum.GetValues() .Select(_ => bufferSizes[r.Next(bufferSizes.Length)]) .ToArray()) .ToArray() - select new object[] { writes, writeType }; - } - - public enum WriteType - { - SingleBuffer, - GatheredSequence + select new object[] { writes }; } [Fact] @@ -719,20 +647,18 @@ public async Task CallDifferentWriteMethodsWorks() (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(); ReadOnlyMemory helloWorld = "Hello world!"u8.ToArray(); - ReadOnlySequence ros = CreateReadOnlySequenceFromBytes(helloWorld.ToArray()); - Assert.False(ros.IsSingleSegment); - using QuicStream clientStream = await clientConnection.OpenBidirectionalStreamAsync(); - ValueTask writeTask = clientStream.WriteAsync(ros); - using QuicStream serverStream = await serverConnection.AcceptStreamAsync(); + await using QuicStream clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); + ValueTask writeTask = clientStream.WriteAsync(helloWorld, completeWrites: true); + await using QuicStream serverStream = await serverConnection.AcceptInboundStreamAsync(); await writeTask; byte[] memory = new byte[24]; int res = await serverStream.ReadAsync(memory); Assert.Equal(12, res); - clientConnection.Dispose(); - serverConnection.Dispose(); + await clientConnection.DisposeAsync(); + await serverConnection.DisposeAsync(); } [Fact] @@ -740,10 +666,10 @@ public async Task CloseAsync_ByServer_AcceptThrows() { (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(); - using (clientConnection) - using (serverConnection) + await using (clientConnection) + await using (serverConnection) { - var acceptTask = serverConnection.AcceptStreamAsync(); + var acceptTask = serverConnection.AcceptInboundStreamAsync(); await serverConnection.CloseAsync(errorCode: 0); // make sure we throw await AssertThrowsQuicExceptionAsync(QuicError.OperationAborted, () => acceptTask.AsTask()); @@ -758,8 +684,8 @@ public async Task CloseAsync_MultipleCalls_FollowingCallsAreIgnored(bool client) (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(); - using (clientConnection) - using (serverConnection) + await using (clientConnection) + await using (serverConnection) { if (client) { @@ -870,7 +796,7 @@ await RunClientServer( iterations: 20, serverFunction: async connection => { - await using QuicStream stream = await connection.AcceptStreamAsync(); + await using QuicStream stream = await connection.AcceptInboundStreamAsync(); byte[] buffer = new byte[data.Length]; int bytesRead = await ReadAll(stream, buffer); @@ -881,26 +807,22 @@ await RunClientServer( { await stream.WriteAsync(data[pos..(pos + writeSize)]); } - await stream.WriteAsync(Memory.Empty, endStream: true); - - await stream.ShutdownCompleted(); + await stream.WriteAsync(Memory.Empty, completeWrites: true); }, clientFunction: async connection => { - await using QuicStream stream = await connection.OpenBidirectionalStreamAsync(); + await using QuicStream stream = await connection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); for (int pos = 0; pos < data.Length; pos += writeSize) { await stream.WriteAsync(data[pos..(pos + writeSize)]); } - await stream.WriteAsync(Memory.Empty, endStream: true); + await stream.WriteAsync(Memory.Empty, completeWrites: true); byte[] buffer = new byte[data.Length]; int bytesRead = await ReadAll(stream, buffer); Assert.Equal(data.Length, bytesRead); AssertExtensions.SequenceEqual(data, buffer); - - await stream.ShutdownCompleted(); } ); } @@ -915,12 +837,12 @@ async Task GetStreamIdWithoutStartWorks() { (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(); - using QuicStream clientStream = await clientConnection.OpenBidirectionalStreamAsync(); - Assert.Equal(0, clientStream.StreamId); + await using QuicStream clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); + Assert.Equal(0, clientStream.Id); // TODO: stream that is opened by client but left unaccepted by server may cause AccessViolationException in its Finalizer - clientConnection.Dispose(); - serverConnection.Dispose(); + await clientConnection.DisposeAsync(); + await serverConnection.DisposeAsync(); } await GetStreamIdWithoutStartWorks().WaitAsync(TimeSpan.FromSeconds(15)); @@ -935,12 +857,12 @@ async Task GetStreamIdWithoutStartWorks() { (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(); - using QuicStream clientStream = await clientConnection.OpenBidirectionalStreamAsync(); - Assert.Equal(0, clientStream.StreamId); + await using QuicStream clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); + Assert.Equal(0, clientStream.Id); // Dispose all connections before the streams; - clientConnection.Dispose(); - serverConnection.Dispose(); + await clientConnection.DisposeAsync(); + await serverConnection.DisposeAsync(); } await GetStreamIdWithoutStartWorks(); @@ -957,15 +879,16 @@ await Task.Run(async () => { (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(); - await using QuicStream clientStream = await clientConnection.OpenBidirectionalStreamAsync(); + await using QuicStream clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); await clientStream.WriteAsync(new byte[1]); - await using QuicStream serverStream = await serverConnection.AcceptStreamAsync(); + await using QuicStream serverStream = await serverConnection.AcceptInboundStreamAsync(); await serverStream.ReadAsync(new byte[1]); await clientConnection.CloseAsync(ExpectedErrorCode); byte[] buffer = new byte[100]; + await AssertThrowsQuicExceptionAsync(QuicError.OperationAborted, () => clientStream.ReadAsync(buffer).AsTask()); QuicException ex = await AssertThrowsQuicExceptionAsync(QuicError.ConnectionAborted, () => serverStream.ReadAsync(buffer).AsTask()); Assert.Equal(ExpectedErrorCode, ex.ApplicationErrorCode); }).WaitAsync(TimeSpan.FromMilliseconds(PassingTestTimeoutMilliseconds)); @@ -974,19 +897,23 @@ await Task.Run(async () => [Fact] public async Task Read_ConnectionAbortedByUser_Throws() { + const int ExpectedErrorCode = 1234; + await Task.Run(async () => { (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(); - await using QuicStream clientStream = await clientConnection.OpenBidirectionalStreamAsync(); + await using QuicStream clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); await clientStream.WriteAsync(new byte[1]); - await using QuicStream serverStream = await serverConnection.AcceptStreamAsync(); + await using QuicStream serverStream = await serverConnection.AcceptInboundStreamAsync(); await serverStream.ReadAsync(new byte[1]); - await serverConnection.CloseAsync(0); + await serverConnection.CloseAsync(ExpectedErrorCode); byte[] buffer = new byte[100]; + QuicException ex = await AssertThrowsQuicExceptionAsync(QuicError.ConnectionAborted, () => clientStream.ReadAsync(buffer).AsTask()); + Assert.Equal(ExpectedErrorCode, ex.ApplicationErrorCode); await AssertThrowsQuicExceptionAsync(QuicError.OperationAborted, () => serverStream.ReadAsync(buffer).AsTask()); }).WaitAsync(TimeSpan.FromMilliseconds(PassingTestTimeoutMilliseconds)); } @@ -998,13 +925,13 @@ public async Task BigWrite_SmallRead_Success(bool closeWithData) { const int size = 100; (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(); - using (clientConnection) - using (serverConnection) + await using (clientConnection) + await using (serverConnection) { byte[] buffer = new byte[1] { 42 }; - QuicStream clientStream = await clientConnection.OpenBidirectionalStreamAsync(); - Task t = serverConnection.AcceptStreamAsync().AsTask(); + QuicStream clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); + Task t = serverConnection.AcceptInboundStreamAsync().AsTask(); await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientStream.WriteAsync(buffer).AsTask(), t, PassingTestTimeoutMilliseconds); QuicStream serverStream = t.Result; Assert.Equal(1, await serverStream.ReadAsync(buffer)); @@ -1034,7 +961,7 @@ public async Task BigWrite_SmallRead_Success(bool closeWithData) if (!closeWithData) { - serverStream.Shutdown(); + serverStream.CompleteWrites(); } readLength = await clientStream.ReadAsync(actual); @@ -1052,34 +979,31 @@ await RunClientServer( iterations: 100, serverFunction: async connection => { - using QuicStream stream = await connection.AcceptStreamAsync(); - Assert.False(stream.ReadsCompleted); + await using QuicStream stream = await connection.AcceptInboundStreamAsync(); + Assert.False(stream.ReadsClosed.IsCompleted); byte[] buffer = new byte[s_data.Length]; int bytesRead = await ReadAll(stream, buffer); - Assert.True(stream.ReadsCompleted); + Assert.True(stream.ReadsClosed.IsCompletedSuccessfully); Assert.Equal(s_data.Length, bytesRead); Assert.Equal(s_data, buffer); - await stream.WriteAsync(s_data, endStream: true); - await stream.ShutdownCompleted(); + await stream.WriteAsync(s_data, completeWrites: true); }, clientFunction: async connection => { - using QuicStream stream = await connection.OpenBidirectionalStreamAsync(); - Assert.False(stream.ReadsCompleted); + await using QuicStream stream = await connection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); + Assert.False(stream.ReadsClosed.IsCompleted); - await stream.WriteAsync(s_data, endStream: true); + await stream.WriteAsync(s_data, completeWrites: true); byte[] buffer = new byte[s_data.Length]; int bytesRead = await ReadAll(stream, buffer); - Assert.True(stream.ReadsCompleted); + Assert.True(stream.ReadsClosed.IsCompletedSuccessfully); Assert.Equal(s_data.Length, bytesRead); Assert.Equal(s_data, buffer); - - await stream.ShutdownCompleted(); } ); } @@ -1090,22 +1014,22 @@ public async Task Read_ReadsCompleted_ReportedBeforeReturning0() await RunBidirectionalClientServer( async clientStream => { - await clientStream.WriteAsync(new byte[1], endStream: true); + await clientStream.WriteAsync(new byte[1], completeWrites: true); }, async serverStream => { - Assert.False(serverStream.ReadsCompleted); + Assert.False(serverStream.ReadsClosed.IsCompleted); var received = await serverStream.ReadAsync(new byte[1]); Assert.Equal(1, received); - Assert.True(serverStream.ReadsCompleted); + Assert.True(serverStream.ReadsClosed.IsCompletedSuccessfully); var task = serverStream.ReadAsync(new byte[1]); Assert.True(task.IsCompleted); received = await task; Assert.Equal(0, received); - Assert.True(serverStream.ReadsCompleted); + Assert.True(serverStream.ReadsClosed.IsCompletedSuccessfully); }); } } diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicConnectionTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicConnectionTests.cs index c19188861e941..d27aae22484ac 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicConnectionTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicConnectionTests.cs @@ -21,19 +21,13 @@ public async Task TestConnect() { await using QuicListener listener = await CreateQuicListener(); - using QuicConnection clientConnection = await CreateQuicConnection(listener.LocalEndPoint); - - Assert.False(clientConnection.Connected); - Assert.Equal(listener.LocalEndPoint, clientConnection.RemoteEndPoint); - - ValueTask connectTask = clientConnection.ConnectAsync(); + ValueTask connectTask = CreateQuicConnection(listener.LocalEndPoint); ValueTask acceptTask = listener.AcceptConnectionAsync(); await new Task[] { connectTask.AsTask(), acceptTask.AsTask() }.WhenAllOrAnyFailed(PassingTestTimeoutMilliseconds); - QuicConnection serverConnection = acceptTask.Result; + await using QuicConnection serverConnection = acceptTask.Result; + await using QuicConnection clientConnection = connectTask.Result; - Assert.True(clientConnection.Connected); - Assert.True(serverConnection.Connected); Assert.Equal(listener.LocalEndPoint, serverConnection.LocalEndPoint); Assert.Equal(listener.LocalEndPoint, clientConnection.RemoteEndPoint); Assert.Equal(clientConnection.LocalEndPoint, serverConnection.RemoteEndPoint); @@ -43,7 +37,7 @@ public async Task TestConnect() private static async Task OpenAndUseStreamAsync(QuicConnection c) { - QuicStream s = await c.OpenBidirectionalStreamAsync(); + QuicStream s = await c.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); // This will pend await s.ReadAsync(new byte[1]); @@ -64,7 +58,7 @@ await RunClientServer( async serverConnection => { // Pend operations before the client closes. - Task acceptTask = serverConnection.AcceptStreamAsync().AsTask(); + Task acceptTask = serverConnection.AcceptInboundStreamAsync().AsTask(); Assert.False(acceptTask.IsCompleted); Task connectTask = OpenAndUseStreamAsync(serverConnection); Assert.False(connectTask.IsCompleted); @@ -83,7 +77,7 @@ await RunClientServer( // Subsequent attempts should fail // TODO: Which exception is correct? - await AssertThrowsQuicExceptionAsync(QuicError.OperationAborted, async () => await serverConnection.AcceptStreamAsync()); + await AssertThrowsQuicExceptionAsync(QuicError.OperationAborted, async () => await serverConnection.AcceptInboundStreamAsync()); await Assert.ThrowsAnyAsync(() => OpenAndUseStreamAsync(serverConnection)); }); } @@ -101,12 +95,12 @@ await RunClientServer( async serverConnection => { // Pend operations before the client closes. - Task acceptTask = serverConnection.AcceptStreamAsync().AsTask(); + Task acceptTask = serverConnection.AcceptInboundStreamAsync().AsTask(); Assert.False(acceptTask.IsCompleted); Task connectTask = OpenAndUseStreamAsync(serverConnection); Assert.False(connectTask.IsCompleted); - serverConnection.Dispose(); + await serverConnection.DisposeAsync(); sync.Release(); @@ -115,13 +109,13 @@ await RunClientServer( // TODO: This may not always throw QuicOperationAbortedException due to a data race with MsQuic worker threads // (CloseAsync may be processed before OpenStreamAsync as it is scheduled to the front of the operation queue) - // To be revisited once we standartize on exceptions. + // To be revisited once we standardize on exceptions. // [ActiveIssue("https://github.com/dotnet/runtime/issues/55619")] await Assert.ThrowsAsync(() => connectTask); // Subsequent attempts should fail // TODO: Should these be QuicOperationAbortedException, to match above? Or vice-versa? - await Assert.ThrowsAsync(async () => await serverConnection.AcceptStreamAsync()); + await Assert.ThrowsAsync(async () => await serverConnection.AcceptInboundStreamAsync()); await Assert.ThrowsAsync(async () => await OpenAndUseStreamAsync(serverConnection)); }); } @@ -137,11 +131,13 @@ await RunClientServer( await sync.WaitAsync(); await clientConnection.CloseAsync(ExpectedErrorCode); + + sync.Release(); }, async serverConnection => { // Pend operations before the client closes. - Task acceptTask = serverConnection.AcceptStreamAsync().AsTask(); + Task acceptTask = serverConnection.AcceptInboundStreamAsync().AsTask(); Assert.False(acceptTask.IsCompleted); Task connectTask = OpenAndUseStreamAsync(serverConnection); Assert.False(connectTask.IsCompleted); @@ -156,8 +152,10 @@ await RunClientServer( ex = await AssertThrowsQuicExceptionAsync(QuicError.ConnectionAborted, () => connectTask); Assert.Equal(ExpectedErrorCode, ex.ApplicationErrorCode); + await sync.WaitAsync(); + // Subsequent attempts should fail - ex = await AssertThrowsQuicExceptionAsync(QuicError.ConnectionAborted, () => serverConnection.AcceptStreamAsync().AsTask()); + ex = await AssertThrowsQuicExceptionAsync(QuicError.ConnectionAborted, () => serverConnection.AcceptInboundStreamAsync().AsTask()); Assert.Equal(ExpectedErrorCode, ex.ApplicationErrorCode); ex = await AssertThrowsQuicExceptionAsync(QuicError.ConnectionAborted, () => OpenAndUseStreamAsync(serverConnection)); Assert.Equal(ExpectedErrorCode, ex.ApplicationErrorCode); @@ -191,7 +189,7 @@ public async Task CloseAsync_WithOpenStream_LocalAndPeerStreamsFailWithQuicOpera await RunClientServer( async clientConnection => { - using QuicStream clientStream = await clientConnection.OpenBidirectionalStreamAsync(); + await using QuicStream clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); await DoWrites(clientStream, writesBeforeClose); // Wait for peer to receive data @@ -204,7 +202,7 @@ await RunClientServer( }, async serverConnection => { - using QuicStream serverStream = await serverConnection.AcceptStreamAsync(); + await using QuicStream serverStream = await serverConnection.AcceptInboundStreamAsync(); await DoReads(serverStream, writesBeforeClose); sync.Release(); @@ -218,6 +216,45 @@ await RunClientServer( }); } + [Theory] + [InlineData(1)] + [InlineData(10)] + public async Task Dispose_WithoutClose_ConnectionClosesWithDefault(int writesBeforeClose) + { + QuicListenerOptions listenerOptions = CreateQuicListenerOptions(); + + using var sync = new SemaphoreSlim(0); + + await RunClientServer( + async clientConnection => + { + using QuicStream clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); + await DoWrites(clientStream, writesBeforeClose); + + // Wait for peer to receive data + await sync.WaitAsync(); + + await clientConnection.DisposeAsync(); + + await AssertThrowsQuicExceptionAsync(QuicError.OperationAborted, async () => await clientStream.ReadAsync(new byte[1])); + await AssertThrowsQuicExceptionAsync(QuicError.OperationAborted, async () => await clientStream.WriteAsync(new byte[1])); + }, + async serverConnection => + { + using QuicStream serverStream = await serverConnection.AcceptInboundStreamAsync(); + await DoReads(serverStream, writesBeforeClose); + + sync.Release(); + + // Since the peer did the abort, we should receive the abort error code in the exception. + QuicException ex; + ex = await AssertThrowsQuicExceptionAsync(QuicError.ConnectionAborted, async () => await serverStream.ReadAsync(new byte[1])); + Assert.Equal(DefaultCloseErrorCodeClient, ex.ApplicationErrorCode); + ex = await AssertThrowsQuicExceptionAsync(QuicError.ConnectionAborted, async () => await serverStream.WriteAsync(new byte[1])); + Assert.Equal(DefaultCloseErrorCodeClient, ex.ApplicationErrorCode); + }, listenerOptions: listenerOptions); + } + [OuterLoop("Depends on IdleTimeout")] [Theory] [InlineData(1)] @@ -232,20 +269,20 @@ public async Task Dispose_WithOpenLocalStream_LocalStreamFailsWithQuicOperationA await RunClientServer( async clientConnection => { - using QuicStream clientStream = await clientConnection.OpenBidirectionalStreamAsync(); + await using QuicStream clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); await DoWrites(clientStream, writesBeforeClose); // Wait for peer to receive data await sync.WaitAsync(); - clientConnection.Dispose(); + await clientConnection.DisposeAsync(); await AssertThrowsQuicExceptionAsync(QuicError.OperationAborted, async () => await clientStream.ReadAsync(new byte[1])); await AssertThrowsQuicExceptionAsync(QuicError.OperationAborted, async () => await clientStream.WriteAsync(new byte[1])); }, async serverConnection => { - using QuicStream serverStream = await serverConnection.AcceptStreamAsync(); + await using QuicStream serverStream = await serverConnection.AcceptInboundStreamAsync(); await DoReads(serverStream, writesBeforeClose); sync.Release(); diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicListenerTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicListenerTests.cs index 44d30bcab1d06..45cf135c503bf 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicListenerTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicListenerTests.cs @@ -20,11 +20,9 @@ await Task.Run(async () => { await using QuicListener listener = await CreateQuicListener(); - using QuicConnection clientConnection = await CreateQuicConnection(listener.LocalEndPoint); - var clientStreamTask = clientConnection.ConnectAsync(); - - using QuicConnection serverConnection = await listener.AcceptConnectionAsync(); - await clientStreamTask; + var clientStreamTask = CreateQuicConnection(listener.LocalEndPoint); + await using QuicConnection serverConnection = await listener.AcceptConnectionAsync(); + await using QuicConnection clientConnection = await clientStreamTask; }).WaitAsync(TimeSpan.FromSeconds(6)); } @@ -35,11 +33,9 @@ await Task.Run(async () => { await using QuicListener listener = await CreateQuicListener(new IPEndPoint(IPAddress.IPv6Loopback, 0)); - using QuicConnection clientConnection = await CreateQuicConnection(listener.LocalEndPoint); - var clientStreamTask = clientConnection.ConnectAsync(); - - using QuicConnection serverConnection = await listener.AcceptConnectionAsync(); - await clientStreamTask; + var clientStreamTask = CreateQuicConnection(listener.LocalEndPoint); + await using QuicConnection serverConnection = await listener.AcceptConnectionAsync(); + await using QuicConnection clientConnection = await clientStreamTask; }).WaitAsync(TimeSpan.FromSeconds(6)); } @@ -54,11 +50,9 @@ await Task.Run(async () => await using QuicListener listener = await CreateQuicListener(new IPEndPoint(IPv6Any, 0)); - using QuicConnection clientConnection = await CreateQuicConnection(new IPEndPoint(IPAddress.Loopback, listener.LocalEndPoint.Port)); - var clientStreamTask = clientConnection.ConnectAsync(); - - using QuicConnection serverConnection = await listener.AcceptConnectionAsync(); - await clientStreamTask; + var clientStreamTask = CreateQuicConnection(new IPEndPoint(IPAddress.Loopback, listener.LocalEndPoint.Port)); + await using QuicConnection serverConnection = await listener.AcceptConnectionAsync(); + await using QuicConnection clientConnection = await clientStreamTask; }).WaitAsync(TimeSpan.FromSeconds(6)); } } diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamConnectedStreamConformanceTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamConnectedStreamConformanceTests.cs index 9034aaf94ee3d..8a49e4eac24c1 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamConnectedStreamConformanceTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamConnectedStreamConformanceTests.cs @@ -57,6 +57,8 @@ protected override async Task CreateConnectedStreamsAsync() ApplicationProtocols = new List() { new SslApplicationProtocol("quictest") }, ConnectionOptionsCallback = (_, _, _) => ValueTask.FromResult(new QuicServerConnectionOptions() { + DefaultStreamErrorCode = QuicTestBase.DefaultStreamErrorCodeServer, + DefaultCloseErrorCode = QuicTestBase.DefaultCloseErrorCodeServer, ServerAuthenticationOptions = GetSslServerAuthenticationOptions() }) }); @@ -68,7 +70,7 @@ await WhenAllOrAnyFailed( Task.Run(async () => { connection1 = await listener.AcceptConnectionAsync(); - stream1 = await connection1.AcceptStreamAsync(); + stream1 = await connection1.AcceptInboundStreamAsync(); Assert.Equal(1, await stream1.ReadAsync(buffer)); }), Task.Run(async () => @@ -77,11 +79,12 @@ await WhenAllOrAnyFailed( { connection2 = await QuicConnection.ConnectAsync(new QuicClientConnectionOptions() { + DefaultStreamErrorCode = QuicTestBase.DefaultStreamErrorCodeClient, + DefaultCloseErrorCode = QuicTestBase.DefaultCloseErrorCodeClient, RemoteEndPoint = listener.LocalEndPoint, ClientAuthenticationOptions = GetSslClientAuthenticationOptions() }); - await connection2.ConnectAsync(); - stream2 = await connection2.OpenBidirectionalStreamAsync(); + stream2 = await connection2.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); // OpenBidirectionalStream only allocates ID. We will force stream opening // by Writing there and receiving data on the other side. await stream2.WriteAsync(buffer); @@ -105,14 +108,17 @@ await WhenAllOrAnyFailed( private sealed class StreamPairWithOtherDisposables : StreamPair { - public readonly List Disposables = new List(); + public readonly List Disposables = new List(); public StreamPairWithOtherDisposables(Stream stream1, Stream stream2) : base(stream1, stream2) { } public override void Dispose() { base.Dispose(); - Disposables.ForEach(d => d.Dispose()); + foreach (IAsyncDisposable disposable in Disposables) + { + disposable.DisposeAsync().GetAwaiter().GetResult(); + } } } } diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs index 5101ade6a61db..3d00da3cdc742 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs @@ -26,7 +26,7 @@ await RunClientServer( iterations: 100, serverFunction: async connection => { - await using QuicStream stream = await connection.AcceptStreamAsync(); + await using QuicStream stream = await connection.AcceptInboundStreamAsync(); byte[] buffer = new byte[s_data.Length]; int bytesRead = await ReadAll(stream, buffer); @@ -34,22 +34,19 @@ await RunClientServer( Assert.Equal(s_data.Length, bytesRead); Assert.Equal(s_data, buffer); - await stream.WriteAsync(s_data, endStream: true); - await stream.ShutdownCompleted(); + await stream.WriteAsync(s_data, completeWrites: true); }, clientFunction: async connection => { - await using QuicStream stream = await connection.OpenBidirectionalStreamAsync(); + await using QuicStream stream = await connection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); - await stream.WriteAsync(s_data, endStream: true); + await stream.WriteAsync(s_data, completeWrites: true); byte[] buffer = new byte[s_data.Length]; int bytesRead = await ReadAll(stream, buffer); Assert.Equal(s_data.Length, bytesRead); Assert.Equal(s_data, buffer); - - await stream.ShutdownCompleted(); } ); } @@ -73,7 +70,7 @@ await RunClientServer( iterations: 100, serverFunction: async connection => { - await using QuicStream stream = await connection.AcceptStreamAsync(cts.Token); + await using QuicStream stream = await connection.AcceptInboundStreamAsync(cts.Token); byte[] buffer = new byte[expectedBytesCount]; int bytesRead = await ReadAll(stream, buffer); @@ -84,26 +81,22 @@ await RunClientServer( { await stream.WriteAsync(s_data); } - await stream.WriteAsync(Memory.Empty, endStream: true, cts.Token); - - await stream.ShutdownCompleted(cts.Token); + await stream.WriteAsync(Memory.Empty, completeWrites: true, cts.Token); }, clientFunction: async connection => { - await using QuicStream stream = await connection.OpenBidirectionalStreamAsync(); + await using QuicStream stream = await connection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); for (int i = 0; i < sendCount; i++) { await stream.WriteAsync(s_data, cts.Token); } - await stream.WriteAsync(Memory.Empty, endStream: true, cts.Token); + await stream.WriteAsync(Memory.Empty, completeWrites: true, cts.Token); byte[] buffer = new byte[expectedBytesCount]; int bytesRead = await ReadAll(stream, buffer); Assert.Equal(expectedBytesCount, bytesRead); Assert.Equal(expected, buffer); - - await stream.ShutdownCompleted(cts.Token); } ); } @@ -114,8 +107,8 @@ public async Task MultipleStreamsOnSingleConnection() await RunClientServer( serverFunction: async connection => { - await using QuicStream stream = await connection.AcceptStreamAsync(); - await using QuicStream stream2 = await connection.AcceptStreamAsync(); + await using QuicStream stream = await connection.AcceptInboundStreamAsync(); + await using QuicStream stream2 = await connection.AcceptInboundStreamAsync(); byte[] buffer = new byte[s_data.Length]; byte[] buffer2 = new byte[s_data.Length]; @@ -128,19 +121,16 @@ await RunClientServer( Assert.Equal(s_data.Length, bytesRead2); Assert.Equal(s_data, buffer2); - await stream.WriteAsync(s_data, endStream: true); - await stream2.WriteAsync(s_data, endStream: true); - - await stream.ShutdownCompleted(); - await stream2.ShutdownCompleted(); + await stream.WriteAsync(s_data, completeWrites: true); + await stream2.WriteAsync(s_data, completeWrites: true); }, clientFunction: async connection => { - await using QuicStream stream = await connection.OpenBidirectionalStreamAsync(); - await using QuicStream stream2 = await connection.OpenBidirectionalStreamAsync(); + await using QuicStream stream = await connection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); + await using QuicStream stream2 = await connection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); - await stream.WriteAsync(s_data, endStream: true); - await stream2.WriteAsync(s_data, endStream: true); + await stream.WriteAsync(s_data, completeWrites: true); + await stream2.WriteAsync(s_data, completeWrites: true); byte[] buffer = new byte[s_data.Length]; byte[] buffer2 = new byte[s_data.Length]; @@ -152,9 +142,6 @@ await RunClientServer( int bytesRead2 = await ReadAll(stream2, buffer2); Assert.Equal(s_data.Length, bytesRead2); Assert.Equal(s_data, buffer2); - - await stream.ShutdownCompleted(); - await stream2.ShutdownCompleted(); } ); } @@ -166,8 +153,8 @@ public async Task MultipleConcurrentStreamsOnSingleConnection() Task[] tasks = new Task[count]; (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(); - using (clientConnection) - using (serverConnection) + await using (clientConnection) + await using (serverConnection) { for (int i = 0; i < count; i++) { @@ -179,9 +166,9 @@ public async Task MultipleConcurrentStreamsOnSingleConnection() static async Task MakeStreams(QuicConnection clientConnection, QuicConnection serverConnection) { byte[] buffer = new byte[64]; - QuicStream clientStream = await clientConnection.OpenBidirectionalStreamAsync(); - ValueTask writeTask = clientStream.WriteAsync("PING"u8.ToArray(), endStream: true); - ValueTask acceptTask = serverConnection.AcceptStreamAsync(); + QuicStream clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); + ValueTask writeTask = clientStream.WriteAsync("PING"u8.ToArray(), completeWrites: true); + ValueTask acceptTask = serverConnection.AcceptInboundStreamAsync(); await new Task[] { writeTask.AsTask(), acceptTask.AsTask() }.WhenAllOrAnyFailed(PassingTestTimeoutMilliseconds); QuicStream serverStream = acceptTask.Result; await serverStream.ReadAsync(buffer); @@ -193,11 +180,11 @@ public async Task GetStreamIdWithoutStartWorks() { (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(); - using (clientConnection) - using (serverConnection) + await using (clientConnection) + await using (serverConnection) { - using QuicStream clientStream = await clientConnection.OpenBidirectionalStreamAsync(); - Assert.Equal(0, clientStream.StreamId); + await using QuicStream clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); + Assert.Equal(0, clientStream.Id); // TODO: stream that is opened by client but left unaccepted by server may cause AccessViolationException in its Finalizer // explicitly closing the connections seems to help, but the problem should still be investigated, we should have a meaningful @@ -217,7 +204,7 @@ await RunClientServer( iterations: 5, serverFunction: async connection => { - await using QuicStream stream = await connection.AcceptStreamAsync(); + await using QuicStream stream = await connection.AcceptInboundStreamAsync(); byte[] buffer = new byte[data.Length]; int bytesRead = await ReadAll(stream, buffer); @@ -228,26 +215,22 @@ await RunClientServer( { await stream.WriteAsync(data[pos..(pos + writeSize)]); } - await stream.WriteAsync(Memory.Empty, endStream: true); - - await stream.ShutdownCompleted(); + await stream.WriteAsync(Memory.Empty, completeWrites: true); }, clientFunction: async connection => { - await using QuicStream stream = await connection.OpenBidirectionalStreamAsync(); + await using QuicStream stream = await connection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); for (int pos = 0; pos < data.Length; pos += writeSize) { await stream.WriteAsync(data[pos..(pos + writeSize)]); } - await stream.WriteAsync(Memory.Empty, endStream: true); + await stream.WriteAsync(Memory.Empty, completeWrites: true); byte[] buffer = new byte[data.Length]; int bytesRead = await ReadAll(stream, buffer); Assert.Equal(data.Length, bytesRead); AssertExtensions.SequenceEqual(data, buffer); - - await stream.ShutdownCompleted(); } ); } @@ -257,14 +240,12 @@ public async Task TestStreams() { await using QuicListener listener = await CreateQuicListener(); var clientOptions = CreateQuicClientOptions(listener.LocalEndPoint); - clientOptions.MaxBidirectionalStreams = 1; - clientOptions.MaxUnidirectionalStreams = 1; + clientOptions.MaxInboundBidirectionalStreams = 1; + clientOptions.MaxInboundUnidirectionalStreams = 1; (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(clientOptions, listener); - using (clientConnection) - using (serverConnection) + await using (clientConnection) + await using (serverConnection) { - Assert.True(clientConnection.Connected); - Assert.True(serverConnection.Connected); Assert.Equal(listener.LocalEndPoint, serverConnection.LocalEndPoint); Assert.Equal(listener.LocalEndPoint, clientConnection.RemoteEndPoint); Assert.Equal(clientConnection.LocalEndPoint, serverConnection.RemoteEndPoint); @@ -279,13 +260,13 @@ public async Task TestStreams() private static async Task CreateAndTestBidirectionalStream(QuicConnection c1, QuicConnection c2) { - using QuicStream s1 = await c1.OpenBidirectionalStreamAsync(); + await using QuicStream s1 = await c1.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); Assert.True(s1.CanRead); Assert.True(s1.CanWrite); ValueTask writeTask = s1.WriteAsync(s_data); - using QuicStream s2 = await c2.AcceptStreamAsync(); + await using QuicStream s2 = await c2.AcceptInboundStreamAsync(); await ReceiveDataAsync(s_data, s2); await writeTask; await TestBidirectionalStream(s1, s2); @@ -293,14 +274,14 @@ private static async Task CreateAndTestBidirectionalStream(QuicConnection c1, Qu private static async Task CreateAndTestUnidirectionalStream(QuicConnection c1, QuicConnection c2) { - using QuicStream s1 = await c1.OpenUnidirectionalStreamAsync(); + await using QuicStream s1 = await c1.OpenOutboundStreamAsync(QuicStreamType.Unidirectional); Assert.False(s1.CanRead); Assert.True(s1.CanWrite); ValueTask writeTask = s1.WriteAsync(s_data); - using QuicStream s2 = await c2.AcceptStreamAsync(); + await using QuicStream s2 = await c2.AcceptInboundStreamAsync(); await ReceiveDataAsync(s_data, s2); await writeTask; await TestUnidirectionalStream(s1, s2); @@ -312,7 +293,7 @@ private static async Task TestBidirectionalStream(QuicStream s1, QuicStream s2) Assert.True(s1.CanWrite); Assert.True(s2.CanRead); Assert.True(s2.CanWrite); - Assert.Equal(s1.StreamId, s2.StreamId); + Assert.Equal(s1.Id, s2.Id); await SendAndReceiveDataAsync(s_data, s1, s2); await SendAndReceiveDataAsync(s_data, s2, s1); @@ -321,9 +302,6 @@ private static async Task TestBidirectionalStream(QuicStream s1, QuicStream s2) await SendAndReceiveEOFAsync(s1, s2); await SendAndReceiveEOFAsync(s2, s1); - - await s1.ShutdownCompleted(); - await s2.ShutdownCompleted(); } private static async Task TestUnidirectionalStream(QuicStream s1, QuicStream s2) @@ -332,15 +310,12 @@ private static async Task TestUnidirectionalStream(QuicStream s1, QuicStream s2) Assert.True(s1.CanWrite); Assert.True(s2.CanRead); Assert.False(s2.CanWrite); - Assert.Equal(s1.StreamId, s2.StreamId); + Assert.Equal(s1.Id, s2.Id); await SendAndReceiveDataAsync(s_data, s1, s2); await SendAndReceiveDataAsync(s_data, s1, s2); await SendAndReceiveEOFAsync(s1, s2); - - await s1.ShutdownCompleted(); - await s2.ShutdownCompleted(); } private static async Task SendAndReceiveDataAsync(byte[] data, QuicStream s1, QuicStream s2) @@ -367,7 +342,7 @@ private static async Task SendAndReceiveEOFAsync(QuicStream s1, QuicStream s2) { byte[] readBuffer = new byte[1]; - await s1.WriteAsync(Memory.Empty, endStream: true); + await s1.WriteAsync(Memory.Empty, completeWrites: true); int bytesRead = await s2.ReadAsync(readBuffer); Assert.Equal(0, bytesRead); @@ -387,7 +362,7 @@ public async Task ReadWrite_Random_Success(int readSize, int writeSize) await RunClientServer( async clientConnection => { - await using QuicStream clientStream = await clientConnection.OpenUnidirectionalStreamAsync(); + await using QuicStream clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Unidirectional); ReadOnlyMemory sendBuffer = testBuffer; while (sendBuffer.Length != 0) @@ -397,12 +372,11 @@ await RunClientServer( sendBuffer = sendBuffer.Slice(chunk.Length); } - await clientStream.WriteAsync(Memory.Empty, endStream: true); - await clientStream.ShutdownCompleted(); + await clientStream.WriteAsync(Memory.Empty, completeWrites: true); }, async serverConnection => { - await using QuicStream serverStream = await serverConnection.AcceptStreamAsync(); + await using QuicStream serverStream = await serverConnection.AcceptInboundStreamAsync(); byte[] receiveBuffer = new byte[testBuffer.Length]; int totalBytesRead = 0; @@ -421,8 +395,6 @@ await RunClientServer( Assert.Equal(testBuffer.Length, totalBytesRead); AssertExtensions.SequenceEqual(testBuffer, receiveBuffer); - - await serverStream.ShutdownCompleted(); }); } @@ -440,7 +412,6 @@ from writeSize in sizes public async Task Read_WriteAborted_Throws() { const int ExpectedErrorCode = 0xfffffff; - using SemaphoreSlim sem = new SemaphoreSlim(0); await RunBidirectionalClientServer( @@ -449,7 +420,7 @@ await RunBidirectionalClientServer( await clientStream.WriteAsync(new byte[1]); await sem.WaitAsync(); - clientStream.AbortWrite(ExpectedErrorCode); + clientStream.Abort(QuicAbortDirection.Write, ExpectedErrorCode); }, async serverStream => { @@ -474,7 +445,7 @@ await RunBidirectionalClientServer( { await clientStream.WriteAsync(new byte[1]); sem.Release(); - clientStream.Shutdown(); + clientStream.CompleteWrites(); sem.Release(); }, async serverStream => @@ -503,14 +474,14 @@ await RunBidirectionalClientServer( public async Task ReadOutstanding_ReadAborted_Throws() { (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(); - using (clientConnection) - using (serverConnection) + await using (clientConnection) + await using (serverConnection) { byte[] buffer = new byte[1] { 42 }; const int ExpectedErrorCode = 0xfffffff; - QuicStream clientStream = await clientConnection.OpenBidirectionalStreamAsync(); - Task t = serverConnection.AcceptStreamAsync().AsTask(); + QuicStream clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); + Task t = serverConnection.AcceptInboundStreamAsync().AsTask(); await TaskTimeoutExtensions.WhenAllOrAnyFailed(clientStream.WriteAsync(buffer).AsTask(), t, PassingTestTimeoutMilliseconds); QuicStream serverStream = t.Result; Assert.Equal(1, await serverStream.ReadAsync(buffer)); @@ -522,7 +493,7 @@ public async Task ReadOutstanding_ReadAborted_Throws() Task exTask = AssertThrowsQuicExceptionAsync(QuicError.OperationAborted, () => serverStream.ReadAsync(new byte[1]).AsTask()); Assert.False(exTask.IsCompleted); - serverStream.AbortRead(ExpectedErrorCode); + serverStream.Abort(QuicAbortDirection.Read, ExpectedErrorCode); await exTask; } @@ -537,12 +508,12 @@ public async Task WriteAbortedWithoutWriting_ReadThrows() await RunClientServer( clientFunction: async connection => { - await using QuicStream stream = await connection.OpenUnidirectionalStreamAsync(); - stream.AbortWrite(expectedErrorCode); + await using QuicStream stream = await connection.OpenOutboundStreamAsync(QuicStreamType.Unidirectional); + stream.Abort(QuicAbortDirection.Write, expectedErrorCode); }, serverFunction: async connection => { - await using QuicStream stream = await connection.AcceptStreamAsync(); + await using QuicStream stream = await connection.AcceptInboundStreamAsync(); byte[] buffer = new byte[1]; @@ -563,12 +534,12 @@ public async Task ReadAbortedWithoutReading_WriteThrows() await RunClientServer( clientFunction: async connection => { - await using QuicStream stream = await connection.OpenBidirectionalStreamAsync(); - stream.AbortRead(expectedErrorCode); + await using QuicStream stream = await connection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); + stream.Abort(QuicAbortDirection.Read, expectedErrorCode); }, serverFunction: async connection => { - await using QuicStream stream = await connection.AcceptStreamAsync(); + await using QuicStream stream = await connection.AcceptInboundStreamAsync(); QuicException ex = await AssertThrowsQuicExceptionAsync(QuicError.StreamAborted, () => WriteForever(stream)); Assert.Equal(expectedErrorCode, ex.ApplicationErrorCode); @@ -582,12 +553,10 @@ await RunClientServer( [Fact] public async Task WritePreCanceled_Throws() { - const long expectedErrorCode = 1234; - await RunClientServer( clientFunction: async connection => { - await using QuicStream stream = await connection.OpenUnidirectionalStreamAsync(); + await using QuicStream stream = await connection.OpenOutboundStreamAsync(QuicStreamType.Unidirectional); CancellationTokenSource cts = new CancellationTokenSource(); cts.Cancel(); @@ -596,21 +565,15 @@ await RunClientServer( // aborting write causes the write direction to throw on subsequent operations await AssertThrowsQuicExceptionAsync(QuicError.OperationAborted, () => stream.WriteAsync(new byte[1]).AsTask()); - - // manual write abort is still required - stream.AbortWrite(expectedErrorCode); - - await stream.ShutdownCompleted(); }, serverFunction: async connection => { - await using QuicStream stream = await connection.AcceptStreamAsync(); + await using QuicStream stream = await connection.AcceptInboundStreamAsync(); byte[] buffer = new byte[1024 * 1024]; - await AssertThrowsQuicExceptionAsync(QuicError.StreamAborted, () => ReadAll(stream, buffer)); - - await stream.ShutdownCompleted(); + QuicException ex = await AssertThrowsQuicExceptionAsync(QuicError.StreamAborted, () => ReadAll(stream, buffer)); + Assert.Equal(DefaultStreamErrorCodeClient, ex.ApplicationErrorCode); } ); } @@ -618,12 +581,10 @@ await RunClientServer( [Fact] public async Task WriteCanceled_NextWriteThrows() { - const long expectedErrorCode = 1234; - await RunClientServer( clientFunction: async connection => { - await using QuicStream stream = await connection.OpenUnidirectionalStreamAsync(); + await using QuicStream stream = await connection.OpenOutboundStreamAsync(QuicStreamType.Unidirectional); CancellationTokenSource cts = new CancellationTokenSource(500); @@ -641,15 +602,10 @@ async Task WriteUntilCanceled() // next write would also throw await AssertThrowsQuicExceptionAsync(QuicError.OperationAborted, () => stream.WriteAsync(new byte[1]).AsTask()); - - // manual write abort is still required - stream.AbortWrite(expectedErrorCode); - - await stream.ShutdownCompleted(); }, serverFunction: async connection => { - await using QuicStream stream = await connection.AcceptStreamAsync(); + await using QuicStream stream = await connection.AcceptInboundStreamAsync(); async Task ReadUntilAborted() { @@ -664,9 +620,8 @@ async Task ReadUntilAborted() } } - await AssertThrowsQuicExceptionAsync(QuicError.StreamAborted, () => ReadUntilAborted()); - - await stream.ShutdownCompleted(); + QuicException ex = await AssertThrowsQuicExceptionAsync(QuicError.StreamAborted, () => ReadUntilAborted()); + Assert.Equal(DefaultStreamErrorCodeClient, ex.ApplicationErrorCode); } ); } @@ -680,25 +635,24 @@ public async Task AbortAfterDispose_ProperlyOpenedStream_Success() await RunClientServer( clientFunction: async connection => { - QuicStream stream = await connection.OpenBidirectionalStreamAsync(); + QuicStream stream = await connection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); // Force stream to open on the wire await stream.WriteAsync(buffer); await sem.WaitAsync(); - stream.Dispose(); + await stream.DisposeAsync(); // should not throw ODE on aborting - stream.AbortRead(1234); - stream.AbortWrite(5675); + stream.Abort(QuicAbortDirection.Read, 1234); + stream.Abort(QuicAbortDirection.Write, 5675); }, serverFunction: async connection => { - await using QuicStream stream = await connection.AcceptStreamAsync(); + await using QuicStream stream = await connection.AcceptInboundStreamAsync(); Assert.Equal(1, await stream.ReadAsync(buffer)); sem.Release(); // client will abort both sides, so we will receive the final event - await stream.ShutdownCompleted(); } ); } @@ -709,21 +663,20 @@ public async Task AbortAfterDispose_StreamCreationFlushedByDispose_Success() await RunClientServer( clientFunction: async connection => { - QuicStream stream = await connection.OpenBidirectionalStreamAsync(); + QuicStream stream = await connection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); // dispose will flush stream creation on the wire - stream.Dispose(); + await stream.DisposeAsync(); // should not throw ODE on aborting - stream.AbortRead(1234); - stream.AbortWrite(5675); + stream.Abort(QuicAbortDirection.Read, 1234); + stream.Abort(QuicAbortDirection.Write, 5675); }, serverFunction: async connection => { - await using QuicStream stream = await connection.AcceptStreamAsync(); + await using QuicStream stream = await connection.AcceptInboundStreamAsync(); // client will abort both sides, so we will receive the final event - await stream.ShutdownCompleted(); } ); } @@ -739,12 +692,12 @@ public async Task WaitForWriteCompletionAsync_ClientReadAborted_Throws() await RunBidirectionalClientServer( async clientStream => { - await clientStream.WriteAsync(new byte[1], endStream: true); + await clientStream.WriteAsync(new byte[1], completeWrites: true); // Wait for server to read data await sem.WaitAsync(); - clientStream.AbortRead(ExpectedErrorCode); + clientStream.Abort(QuicAbortDirection.Read, ExpectedErrorCode); }, async serverStream => { @@ -769,7 +722,7 @@ async ValueTask ReleaseOnWriteCompletionAsync() { try { - await serverStream.WaitForWriteCompletionAsync(); + await serverStream.WritesClosed; waitForAbortTcs.SetException(new Exception("WaitForWriteCompletionAsync didn't throw stream aborted.")); } catch (QuicException ex) when (ex.QuicError == QuicError.StreamAborted) @@ -803,7 +756,7 @@ await RunBidirectionalClientServer( // But in most cases it will still exercise aborting the outstanding write task. var writeTask = WriteForever(serverStream, 1024 * 1024); - serverStream.AbortWrite(ExpectedErrorCode); + serverStream.Abort(QuicAbortDirection.Write, ExpectedErrorCode); await AssertThrowsQuicExceptionAsync(QuicError.OperationAborted, () => writeTask.WaitAsync(TimeSpan.FromSeconds(3))); sem.Release(); @@ -814,13 +767,15 @@ await RunBidirectionalClientServer( public async Task WaitForWriteCompletionAsync_ServerWriteAborted_Throws() { const int ExpectedErrorCode = 0xfffffff; + SemaphoreSlim sem = new SemaphoreSlim(0); TaskCompletionSource waitForAbortTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); await RunBidirectionalClientServer( async clientStream => { - await clientStream.WriteAsync(new byte[1], endStream: true); + await clientStream.WriteAsync(new byte[1], completeWrites: true); + await sem.WaitAsync(); }, async serverStream => { @@ -833,7 +788,8 @@ await RunBidirectionalClientServer( Assert.False(writeCompletionTask.IsCompleted, "Server is still writing."); - serverStream.AbortWrite(ExpectedErrorCode); + serverStream.Abort(QuicAbortDirection.Write, ExpectedErrorCode); + sem.Release(); await waitForAbortTcs.Task; await writeCompletionTask; @@ -842,7 +798,7 @@ async ValueTask ReleaseOnWriteCompletionAsync() { try { - await serverStream.WaitForWriteCompletionAsync(); + await serverStream.WritesClosed; waitForAbortTcs.SetException(new Exception("WaitForWriteCompletionAsync didn't throw stream aborted.")); } catch (QuicException ex) when (ex.QuicError == QuicError.OperationAborted) @@ -863,7 +819,7 @@ public async Task WaitForWriteCompletionAsync_ServerShutdown_Success() await RunBidirectionalClientServer( async clientStream => { - await clientStream.WriteAsync(new byte[1], endStream: true); + await clientStream.WriteAsync(new byte[1], completeWrites: true); int readCount = await clientStream.ReadAsync(new byte[1]); Assert.Equal(1, readCount); @@ -873,7 +829,7 @@ await RunBidirectionalClientServer( }, async serverStream => { - var writeCompletionTask = serverStream.WaitForWriteCompletionAsync(); + var writeCompletionTask = serverStream.WritesClosed; int received = await serverStream.ReadAsync(new byte[1]); Assert.Equal(1, received); @@ -884,7 +840,7 @@ await RunBidirectionalClientServer( Assert.False(writeCompletionTask.IsCompleted, "Server is still writing."); - serverStream.Shutdown(); + serverStream.CompleteWrites(); await writeCompletionTask; }); @@ -896,7 +852,7 @@ public async Task WaitForWriteCompletionAsync_GracefulShutdown_Success() await RunBidirectionalClientServer( async clientStream => { - await clientStream.WriteAsync(new byte[1], endStream: true); + await clientStream.WriteAsync(new byte[1], completeWrites: true); int readCount = await clientStream.ReadAsync(new byte[1]); Assert.Equal(1, readCount); @@ -906,7 +862,7 @@ await RunBidirectionalClientServer( }, async serverStream => { - var writeCompletionTask = serverStream.WaitForWriteCompletionAsync(); + var writeCompletionTask = serverStream.WritesClosed; int received = await serverStream.ReadAsync(new byte[1]); Assert.Equal(1, received); @@ -915,7 +871,7 @@ await RunBidirectionalClientServer( Assert.False(writeCompletionTask.IsCompleted, "Server is still writing."); - await serverStream.WriteAsync(new byte[1], endStream: true); + await serverStream.WriteAsync(new byte[1], completeWrites: true); await writeCompletionTask; }); @@ -932,7 +888,7 @@ public async Task WaitForWriteCompletionAsync_ConnectionClosed_Throws() await RunClientServer( serverFunction: async connection => { - await using QuicStream stream = await connection.AcceptStreamAsync(); + await using QuicStream stream = await connection.AcceptInboundStreamAsync(); var writeCompletionTask = ReleaseOnWriteCompletionAsync(); @@ -953,7 +909,7 @@ async ValueTask ReleaseOnWriteCompletionAsync() { try { - await stream.WaitForWriteCompletionAsync(); + await stream.WritesClosed; waitForAbortTcs.SetException(new Exception("WaitForWriteCompletionAsync didn't throw connection aborted.")); } catch (QuicException ex) when (ex.QuicError == QuicError.ConnectionAborted) @@ -964,11 +920,11 @@ async ValueTask ReleaseOnWriteCompletionAsync() }, clientFunction: async connection => { - await using QuicStream stream = await connection.OpenBidirectionalStreamAsync(); + await using QuicStream stream = await connection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); - await stream.WriteAsync(new byte[1], endStream: true); + await stream.WriteAsync(new byte[1], completeWrites: true); - await stream.WaitForWriteCompletionAsync(); + await stream.WritesClosed; // Wait for the server to read data before closing the connection await sem.WaitAsync(); diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs index 53685907d9e9a..8dedd37acf0a3 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs @@ -18,6 +18,11 @@ namespace System.Net.Quic.Tests { public abstract class QuicTestBase { + public const long DefaultStreamErrorCodeClient = 123456; + public const long DefaultStreamErrorCodeServer = 654321; + public const long DefaultCloseErrorCodeClient = 789; + public const long DefaultCloseErrorCodeServer = 987; + private static readonly byte[] s_ping = "PING"u8.ToArray(); private static readonly byte[] s_pong = "PONG"u8.ToArray(); @@ -53,6 +58,8 @@ public QuicServerConnectionOptions CreateQuicServerOptions() { return new QuicServerConnectionOptions() { + DefaultStreamErrorCode = DefaultStreamErrorCodeServer, + DefaultCloseErrorCode = DefaultCloseErrorCodeServer, ServerAuthenticationOptions = GetSslServerAuthenticationOptions() }; } @@ -80,6 +87,8 @@ public QuicClientConnectionOptions CreateQuicClientOptions(EndPoint endpoint) { return new QuicClientConnectionOptions() { + DefaultStreamErrorCode = DefaultStreamErrorCodeClient, + DefaultCloseErrorCode = DefaultCloseErrorCodeClient, RemoteEndPoint = endpoint, ClientAuthenticationOptions = GetSslClientAuthenticationOptions() }; @@ -106,7 +115,7 @@ internal QuicListenerOptions CreateQuicListenerOptions() }; } - internal ValueTask CreateQuicListener(int maxUnidirectionalStreams = 100, int maxBidirectionalStreams = 100) + internal ValueTask CreateQuicListener(int MaxInboundUnidirectionalStreams = 100, int MaxInboundBidirectionalStreams = 100) { var options = CreateQuicListenerOptions(); return CreateQuicListener(options); @@ -130,11 +139,7 @@ internal ValueTask CreateQuicListener(IPEndPoint endpoint) { await using (QuicListener listener = await CreateQuicListener(listenerOptions)) { - clientOptions ??= new QuicClientConnectionOptions() - { - RemoteEndPoint = listener.LocalEndPoint, - ClientAuthenticationOptions = GetSslClientAuthenticationOptions() - }; + clientOptions ??= CreateQuicClientOptions(listener.LocalEndPoint); if (clientOptions.RemoteEndPoint is IPEndPoint iPEndPoint && !iPEndPoint.Equals(listener.LocalEndPoint)) { clientOptions.RemoteEndPoint = listener.LocalEndPoint; @@ -165,11 +170,10 @@ internal ValueTask CreateQuicListener(IPEndPoint endpoint) ValueTask serverTask = listener.AcceptConnectionAsync(); while (retry > 0) { - clientConnection = await CreateQuicConnection(clientOptions); retry--; try { - await clientConnection.ConnectAsync().ConfigureAwait(false); + clientConnection = await CreateQuicConnection(clientOptions).ConfigureAwait(false); break; } catch (QuicException ex) when (ex.HResult == (int)SocketError.ConnectionRefused) @@ -192,17 +196,14 @@ internal ValueTask CreateQuicListener(IPEndPoint endpoint) await listener.DisposeAsync(); } - Assert.True(serverConnection.Connected); - Assert.True(clientConnection.Connected); - return (clientConnection, serverTask.Result); } internal async Task PingPong(QuicConnection client, QuicConnection server) { - using QuicStream clientStream = await client.OpenBidirectionalStreamAsync(); + await using QuicStream clientStream = await client.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); ValueTask t = clientStream.WriteAsync(s_ping); - using QuicStream serverStream = await server.AcceptStreamAsync(); + await using QuicStream serverStream = await server.AcceptInboundStreamAsync(); byte[] buffer = new byte[s_ping.Length]; int remains = s_ping.Length; @@ -241,8 +242,8 @@ internal async Task RunClientServer(Func clientFunction, F for (int i = 0; i < iterations; ++i) { (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(listener); - using (clientConnection) - using (serverConnection) + await using (clientConnection) + await using (serverConnection) { await new[] { @@ -259,8 +260,22 @@ await new[] await serverFinished.WaitAsync(); }) }.WhenAllOrAnyFailed(millisecondsTimeout); - await serverConnection.CloseAsync(ServerCloseErrorCode); - await clientConnection.CloseAsync(ClientCloseErrorCode); + try + { + await serverConnection.CloseAsync(ServerCloseErrorCode); + } + catch (ObjectDisposedException ex) + { + _output.WriteLine(ex.ToString()); + } + try + { + await clientConnection.CloseAsync(ClientCloseErrorCode); + } + catch (ObjectDisposedException ex) + { + _output.WriteLine(ex.ToString()); + } } } } @@ -272,25 +287,23 @@ internal async Task RunStreamClientServer(Func clientFunction, await RunClientServer( clientFunction: async connection => { - await using QuicStream stream = bidi ? await connection.OpenBidirectionalStreamAsync() : await connection.OpenUnidirectionalStreamAsync(); + await using QuicStream stream = bidi ? await connection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional) : await connection.OpenOutboundStreamAsync(QuicStreamType.Unidirectional); // Open(Bi|Uni)directionalStream only allocates ID. We will force stream opening // by Writing there and receiving data on the other side. await stream.WriteAsync(buffer); await clientFunction(stream); - stream.Shutdown(); - await stream.ShutdownCompleted(); + stream.CompleteWrites(); }, serverFunction: async connection => { - await using QuicStream stream = await connection.AcceptStreamAsync(); + await using QuicStream stream = await connection.AcceptInboundStreamAsync(); Assert.Equal(1, await stream.ReadAsync(buffer)); await serverFunction(stream); - stream.Shutdown(); - await stream.ShutdownCompleted(); + stream.CompleteWrites(); }, iterations, millisecondsTimeout diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/System.Net.Quic.Functional.Tests.csproj b/src/libraries/System.Net.Quic/tests/FunctionalTests/System.Net.Quic.Functional.Tests.csproj index 89f016cb962f6..cfb6c5c21e1a7 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/System.Net.Quic.Functional.Tests.csproj +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/System.Net.Quic.Functional.Tests.csproj @@ -32,6 +32,5 @@ - diff --git a/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketOptionsTests.cs b/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketOptionsTests.cs index a662dabeb0f9d..665d8208b4142 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketOptionsTests.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/ClientWebSocketOptionsTests.cs @@ -268,7 +268,7 @@ await LoopbackServer.CreateClientAndServerAsync(async proxyUri => // Send non-success error code so that SocketsHttpHandler won't retry. await connection.SendResponseAsync(statusCode: HttpStatusCode.Forbidden); - connection.Dispose(); + await connection.DisposeAsync(); })); Assert.True(connectionAccepted);