Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Task.WhenEach to process tasks as they complete #100316

Merged
merged 6 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
using System.Runtime.ExceptionServices;
using System.Runtime.InteropServices;
using System.Runtime.Versioning;
using System.Threading.Tasks.Sources;

namespace System.Threading.Tasks
{
Expand Down Expand Up @@ -6659,6 +6660,193 @@ public static Task<Task<TResult>> WhenAny<TResult>(IEnumerable<Task<TResult>> ta
WhenAny<Task<TResult>>(tasks);
#endregion

#region WhenEach
/// <summary>Creates an <see cref="IAsyncEnumerable{T}"/> that will yield the supplied tasks as those tasks complete.</summary>
/// <param name="tasks">The task to iterate through when completed.</param>
/// <returns>An <see cref="IAsyncEnumerable{T}"/> for iterating through the supplied tasks.</returns>
/// <remarks>
/// The supplied tasks will become available to be output via the enumerable once they've completed. The exact order
/// in which the tasks will become available is not defined.
/// </remarks>
/// <exception cref="ArgumentNullException"><paramref name="tasks"/> is null.</exception>
/// <exception cref="ArgumentException"><paramref name="tasks"/> contains a null.</exception>
public static IAsyncEnumerable<Task> WhenEach(params Task[] tasks)
{
ArgumentNullException.ThrowIfNull(tasks);
return WhenEach((ReadOnlySpan<Task>)tasks);
}

/// <inheritdoc cref="WhenEach(Task[])"/>
public static IAsyncEnumerable<Task> WhenEach(ReadOnlySpan<Task> tasks) => // TODO https://github.com/dotnet/runtime/issues/77873: Add params
WhenEachState.Iterate<Task>(WhenEachState.Create(tasks));

/// <inheritdoc cref="WhenEach(Task[])"/>
public static IAsyncEnumerable<Task> WhenEach(IEnumerable<Task> tasks) =>
WhenEachState.Iterate<Task>(WhenEachState.Create(tasks));

/// <inheritdoc cref="WhenEach(Task[])"/>
public static IAsyncEnumerable<Task<TResult>> WhenEach<TResult>(params Task<TResult>[] tasks)
{
ArgumentNullException.ThrowIfNull(tasks);
return WhenEach((ReadOnlySpan<Task<TResult>>)tasks);
}

/// <inheritdoc cref="WhenEach(Task[])"/>
public static IAsyncEnumerable<Task<TResult>> WhenEach<TResult>(ReadOnlySpan<Task<TResult>> tasks) => // TODO https://github.com/dotnet/runtime/issues/77873: Add params
WhenEachState.Iterate<Task<TResult>>(WhenEachState.Create(ReadOnlySpan<Task>.CastUp(tasks)));

/// <inheritdoc cref="WhenEach(Task[])"/>
public static IAsyncEnumerable<Task<TResult>> WhenEach<TResult>(IEnumerable<Task<TResult>> tasks) =>
WhenEachState.Iterate<Task<TResult>>(WhenEachState.Create(tasks));

/// <summary>Object used by <see cref="Iterate"/> to store its state.</summary>
private sealed class WhenEachState : Queue<Task>, IValueTaskSource, ITaskCompletionAction
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
{
/// <summary>Implementation backing the ValueTask used to wait for the next task to be available.</summary>
/// <remarks>This is a mutable struct. Do not make it readonly.</remarks>
private ManualResetValueTaskSourceCore<bool> _waitForNextCompletedTask = new() { RunContinuationsAsynchronously = true }; // _waitForNextCompletedTask.Set is called while holding a lock
/// <summary>0 if this has never been used in an iteration; 1 if it has.</summary>
/// <remarks>This is used to ensure we only ever iterate through the tasks once.</remarks>
private int _enumerated;

/// <summary>Called at the beginning of the iterator to assume ownership of the state.</summary>
/// <returns>true if the caller owns the state; false if the caller should end immediately.</returns>
public bool TryStart() => Interlocked.Exchange(ref _enumerated, 1) == 0;

/// <summary>Gets or sets the number of tasks that haven't yet been yielded.</summary>
public int Remaining { get; set; }

void ITaskCompletionAction.Invoke(Task completingTask)
{
lock (this)
{
// Enqueue the task into the queue. If the Count is now 1, we transitioned from
// empty to non-empty, which means we need to signal the MRVTSC, as the consumer
// could be waiting on a ValueTask representing a completed task being available.
Enqueue(completingTask);
if (Count == 1)
{
Debug.Assert(_waitForNextCompletedTask.GetStatus(_waitForNextCompletedTask.Version) == ValueTaskSourceStatus.Pending);
_waitForNextCompletedTask.SetResult(default);
}
}
}
bool ITaskCompletionAction.InvokeMayRunArbitraryCode => false;

// Delegate to _waitForNextCompletedTask for IValueTaskSource implementation.
void IValueTaskSource.GetResult(short token) => _waitForNextCompletedTask.GetResult(token);
ValueTaskSourceStatus IValueTaskSource.GetStatus(short token) => _waitForNextCompletedTask.GetStatus(token);
void IValueTaskSource.OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags) =>
_waitForNextCompletedTask.OnCompleted(continuation, state, token, flags);

/// <summary>Creates a <see cref="WhenEachState"/> from the specified tasks.</summary>
public static WhenEachState? Create(ReadOnlySpan<Task> tasks)
{
WhenEachState? waiter = null;

if (tasks.Length != 0)
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
{
waiter = new();
foreach (Task task in tasks)
{
if (task is null)
{
ThrowHelper.ThrowArgumentException(ExceptionResource.Task_MultiTaskContinuation_NullTask, ExceptionArgument.tasks);
}

waiter.Remaining++;
task.AddCompletionAction(waiter);
}
}

return waiter;
}

/// <inheritdoc cref="Create(ReadOnlySpan{Task})"/>
public static WhenEachState? Create(IEnumerable<Task> tasks)
{
ArgumentNullException.ThrowIfNull(tasks);

WhenEachState? waiter = null;

IEnumerator<Task> e = tasks.GetEnumerator();
if (e.MoveNext())
{
waiter = new();
do
{
Task task = e.Current;
if (task is null)
{
ThrowHelper.ThrowArgumentException(ExceptionResource.Task_MultiTaskContinuation_NullTask, ExceptionArgument.tasks);
}

waiter.Remaining++;
task.AddCompletionAction(waiter);
}
while (e.MoveNext());
}

return waiter;
}

/// <summary>Iterates through the tasks represented by the provided waiter.</summary>
public static async IAsyncEnumerable<T> Iterate<T>(WhenEachState? waiter, [EnumeratorCancellation] CancellationToken cancellationToken = default) where T : Task
{
// The enumerable could have GetAsyncEnumerator called on it multiple times. As we're dealing with Tasks that
// only ever transition from non-completed to completed, re-enumeration doesn't have much benefit, so we take
// advantage of the optimizations possible by not supporting that and simply have the semantics that, no matter
// how many times the enumerable is enumerated, every task is yielded only once. The original GetAsyncEnumerator
// call will give back all the tasks, and all subsequent iterations will be empty.
if (waiter?.TryStart() is true)
{
// Loop until we've yielded all tasks.
while (waiter.Remaining > 0)
{
// Either get the next completed task from the queue, or get a
// ValueTask with which to wait for the next task to complete.
Task? next;
ValueTask waitTask = default;
lock (waiter)
{
// Reset the MRVTSC if it was signaled, then try to dequeue a task and
// either return one we got or return a ValueTask that will be signaled
// when the next completed task is available.
waiter._waitForNextCompletedTask.Reset();
if (!waiter.TryDequeue(out next))
{
waitTask = new(waiter, waiter._waitForNextCompletedTask.Version);
}
}

// If we got a completed Task, yield it.
if (next is not null)
{
cancellationToken.ThrowIfCancellationRequested();
waiter.Remaining--;
yield return (T)next;
continue;
}

// Otherwise, wait.
if (cancellationToken.CanBeCanceled && !waitTask.IsCompleted)
{
// If we have a cancellation token and the ValueTask isn't already completed,
// get a Task from the ValueTask so we can use WaitAsync to make the wait cancelable.
await waitTask.AsTask().WaitAsync(cancellationToken).ConfigureAwait(false);
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
}
else
{
// Otherwise, just await the ValueTask directly. We don't need to be concerned
// about suppressing exceptions, as the ValueTask is only ever completed successfully.
await waitTask.ConfigureAwait(false);
}
}
}
}
}
#endregion

internal static Task<TResult> CreateUnwrapPromise<TResult>(Task outerTask, bool lookForOce)
{
Debug.Assert(outerTask != null);
Expand Down
6 changes: 6 additions & 0 deletions src/libraries/System.Runtime/ref/System.Runtime.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15335,6 +15335,12 @@ public static void WaitAll(System.Threading.Tasks.Task[] tasks, System.Threading
public static System.Threading.Tasks.Task<System.Threading.Tasks.Task<TResult>> WhenAny<TResult>(System.Collections.Generic.IEnumerable<System.Threading.Tasks.Task<TResult>> tasks) { throw null; }
public static System.Threading.Tasks.Task<System.Threading.Tasks.Task<TResult>> WhenAny<TResult>(System.Threading.Tasks.Task<TResult> task1, System.Threading.Tasks.Task<TResult> task2) { throw null; }
public static System.Threading.Tasks.Task<System.Threading.Tasks.Task<TResult>> WhenAny<TResult>(params System.Threading.Tasks.Task<TResult>[] tasks) { throw null; }
public static System.Collections.Generic.IAsyncEnumerable<System.Threading.Tasks.Task> WhenEach(System.Collections.Generic.IEnumerable<System.Threading.Tasks.Task> tasks) { throw null; }
public static System.Collections.Generic.IAsyncEnumerable<System.Threading.Tasks.Task> WhenEach(params System.Threading.Tasks.Task[] tasks) { throw null; }
public static System.Collections.Generic.IAsyncEnumerable<System.Threading.Tasks.Task> WhenEach(System.ReadOnlySpan<System.Threading.Tasks.Task> tasks) { throw null; }
public static System.Collections.Generic.IAsyncEnumerable<System.Threading.Tasks.Task<TResult>> WhenEach<TResult>(System.Collections.Generic.IEnumerable<System.Threading.Tasks.Task<TResult>> tasks) { throw null; }
public static System.Collections.Generic.IAsyncEnumerable<System.Threading.Tasks.Task<TResult>> WhenEach<TResult>(params System.Threading.Tasks.Task<TResult>[] tasks) { throw null; }
public static System.Collections.Generic.IAsyncEnumerable<System.Threading.Tasks.Task<TResult>> WhenEach<TResult>(System.ReadOnlySpan<System.Threading.Tasks.Task<TResult>> tasks) { throw null; }
public static System.Runtime.CompilerServices.YieldAwaitable Yield() { throw null; }
}
public static partial class TaskAsyncEnumerableExtensions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -924,5 +924,118 @@ public static void Task_WhenAll_TwoTasks_WakesOnBothCompletionWithExceptionAndCa
Assert.Equal(e1, twa.Exception?.InnerException);
}
}

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
public void Task_WhenEach_NullsTriggerExceptions()
{
AssertExtensions.Throws<ArgumentNullException>("tasks", () => Task.WhenEach((Task[])null));
AssertExtensions.Throws<ArgumentNullException>("tasks", () => Task.WhenEach((Task<int>[])null));
AssertExtensions.Throws<ArgumentNullException>("tasks", () => Task.WhenEach((IEnumerable<Task>)null));
AssertExtensions.Throws<ArgumentNullException>("tasks", () => Task.WhenEach((IEnumerable<Task<int>>)null));

AssertExtensions.Throws<ArgumentException>("tasks", () => Task.WhenEach((Task[])[null]));
AssertExtensions.Throws<ArgumentException>("tasks", () => Task.WhenEach((ReadOnlySpan<Task>)[null]));
AssertExtensions.Throws<ArgumentException>("tasks", () => Task.WhenEach((IEnumerable<Task>)[null]));
AssertExtensions.Throws<ArgumentException>("tasks", () => Task.WhenEach((Task<int>[])[null]));
AssertExtensions.Throws<ArgumentException>("tasks", () => Task.WhenEach((ReadOnlySpan<Task<int>>)[null]));
AssertExtensions.Throws<ArgumentException>("tasks", () => Task.WhenEach((IEnumerable<Task<int>>)[null]));
}

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
public async Task Task_WhenEach_EmptyInputsCompleteImmediately()
{
Assert.False(await Task.WhenEach((Task[])[]).GetAsyncEnumerator().MoveNextAsync());
Assert.False(await Task.WhenEach((ReadOnlySpan<Task>)[]).GetAsyncEnumerator().MoveNextAsync());
Assert.False(await Task.WhenEach((IEnumerable<Task>)[]).GetAsyncEnumerator().MoveNextAsync());
Assert.False(await Task.WhenEach((Task<int>[])[]).GetAsyncEnumerator().MoveNextAsync());
Assert.False(await Task.WhenEach((ReadOnlySpan<Task<int>>)[]).GetAsyncEnumerator().MoveNextAsync());
Assert.False(await Task.WhenEach((IEnumerable<Task<int>>)[]).GetAsyncEnumerator().MoveNextAsync());
}

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
public async Task Task_WhenEach_TasksOnlyEnumerableOnce()
{
IAsyncEnumerable<Task>[] enumerables =
[
Task.WhenEach((Task[])[Task.CompletedTask, Task.CompletedTask]),
Task.WhenEach((ReadOnlySpan<Task>)[Task.CompletedTask, Task.CompletedTask]),
Task.WhenEach((IEnumerable<Task>)[Task.CompletedTask, Task.CompletedTask]),
Task.WhenEach((Task<int>[])[Task.FromResult(0), Task.FromResult(0)]),
Task.WhenEach((ReadOnlySpan<Task<int>>)[Task.FromResult(0), Task.FromResult(0)]),
Task.WhenEach((IEnumerable<Task<int>>)[Task.FromResult(0), Task.FromResult(0)]),
];

foreach (IAsyncEnumerable<Task> e in enumerables)
{
IAsyncEnumerator<Task> e1 = e.GetAsyncEnumerator();
IAsyncEnumerator<Task> e2 = e.GetAsyncEnumerator();
IAsyncEnumerator<Task> e3 = e.GetAsyncEnumerator();

Assert.True(await e1.MoveNextAsync());
Assert.False(await e2.MoveNextAsync());
Assert.False(await e3.MoveNextAsync());

int count = 0;
do
{
count++;
}
while (await e1.MoveNextAsync());
Assert.Equal(2, count);

Assert.False(await e.GetAsyncEnumerator().MoveNextAsync());
}
}

[ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
[InlineData(0)]
[InlineData(1)]
[InlineData(2)]
[InlineData(3)]
[InlineData(4)]
[InlineData(5)]
public async void Task_WhenEach_IteratesThroughCompleteAndIncompleteTasks(int mode)
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
{
TaskCompletionSource<int> tcs1 = new(), tcs2 = new(), tcs3 = new();
Task<int>[] array = [Task.FromResult(1), tcs1.Task, Task.FromResult(2), tcs2.Task, Task.FromResult(3), tcs3.Task];

IAsyncEnumerable<Task> tasks = mode switch
{
0 => Task.WhenEach((ReadOnlySpan<Task>)array),
1 => Task.WhenEach((Task[])array),
2 => Task.WhenEach((IEnumerable<Task>)array),
3 => Task.WhenEach((ReadOnlySpan<Task<int>>)array),
4 => Task.WhenEach((Task<int>[])array),
_ => Task.WhenEach((IEnumerable<Task<int>>)array),
};

Assert.NotNull(tasks);

IAsyncEnumerator<Task> e = tasks.GetAsyncEnumerator();
Assert.NotNull(tasks);

ValueTask<bool> moveNext;

for (int i = 1; i <= 3; i++)
{
moveNext = e.MoveNextAsync();
Assert.True(moveNext.IsCompletedSuccessfully);
Assert.True(moveNext.Result);
Assert.Same(Task.FromResult(i), e.Current);
}

foreach (TaskCompletionSource<int> tcs in new[] { tcs2, tcs1, tcs3 })
{
moveNext = e.MoveNextAsync();
Assert.False(moveNext.IsCompleted);
tcs.SetResult(42);
Assert.True(await moveNext);
Assert.Same(tcs.Task, e.Current);
}

moveNext = e.MoveNextAsync();
Assert.True(moveNext.IsCompletedSuccessfully);
Assert.False(moveNext.Result);
}
}
}
Loading