diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs index 006dd15aa8da1..312043e56d477 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs @@ -87,7 +87,7 @@ public void GetService_Throws_WhenGetServiceForScopedServiceIsCalledOnRoot() } [Fact] - public async void GetService_Throws_WhenGetServiceForScopedServiceIsCalledOnRoot_IL_Replacement() + public async Task GetService_Throws_WhenGetServiceForScopedServiceIsCalledOnRoot_IL_Replacement() { // Arrange var serviceCollection = new ServiceCollection(); diff --git a/src/libraries/Microsoft.Extensions.Hosting/tests/UnitTests/OptionsBuilderExtensionsTests.cs b/src/libraries/Microsoft.Extensions.Hosting/tests/UnitTests/OptionsBuilderExtensionsTests.cs index 26b61d0a2d761..74c2a1c0bfe92 100644 --- a/src/libraries/Microsoft.Extensions.Hosting/tests/UnitTests/OptionsBuilderExtensionsTests.cs +++ b/src/libraries/Microsoft.Extensions.Hosting/tests/UnitTests/OptionsBuilderExtensionsTests.cs @@ -246,7 +246,7 @@ private async Task ValidateOnStart_AddEagerValidation_DoesValidationWhenHostStar } [Fact] - private async void CanValidateOptionsEagerly_AddOptionsWithValidateOnStart_IValidateOptions() + private async Task CanValidateOptionsEagerly_AddOptionsWithValidateOnStart_IValidateOptions() { var hostBuilder = CreateHostBuilder(services => services.AddOptionsWithValidateOnStart() diff --git a/src/libraries/Microsoft.Extensions.Http/tests/Microsoft.Extensions.Http.Tests/Logging/HttpClientLoggerTest.cs b/src/libraries/Microsoft.Extensions.Http/tests/Microsoft.Extensions.Http.Tests/Logging/HttpClientLoggerTest.cs index fb1d94310d5ca..6fb8e2dd8f7a9 100644 --- a/src/libraries/Microsoft.Extensions.Http/tests/Microsoft.Extensions.Http.Tests/Logging/HttpClientLoggerTest.cs +++ b/src/libraries/Microsoft.Extensions.Http/tests/Microsoft.Extensions.Http.Tests/Logging/HttpClientLoggerTest.cs @@ -162,7 +162,7 @@ private void AssertCounters(TestCountingLogger testLogger, int requestCount, boo [InlineData(false, true)] [InlineData(true, false)] [InlineData(true, true)] - public async void CustomLogger_LogsCorrectEvents_Sync(bool requestSuccessful, bool asyncSecondCall) + public async Task CustomLogger_LogsCorrectEvents_Sync(bool requestSuccessful, bool asyncSecondCall) { var serviceCollection = new ServiceCollection(); serviceCollection.AddTransient(_ => diff --git a/src/libraries/System.ComponentModel.TypeConverter/tests/TypeDescriptorTests.cs b/src/libraries/System.ComponentModel.TypeConverter/tests/TypeDescriptorTests.cs index b5468a6583f46..193dcba83359a 100644 --- a/src/libraries/System.ComponentModel.TypeConverter/tests/TypeDescriptorTests.cs +++ b/src/libraries/System.ComponentModel.TypeConverter/tests/TypeDescriptorTests.cs @@ -1395,7 +1395,7 @@ public static IEnumerable GetConverter_ByMultithread_ReturnsExpected_T [Theory] [MemberData(nameof(GetConverter_ByMultithread_ReturnsExpected_TestData))] - public async void GetConverter_ByMultithread_ReturnsExpected(Type typeForGetConverter, Type expectedConverterType) + public async Task GetConverter_ByMultithread_ReturnsExpected(Type typeForGetConverter, Type expectedConverterType) { TypeConverter[] actualConverters = await Task.WhenAll( Enumerable.Range(0, 100).Select(_ => @@ -1415,7 +1415,7 @@ public static IEnumerable GetConverterWithAddProvider_ByMultithread_Su [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsReflectionEmitSupported))] // Mock will try to JIT [MemberData(nameof(GetConverterWithAddProvider_ByMultithread_Success_TestData))] - public async void GetConverterWithAddProvider_ByMultithread_Success(Type typeForGetConverter, Type expectedConverterType) + public async Task GetConverterWithAddProvider_ByMultithread_Success(Type typeForGetConverter, Type expectedConverterType) { TypeConverter[] actualConverters = await Task.WhenAll( Enumerable.Range(0, 200).Select(_ => diff --git a/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/WinHttpHandlerTest.cs b/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/WinHttpHandlerTest.cs index f204e21536bee..08e2850e5b0f5 100644 --- a/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/WinHttpHandlerTest.cs +++ b/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/WinHttpHandlerTest.cs @@ -98,7 +98,7 @@ public async Task SendAsync_SlowServerAndCancel_ThrowsTaskCanceledException() [OuterLoop] [Fact] - public async void SendAsync_SlowServerRespondsAfterDefaultReceiveTimeout_ThrowsHttpRequestException() + public async Task SendAsync_SlowServerRespondsAfterDefaultReceiveTimeout_ThrowsHttpRequestException() { var handler = new WinHttpHandler(); using (var client = new HttpClient(handler)) diff --git a/src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/Task.cs b/src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/Task.cs index e6d08b71f5881..b8a2196ec69b3 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/Task.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/Task.cs @@ -17,6 +17,7 @@ using System.Runtime.ExceptionServices; using System.Runtime.InteropServices; using System.Runtime.Versioning; +using System.Threading.Tasks.Sources; namespace System.Threading.Tasks { @@ -6659,6 +6660,191 @@ public static Task> WhenAny(IEnumerable> ta WhenAny>(tasks); #endregion + #region WhenEach + /// Creates an that will yield the supplied tasks as those tasks complete. + /// The task to iterate through when completed. + /// An for iterating through the supplied tasks. + /// + /// The supplied tasks will become available to be output via the enumerable once they've completed. The exact order + /// in which the tasks will become available is not defined. + /// + /// is null. + /// contains a null. + public static IAsyncEnumerable WhenEach(params Task[] tasks) + { + ArgumentNullException.ThrowIfNull(tasks); + return WhenEach((ReadOnlySpan)tasks); + } + + /// + public static IAsyncEnumerable WhenEach(ReadOnlySpan tasks) => // TODO https://github.com/dotnet/runtime/issues/77873: Add params + WhenEachState.Iterate(WhenEachState.Create(tasks)); + + /// + public static IAsyncEnumerable WhenEach(IEnumerable tasks) => + WhenEachState.Iterate(WhenEachState.Create(tasks)); + + /// + public static IAsyncEnumerable> WhenEach(params Task[] tasks) + { + ArgumentNullException.ThrowIfNull(tasks); + return WhenEach((ReadOnlySpan>)tasks); + } + + /// + public static IAsyncEnumerable> WhenEach(ReadOnlySpan> tasks) => // TODO https://github.com/dotnet/runtime/issues/77873: Add params + WhenEachState.Iterate>(WhenEachState.Create(ReadOnlySpan.CastUp(tasks))); + + /// + public static IAsyncEnumerable> WhenEach(IEnumerable> tasks) => + WhenEachState.Iterate>(WhenEachState.Create(tasks)); + + /// Object used by to store its state. + private sealed class WhenEachState : Queue, IValueTaskSource, ITaskCompletionAction + { + /// Implementation backing the ValueTask used to wait for the next task to be available. + /// This is a mutable struct. Do not make it readonly. + private ManualResetValueTaskSourceCore _waitForNextCompletedTask = new() { RunContinuationsAsynchronously = true }; // _waitForNextCompletedTask.Set is called while holding a lock + /// 0 if this has never been used in an iteration; 1 if it has. + /// This is used to ensure we only ever iterate through the tasks once. + private int _enumerated; + + /// Called at the beginning of the iterator to assume ownership of the state. + /// true if the caller owns the state; false if the caller should end immediately. + public bool TryStart() => Interlocked.Exchange(ref _enumerated, 1) == 0; + + /// Gets or sets the number of tasks that haven't yet been yielded. + public int Remaining { get; set; } + + void ITaskCompletionAction.Invoke(Task completingTask) + { + lock (this) + { + // Enqueue the task into the queue. If the Count is now 1, we transitioned from + // empty to non-empty, which means we need to signal the MRVTSC, as the consumer + // could be waiting on a ValueTask representing a completed task being available. + Enqueue(completingTask); + if (Count == 1) + { + Debug.Assert(_waitForNextCompletedTask.GetStatus(_waitForNextCompletedTask.Version) == ValueTaskSourceStatus.Pending); + _waitForNextCompletedTask.SetResult(default); + } + } + } + bool ITaskCompletionAction.InvokeMayRunArbitraryCode => false; + + // Delegate to _waitForNextCompletedTask for IValueTaskSource implementation. + void IValueTaskSource.GetResult(short token) => _waitForNextCompletedTask.GetResult(token); + ValueTaskSourceStatus IValueTaskSource.GetStatus(short token) => _waitForNextCompletedTask.GetStatus(token); + void IValueTaskSource.OnCompleted(Action continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags) => + _waitForNextCompletedTask.OnCompleted(continuation, state, token, flags); + + /// Creates a from the specified tasks. + public static WhenEachState? Create(ReadOnlySpan tasks) + { + WhenEachState? waiter = null; + + if (tasks.Length != 0) + { + waiter = new(); + foreach (Task task in tasks) + { + if (task is null) + { + ThrowHelper.ThrowArgumentException(ExceptionResource.Task_MultiTaskContinuation_NullTask, ExceptionArgument.tasks); + } + + waiter.Remaining++; + task.AddCompletionAction(waiter); + } + } + + return waiter; + } + + /// + public static WhenEachState? Create(IEnumerable tasks) + { + ArgumentNullException.ThrowIfNull(tasks); + + WhenEachState? waiter = null; + + IEnumerator e = tasks.GetEnumerator(); + if (e.MoveNext()) + { + waiter = new(); + do + { + Task task = e.Current; + if (task is null) + { + ThrowHelper.ThrowArgumentException(ExceptionResource.Task_MultiTaskContinuation_NullTask, ExceptionArgument.tasks); + } + + waiter.Remaining++; + task.AddCompletionAction(waiter); + } + while (e.MoveNext()); + } + + return waiter; + } + + /// Iterates through the tasks represented by the provided waiter. + public static async IAsyncEnumerable Iterate(WhenEachState? waiter, [EnumeratorCancellation] CancellationToken cancellationToken = default) where T : Task + { + // The enumerable could have GetAsyncEnumerator called on it multiple times. As we're dealing with Tasks that + // only ever transition from non-completed to completed, re-enumeration doesn't have much benefit, so we take + // advantage of the optimizations possible by not supporting that and simply have the semantics that, no matter + // how many times the enumerable is enumerated, every task is yielded only once. The original GetAsyncEnumerator + // call will give back all the tasks, and all subsequent iterations will be empty. + if (waiter?.TryStart() is not true) + { + yield break; + } + + // Loop until we've yielded all tasks. + while (waiter.Remaining > 0) + { + // Either get the next completed task from the queue, or get a + // ValueTask with which to wait for the next task to complete. + Task? next; + ValueTask waitTask = default; + lock (waiter) + { + // Reset the MRVTSC if it was signaled, then try to dequeue a task and + // either return one we got or return a ValueTask that will be signaled + // when the next completed task is available. + waiter._waitForNextCompletedTask.Reset(); + if (!waiter.TryDequeue(out next)) + { + waitTask = new(waiter, waiter._waitForNextCompletedTask.Version); + } + } + + // If we got a completed Task, yield it. + if (next is not null) + { + cancellationToken.ThrowIfCancellationRequested(); + waiter.Remaining--; + yield return (T)next; + continue; + } + + // If we have a cancellation token and the ValueTask isn't already completed, + // get a Task from the ValueTask so we can use WaitAsync to make the wait cancelable. + // Otherwise, just await the ValueTask directly. We don't need to be concerned + // about suppressing exceptions, as the ValueTask is only ever completed successfully. + if (cancellationToken.CanBeCanceled && !waitTask.IsCompleted) + { + waitTask = new ValueTask(waitTask.AsTask().WaitAsync(cancellationToken)); + } + await waitTask.ConfigureAwait(false); + } + } + } + #endregion + internal static Task CreateUnwrapPromise(Task outerTask, bool lookForOce) { Debug.Assert(outerTask != null); diff --git a/src/libraries/System.Runtime/ref/System.Runtime.cs b/src/libraries/System.Runtime/ref/System.Runtime.cs index e8f88bb1ce1f6..ef4bc568fb56f 100644 --- a/src/libraries/System.Runtime/ref/System.Runtime.cs +++ b/src/libraries/System.Runtime/ref/System.Runtime.cs @@ -15335,6 +15335,12 @@ public static void WaitAll(System.Threading.Tasks.Task[] tasks, System.Threading public static System.Threading.Tasks.Task> WhenAny(System.Collections.Generic.IEnumerable> tasks) { throw null; } public static System.Threading.Tasks.Task> WhenAny(System.Threading.Tasks.Task task1, System.Threading.Tasks.Task task2) { throw null; } public static System.Threading.Tasks.Task> WhenAny(params System.Threading.Tasks.Task[] tasks) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable WhenEach(System.Collections.Generic.IEnumerable tasks) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable WhenEach(params System.Threading.Tasks.Task[] tasks) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable WhenEach(System.ReadOnlySpan tasks) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable> WhenEach(System.Collections.Generic.IEnumerable> tasks) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable> WhenEach(params System.Threading.Tasks.Task[] tasks) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable> WhenEach(System.ReadOnlySpan> tasks) { throw null; } public static System.Runtime.CompilerServices.YieldAwaitable Yield() { throw null; } } public static partial class TaskAsyncEnumerableExtensions diff --git a/src/libraries/System.Runtime/tests/System.Threading.Tasks.Tests/MethodCoverage.cs b/src/libraries/System.Runtime/tests/System.Threading.Tasks.Tests/MethodCoverage.cs index eca5c0c92e203..2634e8340ef9e 100644 --- a/src/libraries/System.Runtime/tests/System.Threading.Tasks.Tests/MethodCoverage.cs +++ b/src/libraries/System.Runtime/tests/System.Threading.Tasks.Tests/MethodCoverage.cs @@ -924,5 +924,118 @@ public static void Task_WhenAll_TwoTasks_WakesOnBothCompletionWithExceptionAndCa Assert.Equal(e1, twa.Exception?.InnerException); } } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public void Task_WhenEach_NullsTriggerExceptions() + { + AssertExtensions.Throws("tasks", () => Task.WhenEach((Task[])null)); + AssertExtensions.Throws("tasks", () => Task.WhenEach((Task[])null)); + AssertExtensions.Throws("tasks", () => Task.WhenEach((IEnumerable)null)); + AssertExtensions.Throws("tasks", () => Task.WhenEach((IEnumerable>)null)); + + AssertExtensions.Throws("tasks", () => Task.WhenEach((Task[])[null])); + AssertExtensions.Throws("tasks", () => Task.WhenEach((ReadOnlySpan)[null])); + AssertExtensions.Throws("tasks", () => Task.WhenEach((IEnumerable)[null])); + AssertExtensions.Throws("tasks", () => Task.WhenEach((Task[])[null])); + AssertExtensions.Throws("tasks", () => Task.WhenEach((ReadOnlySpan>)[null])); + AssertExtensions.Throws("tasks", () => Task.WhenEach((IEnumerable>)[null])); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public async Task Task_WhenEach_EmptyInputsCompleteImmediately() + { + Assert.False(await Task.WhenEach((Task[])[]).GetAsyncEnumerator().MoveNextAsync()); + Assert.False(await Task.WhenEach((ReadOnlySpan)[]).GetAsyncEnumerator().MoveNextAsync()); + Assert.False(await Task.WhenEach((IEnumerable)[]).GetAsyncEnumerator().MoveNextAsync()); + Assert.False(await Task.WhenEach((Task[])[]).GetAsyncEnumerator().MoveNextAsync()); + Assert.False(await Task.WhenEach((ReadOnlySpan>)[]).GetAsyncEnumerator().MoveNextAsync()); + Assert.False(await Task.WhenEach((IEnumerable>)[]).GetAsyncEnumerator().MoveNextAsync()); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public async Task Task_WhenEach_TasksOnlyEnumerableOnce() + { + IAsyncEnumerable[] enumerables = + [ + Task.WhenEach((Task[])[Task.CompletedTask, Task.CompletedTask]), + Task.WhenEach((ReadOnlySpan)[Task.CompletedTask, Task.CompletedTask]), + Task.WhenEach((IEnumerable)[Task.CompletedTask, Task.CompletedTask]), + Task.WhenEach((Task[])[Task.FromResult(0), Task.FromResult(0)]), + Task.WhenEach((ReadOnlySpan>)[Task.FromResult(0), Task.FromResult(0)]), + Task.WhenEach((IEnumerable>)[Task.FromResult(0), Task.FromResult(0)]), + ]; + + foreach (IAsyncEnumerable e in enumerables) + { + IAsyncEnumerator e1 = e.GetAsyncEnumerator(); + IAsyncEnumerator e2 = e.GetAsyncEnumerator(); + IAsyncEnumerator e3 = e.GetAsyncEnumerator(); + + Assert.True(await e1.MoveNextAsync()); + Assert.False(await e2.MoveNextAsync()); + Assert.False(await e3.MoveNextAsync()); + + int count = 0; + do + { + count++; + } + while (await e1.MoveNextAsync()); + Assert.Equal(2, count); + + Assert.False(await e.GetAsyncEnumerator().MoveNextAsync()); + } + } + + [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + [InlineData(0)] + [InlineData(1)] + [InlineData(2)] + [InlineData(3)] + [InlineData(4)] + [InlineData(5)] + public async Task Task_WhenEach_IteratesThroughCompleteAndIncompleteTasks(int mode) + { + TaskCompletionSource tcs1 = new(), tcs2 = new(), tcs3 = new(); + Task[] array = [Task.FromResult(1), tcs1.Task, Task.FromResult(2), tcs2.Task, Task.FromResult(3), tcs3.Task]; + + IAsyncEnumerable tasks = mode switch + { + 0 => Task.WhenEach((ReadOnlySpan)array), + 1 => Task.WhenEach((Task[])array), + 2 => Task.WhenEach((IEnumerable)array), + 3 => Task.WhenEach((ReadOnlySpan>)array), + 4 => Task.WhenEach((Task[])array), + _ => Task.WhenEach((IEnumerable>)array), + }; + + Assert.NotNull(tasks); + + IAsyncEnumerator e = tasks.GetAsyncEnumerator(); + Assert.NotNull(tasks); + + ValueTask moveNext; + + for (int i = 1; i <= 3; i++) + { + moveNext = e.MoveNextAsync(); + Assert.True(moveNext.IsCompletedSuccessfully); + Assert.True(moveNext.Result); + Assert.Same(Task.FromResult(i), e.Current); + } + + foreach (TaskCompletionSource tcs in new[] { tcs2, tcs1, tcs3 }) + { + moveNext = e.MoveNextAsync(); + Assert.False(moveNext.IsCompleted); + tcs.SetResult(42); + Assert.True(await moveNext); + Assert.Same(tcs.Task, e.Current); + } + + moveNext = e.MoveNextAsync(); + Assert.True(moveNext.IsCompletedSuccessfully); + Assert.False(moveNext.Result); + } } } diff --git a/src/libraries/System.Text.Json/tests/Common/PropertyVisibilityTests.cs b/src/libraries/System.Text.Json/tests/Common/PropertyVisibilityTests.cs index 44c702025f5d1..f957db15a96b3 100644 --- a/src/libraries/System.Text.Json/tests/Common/PropertyVisibilityTests.cs +++ b/src/libraries/System.Text.Json/tests/Common/PropertyVisibilityTests.cs @@ -245,7 +245,7 @@ public async Task Ignore_BasePublicPropertyIgnored_ConflictWithDerivedPrivate() } [Fact] - public async void Ignore_BasePublicPropertyIgnored_ConflictWithDerivedPublicPropertyIgnored() + public async Task Ignore_BasePublicPropertyIgnored_ConflictWithDerivedPublicPropertyIgnored() { var obj = new ClassWithIgnoredPublicPropertyAndNewSlotPublicAndIgnoredToo();