Created
August 24, 2019 22:37
-
-
Save jnm2/7cd52dce0b7f761c9bd496dabdeeccaa to your computer and use it in GitHub Desktop.
AsyncParallelQueue and tests
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections; | |
using System.Collections.Generic; | |
using System.Threading.Tasks; | |
internal sealed partial class AsyncParallelQueue<T> | |
{ | |
/// <summary> | |
/// <para> | |
/// The purpose of this class is to separate what must be done inside a lock (everything this class does) from | |
/// what may or must be done outside the lock (everything <see cref="AsyncParallelQueue{T}"/> does). | |
/// </para> | |
/// <para> | |
/// Do not lock on objects of this type from outside the type. It locks on itself by design. | |
/// </para> | |
/// <para> | |
/// A lock-free implementation is not possible in this case due to the need to use <see cref="source"/> from | |
/// multiple threads in a thread-safe manner. | |
/// </para> | |
/// </summary> | |
private sealed class AtomicOperations | |
{ | |
private readonly IEnumerator<Task<T>> source; | |
private ArrayBuilder<Task<T>> completedTasks; | |
private bool reachedEndOfEnumerator; | |
private bool cancelFurtherEnumeration; | |
/// <summary> | |
/// Used to detect reentry within the same lock due to user code triggering <see cref="OnCancel"/> during a | |
/// call to <see cref="IEnumerator.MoveNext"/> or <see cref="IEnumerator{T}.Current"/> (while inside <see | |
/// cref="TryStartNext"/>). | |
/// </summary> | |
private bool isCallingMoveNextOrCurrent; | |
/// <summary> | |
/// Ensures that multiple calls to <see cref="OnCancel"/> or <see cref="TryStartNext"/> cannot result in | |
/// more than call being assigned the responsibility of handling the completion. | |
/// </summary> | |
private bool completionHandled; | |
/// <summary> | |
/// Tracks which index the last enumerated task's result should occupy in the output array. | |
/// </summary> | |
private int startedTaskCount; | |
/// <summary> | |
/// Used to determine when all started tasks have completed. | |
/// </summary> | |
private int completedTaskCount; | |
public AtomicOperations(IEnumerable<Task<T>> source) | |
{ | |
this.source = source.GetEnumerator(); | |
completedTasks = new ArrayBuilder<Task<T>>( | |
initialCapacity: CommonUtils.TryGetCollectionCount(source, out var count) ? count : 0); | |
} | |
public NextOperation TryStartNext() | |
{ | |
lock (this) | |
{ | |
return TryStartNextOrTryComplete(); | |
} | |
} | |
public NextOperation OnCancel() | |
{ | |
lock (this) | |
{ | |
cancelFurtherEnumeration = true; | |
if (isCallingMoveNextOrCurrent) | |
{ | |
// TryStartNextOrTryComplete is the current caller of the user code that called this method by | |
// canceling the token. It's too soon to complete because user code hasn't returned to | |
// TryStartNextOrTryComplete yet, so the resulting task hasn't been seen yet. | |
// Because cancelFurtherEnumeration is set, TryStartNextOrTryComplete will call TryComplete when | |
// the time is right. | |
return NextOperation.None; | |
} | |
return TryComplete(); | |
} | |
} | |
public NextOperation OnTaskCompleted(Task<T> completedTask, int taskIndex) | |
{ | |
lock (this) | |
{ | |
completedTaskCount++; | |
completedTasks[taskIndex] = completedTask; | |
return TryStartNextOrTryComplete(); | |
} | |
} | |
/// <summary>Only call from within a lock.</summary> | |
private NextOperation TryStartNextOrTryComplete() | |
{ | |
if (reachedEndOfEnumerator | cancelFurtherEnumeration) | |
return TryComplete(); | |
var taskIndex = startedTaskCount; | |
isCallingMoveNextOrCurrent = true; | |
try | |
{ | |
reachedEndOfEnumerator = !source.MoveNext(); | |
} | |
catch (Exception ex) | |
{ | |
startedTaskCount++; | |
completedTaskCount++; | |
completedTasks[taskIndex] = Task.FromException<T>(ex); | |
reachedEndOfEnumerator = true; | |
} | |
finally | |
{ | |
isCallingMoveNextOrCurrent = false; | |
} | |
if (reachedEndOfEnumerator) return TryComplete(); | |
startedTaskCount++; | |
Task<T> task; | |
isCallingMoveNextOrCurrent = true; | |
try | |
{ | |
// Even though MoveNext may have caused cancelFurtherEnumeration to become true, assume that | |
// the task is already started even if we don't access Current and that therefore we should | |
// observe it. | |
task = source.Current; | |
} | |
catch (Exception ex) | |
{ | |
task = Task.FromException<T>(ex); | |
} | |
finally | |
{ | |
isCallingMoveNextOrCurrent = false; | |
} | |
if (task is null) | |
task = Task.FromException<T>(new InvalidOperationException("The source task enumerator returned a null instance.")); | |
return NextOperation.Subscribe(task, taskIndex); | |
} | |
/// <summary>Only call from within a lock.</summary> | |
private NextOperation TryComplete() | |
{ | |
if (completionHandled || completedTaskCount != startedTaskCount) | |
return NextOperation.None; | |
try | |
{ | |
source.Dispose(); | |
} | |
catch (Exception ex) | |
{ | |
var taskIndex = startedTaskCount; | |
startedTaskCount++; | |
completedTaskCount++; | |
completedTasks[taskIndex] = Task.FromException<T>(ex); | |
} | |
completionHandled = true; | |
return NextOperation.Complete(canceled: !reachedEndOfEnumerator, completedTasks.MoveToArraySegment()); | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections; | |
using System.Collections.Generic; | |
using System.Collections.Immutable; | |
using System.Linq; | |
using System.Threading; | |
using System.Threading.Tasks; | |
internal sealed partial class AsyncParallelQueue<T> | |
{ | |
private readonly AtomicOperations atomicOperations; | |
private readonly TaskCompletionSource<ImmutableArray<T>> taskCompletionSource = new TaskCompletionSource<ImmutableArray<T>>(); | |
/// <summary> | |
/// Initializes a new instance of the <see cref="AsyncParallelQueue{T}"/> class and begins enumerating all tasks | |
/// from the <paramref name="source"/> parameter in the background. | |
/// </summary> | |
/// <param name="source"> | |
/// <para> | |
/// <see cref="IEnumerator.MoveNext"/> will be called each time a new task needs to be started in order to reach | |
/// the specified level of parallelism. This means that the enumerable you provide must create tasks on demand. | |
/// </para> | |
/// <para> | |
/// For example, <c>data.Select(async d => await FooAsync(d))</c>. <b>Do not</b> use <c>.ToList()</c> or | |
/// otherwise eagerly buffer the tasks into a list, or else all the tasks will be started in parallel with an | |
/// infinite degree of parallelism before you even call <see cref="AsyncParallelQueue{T}"/>. | |
/// </para> | |
/// <para> | |
/// Only one thread will interact with the enumerator at a time. | |
/// </para> | |
/// </param> | |
/// <param name="degreeOfParallelism"> | |
/// The maximum number of incomplete tasks enumerated from the <paramref name="source"/> parameter at a time. | |
/// </param> | |
/// <param name="cancellationToken"> | |
/// If enumeration has not already ended, prevents further calls to <see cref="IEnumerator.MoveNext"/> and | |
/// causes <see cref="WaitAllAsync"/> to result in cancellation instead of success once the already-started | |
/// tasks complete. | |
/// </param> | |
/// <exception cref="ArgumentNullException"> | |
/// Thrown when <paramref name="source"/> is <see langword="null"/>. | |
/// </exception> | |
/// <exception cref="ArgumentOutOfRangeException"> | |
/// Thrown when <paramref name="degreeOfParallelism"/> is less than zero. | |
/// </exception> | |
public AsyncParallelQueue(IEnumerable<Task<T>> source, int degreeOfParallelism, CancellationToken cancellationToken) | |
{ | |
if (source is null) | |
throw new ArgumentNullException(nameof(source)); | |
if (degreeOfParallelism < 1) | |
throw new ArgumentOutOfRangeException(nameof(degreeOfParallelism), degreeOfParallelism, "Degree of parallelism must be greater than or equal to one."); | |
if (cancellationToken.IsCancellationRequested) | |
{ | |
taskCompletionSource.SetCanceled(); | |
} | |
else | |
{ | |
atomicOperations = new AtomicOperations(source); | |
cancellationToken.Register(OnCancel); | |
for (var i = 0; i < degreeOfParallelism; i++) | |
{ | |
DoNextOperation(atomicOperations.TryStartNext()); | |
} | |
} | |
} | |
/// <summary> | |
/// <para> | |
/// Asynchronously waits for all enumerated tasks and returns a collection of the task results in the same order | |
/// that the tasks were enumerated. | |
/// </para> | |
/// <para> | |
/// If any task fails or the enumerator misbehaves, the result will be a failed task aggregating all inner | |
/// exceptions once all tasks are no longer running. Otherwise, if any task is externally canceled or the | |
/// cancellation token passed to the constructor is canceled before the enumerator ends, the result will be a | |
/// canceled task once all tasks are no longer running. | |
/// </para> | |
/// </summary> | |
public Task<ImmutableArray<T>> WaitAllAsync() => taskCompletionSource.Task; | |
private void OnCancel() | |
{ | |
DoNextOperation(atomicOperations.OnCancel()); | |
} | |
private void DoNextOperation(NextOperation nextOperation) | |
{ | |
if (nextOperation.IsSubscribe(out var task, out var taskIndex)) | |
{ | |
task.ContinueWith(OnTaskCompleted, state: taskIndex, TaskContinuationOptions.ExecuteSynchronously); | |
} | |
else if (nextOperation.IsComplete(out var canceled, out var completedTasks)) | |
{ | |
Complete(canceled, completedTasks); | |
} | |
else | |
{ | |
RuntimeAssert.That(nextOperation.IsNone, "All possible operations must be handled."); | |
} | |
} | |
private void OnTaskCompleted(Task<T> completedTask, object state) | |
{ | |
DoNextOperation(atomicOperations.OnTaskCompleted(completedTask, taskIndex: (int)state)); | |
} | |
private void Complete(bool canceled, ArraySegment<Task<T>> completedTasks) | |
{ | |
var exceptions = new List<Exception>(); | |
var anyTaskWasCancelledExternally = false; | |
var results = ImmutableArray.CreateBuilder<T>(completedTasks.Count); | |
foreach (var completedTask in completedTasks) | |
{ | |
switch (completedTask.Status) | |
{ | |
case TaskStatus.RanToCompletion: | |
results.Add(completedTask.Result); | |
break; | |
case TaskStatus.Canceled: | |
anyTaskWasCancelledExternally = true; | |
break; | |
case TaskStatus.Faulted: | |
exceptions.AddRange(completedTask.Exception.InnerExceptions); | |
break; | |
} | |
} | |
if (exceptions.Any()) | |
taskCompletionSource.TrySetException(exceptions); | |
else if (anyTaskWasCancelledExternally || canceled) | |
taskCompletionSource.SetCanceled(); | |
else | |
taskCompletionSource.SetResult(results.MoveToImmutable()); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Threading.Tasks; | |
internal sealed partial class AsyncParallelQueue<T> | |
{ | |
/// <summary> | |
/// A discriminated union which holds information which must be read within the lock inside <see | |
/// cref="AtomicOperations"/> but which must not be acted on inside that lock. | |
/// </summary> | |
private readonly struct NextOperation | |
{ | |
public static NextOperation None => default; | |
public bool IsNone => kind == 0; | |
public static NextOperation Subscribe(Task<T> task, int taskIndex) | |
{ | |
return new NextOperation(1, task, taskIndex, default, default); | |
} | |
public bool IsSubscribe(out Task<T> task, out int taskIndex) | |
{ | |
task = this.task; | |
taskIndex = this.taskIndex; | |
return kind == 1; | |
} | |
public static NextOperation Complete(bool canceled, ArraySegment<Task<T>> completedTasks) | |
{ | |
return new NextOperation(2, default, default, canceled, completedTasks); | |
} | |
public bool IsComplete(out bool canceled, out ArraySegment<Task<T>> completedTasks) | |
{ | |
canceled = this.canceled; | |
completedTasks = this.completedTasks; | |
return kind == 2; | |
} | |
private readonly int kind; | |
private readonly Task<T> task; | |
private readonly int taskIndex; | |
private readonly bool canceled; | |
private readonly ArraySegment<Task<T>> completedTasks; | |
private NextOperation(int kind, Task<T> task, int taskIndex, bool canceled, ArraySegment<Task<T>> completedTasks) | |
{ | |
this.kind = kind; | |
this.task = task; | |
this.taskIndex = taskIndex; | |
this.canceled = canceled; | |
this.completedTasks = completedTasks; | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections; | |
using System.Collections.Generic; | |
using System.Linq; | |
using System.Threading; | |
using System.Threading.Tasks; | |
using NSubstitute; | |
using NUnit.Framework; | |
using Shouldly; | |
public static class AsyncParallelQueueTests | |
{ | |
[Test] | |
public static void Empty_source_should_complete_synchronously_with_empty_result() | |
{ | |
var resultTask = Enumerable.Empty<int>().DoAllParallelAsync<int, int>( | |
_ => throw new AssertionException("Selector should not be invoked"), | |
degreeOfParallelism: 2, | |
CancellationToken.None); | |
resultTask.ShouldCompleteSynchronously().ShouldBeEmpty(); | |
} | |
[Test] | |
public static void Source_must_not_be_null() | |
{ | |
var ex = Should.Throw<ArgumentNullException>(() => | |
((IEnumerable<Task<int>>)null).DoAllParallelAsync( | |
degreeOfParallelism: -1, | |
CancellationToken.None)); | |
ex.ParamName.ShouldBe("source"); | |
ex = Should.Throw<ArgumentNullException>(() => | |
((IEnumerable<int>)null).DoAllParallelAsync<int, int>( | |
_ => throw new AssertionException("Selector should not be invoked"), | |
degreeOfParallelism: -1, | |
CancellationToken.None)); | |
ex.ParamName.ShouldBe("source"); | |
} | |
[Test] | |
public static void Selector_must_not_be_null() | |
{ | |
var ex = Should.Throw<ArgumentNullException>(() => | |
Enumerable.Empty<int>().DoAllParallelAsync<int, int>( | |
asyncSelector: null, | |
degreeOfParallelism: 2, | |
CancellationToken.None)); | |
ex.ParamName.ShouldBe("asyncSelector"); | |
} | |
[Test] | |
public static void Degree_of_parallelism_must_not_be_negative() | |
{ | |
var ex = Should.Throw<ArgumentOutOfRangeException>(() => | |
Enumerable.Empty<Task<int>>().DoAllParallelAsync( | |
degreeOfParallelism: -1, | |
CancellationToken.None)); | |
ex.ParamName.ShouldBe("degreeOfParallelism"); | |
ex.ActualValue.ShouldBe(-1); | |
ex = Should.Throw<ArgumentOutOfRangeException>(() => | |
Enumerable.Empty<int>().DoAllParallelAsync<int, int>( | |
_ => throw new AssertionException("Selector should not be invoked"), | |
degreeOfParallelism: -1, | |
CancellationToken.None)); | |
ex.ParamName.ShouldBe("degreeOfParallelism"); | |
ex.ActualValue.ShouldBe(-1); | |
} | |
[Test] | |
public static void GetEnumerator_is_called_exactly_once() | |
{ | |
var enumerable = Substitute.For<IEnumerable<Task<int>>>(); | |
enumerable.GetEnumerator().Returns(Enumerable.Empty<Task<int>>().GetEnumerator()); | |
enumerable.DoAllParallelAsync(degreeOfParallelism: 2, CancellationToken.None).ShouldCompleteSynchronously(); | |
enumerable.ReceivedCalls().ShouldHaveSingleItem(); | |
} | |
[Test] | |
public static async Task Enumerator_is_accessed_by_only_one_thread_at_a_time() | |
{ | |
var enumerable = Substitute.For<IEnumerable<Task<int>>>(); | |
enumerable.GetEnumerator().Returns(new ThreadSafetyCheckingEnumerator()); | |
await enumerable | |
.DoAllParallelAsync(degreeOfParallelism: 10, CancellationToken.None) | |
.WithTimeout(TimeSpan.FromSeconds(10)); | |
} | |
private sealed class ThreadSafetyCheckingEnumerator : IEnumerator<Task<int>> | |
{ | |
private int currentValue; | |
private int concurrentCallerCount; | |
public Task<int> Current | |
{ | |
get | |
{ | |
if (Interlocked.Increment(ref concurrentCallerCount) != 1) Assert.Fail(); | |
try | |
{ | |
return Task.Delay(1).ContinueWith(_ => currentValue); | |
} | |
finally | |
{ | |
Interlocked.Decrement(ref concurrentCallerCount); | |
} | |
} | |
} | |
object IEnumerator.Current => throw new NotImplementedException(); | |
public void Dispose() | |
{ | |
} | |
public bool MoveNext() | |
{ | |
if (Interlocked.Increment(ref concurrentCallerCount) != 1) Assert.Fail(); | |
try | |
{ | |
if (currentValue >= 100) return false; | |
currentValue++; | |
return true; | |
} | |
finally | |
{ | |
Interlocked.Decrement(ref concurrentCallerCount); | |
} | |
} | |
public void Reset() | |
{ | |
throw new NotImplementedException(); | |
} | |
} | |
[Test] | |
public static void Completes_synchronously_when_all_work_completes_synchronously() | |
{ | |
var task = Enumerable.Range(1, 10).DoAllParallelAsync( | |
n => | |
n % 3 == 0 ? Task.FromResult(42) : | |
n % 3 == 1 ? Stubs.Task.Canceled<int>() : | |
Task.FromException<int>(new Exception()), | |
degreeOfParallelism: 2, | |
CancellationToken.None); | |
task.IsCompleted.ShouldBeTrue(); | |
} | |
[Test] | |
public static void All_results_are_collected_in_order() | |
{ | |
var sources = Enumerable.Range(1, 100).Select(n => new TaskCompletionSource<int>()).ToList(); | |
var resultTask = sources.DoAllParallelAsync( | |
s => s.Task, | |
degreeOfParallelism: 10, | |
CancellationToken.None); | |
for (var i = 0; i < sources.Count; i++) | |
sources[i].SetResult(i + 1); | |
var result = resultTask.ShouldCompleteSynchronously(); | |
result.ShouldBe(Enumerable.Range(1, 100)); | |
} | |
[Test] | |
public static void Externally_canceled_task_results_in_canceled_result() | |
{ | |
var sources = Enumerable.Range(1, 100).Select(n => new TaskCompletionSource<int>()).ToList(); | |
var resultTask = sources.DoAllParallelAsync( | |
s => s.Task, | |
degreeOfParallelism: 10, | |
CancellationToken.None); | |
for (var i = 1; i < sources.Count; i++) | |
sources[i].SetResult(i + 1); | |
sources[0].SetCanceled(); | |
resultTask.IsCanceled.ShouldBeTrue(); | |
} | |
[Test] | |
public static void Fault_overrides_success() | |
{ | |
var exception = new Exception(); | |
var resultTask = new[] | |
{ | |
Task.FromException<int>(exception), | |
Task.FromResult(42) | |
}.DoAllParallelAsync(degreeOfParallelism: 2, CancellationToken.None); | |
resultTask.ShouldBeFaulted() | |
.InnerExceptions.ShouldHaveSingleItem().ShouldBeSameAs(exception); | |
} | |
[Test] | |
public static void Fault_overrides_external_cancellation() | |
{ | |
var exception = new Exception(); | |
var resultTask = new[] | |
{ | |
Task.FromException<int>(exception), | |
Stubs.Task.Canceled<int>() | |
}.DoAllParallelAsync(degreeOfParallelism: 2, CancellationToken.None); | |
resultTask.ShouldBeFaulted() | |
.InnerExceptions.ShouldHaveSingleItem().ShouldBeSameAs(exception); | |
} | |
[Test] | |
public static void Fault_overrides_internal_cancellation() | |
{ | |
var exception = new Exception(); | |
using (var cancelSource = new CancellationTokenSource()) | |
{ | |
var resultTask = new Func<Task<int>>[] | |
{ | |
() => | |
{ | |
cancelSource.Cancel(); | |
return Task.FromException<int>(exception); | |
}, | |
() => Task.FromException<int>(new Exception("Should not be called")), | |
}.DoAllParallelAsync(f => f.Invoke(), degreeOfParallelism: 2, cancelSource.Token); | |
resultTask.ShouldBeFaulted() | |
.InnerExceptions.ShouldHaveSingleItem().ShouldBeSameAs(exception); | |
} | |
} | |
[Test] | |
public static async Task Cancellation_from_another_thread_prevents_further_enumeration() | |
{ | |
using (var cancelSource = new CancellationTokenSource()) | |
{ | |
var enumerationDidContinue = false; | |
await new Func<Task<int>>[] | |
{ | |
async () => | |
{ | |
await Task.Run(cancelSource.Cancel); | |
return 42; | |
}, | |
() => | |
{ | |
enumerationDidContinue = true; | |
return Task.FromResult(42); | |
} | |
}.DoAllParallelAsync(f => f.Invoke(), degreeOfParallelism: 1, cancelSource.Token) | |
.ShouldCancelAsync(TimeSpan.FromSeconds(10)); | |
enumerationDidContinue.ShouldBeFalse(); | |
} | |
} | |
[Test] | |
public static void Cancellation_from_Current_prevents_further_enumeration() | |
{ | |
using (var cancelSource = new CancellationTokenSource()) | |
{ | |
var enumerationDidContinue = false; | |
var enumerator = Substitute.For<IEnumerator<Task<int>>>(); | |
enumerator.MoveNext().Returns( | |
_ => true, | |
_ => | |
{ | |
enumerationDidContinue = true; | |
return false; | |
}); | |
enumerator.Current.Returns(_ => | |
{ | |
cancelSource.Cancel(); | |
return Task.FromResult(42); | |
}); | |
var enumerable = Substitute.For<IEnumerable<Task<int>>>(); | |
enumerable.GetEnumerator().Returns(enumerator); | |
enumerable.DoAllParallelAsync(degreeOfParallelism: 1, cancelSource.Token) | |
.IsCanceled.ShouldBeTrue(); | |
enumerationDidContinue.ShouldBeFalse(); | |
} | |
} | |
[Test] | |
public static void Calling_with_canceled_token_prevents_GetEnumerator_call() | |
{ | |
using (var cancelSource = new CancellationTokenSource()) | |
{ | |
var enumerable = Substitute.For<IEnumerable<Task<int>>>(); | |
enumerable.GetEnumerator().Returns(_ => throw new AssertionException("GetEnumerator should not be called")); | |
cancelSource.Cancel(); | |
enumerable.DoAllParallelAsync(degreeOfParallelism: 10, cancelSource.Token) | |
.IsCanceled.ShouldBeTrue(); | |
} | |
} | |
[Test] | |
public static void Cancellation_from_GetEnumerator_prevents_any_enumeration() | |
{ | |
using (var cancelSource = new CancellationTokenSource()) | |
{ | |
var enumerationDidContinue = false; | |
var enumerator = Substitute.For<IEnumerator<Task<int>>>(); | |
enumerator.MoveNext().Returns( | |
_ => | |
{ | |
enumerationDidContinue = true; | |
return false; | |
}); | |
var enumerable = Substitute.For<IEnumerable<Task<int>>>(); | |
enumerable.GetEnumerator().Returns(_ => | |
{ | |
cancelSource.Cancel(); | |
return enumerator; | |
}); | |
enumerable.DoAllParallelAsync(degreeOfParallelism: 10, cancelSource.Token) | |
.IsCanceled.ShouldBeTrue(); | |
enumerationDidContinue.ShouldBeFalse(); | |
} | |
} | |
[Test] | |
public static void Cancellation_from_MoveNext_prevents_further_enumeration() | |
{ | |
using (var cancelSource = new CancellationTokenSource()) | |
{ | |
var enumerationDidContinue = false; | |
var enumerator = Substitute.For<IEnumerator<Task<int>>>(); | |
enumerator.MoveNext().Returns( | |
_ => | |
{ | |
cancelSource.Cancel(); | |
return true; | |
}, | |
_ => | |
{ | |
enumerationDidContinue = true; | |
return false; | |
}); | |
enumerator.Current.Returns(42); | |
var enumerable = Substitute.For<IEnumerable<Task<int>>>(); | |
enumerable.GetEnumerator().Returns(enumerator); | |
enumerable.DoAllParallelAsync(degreeOfParallelism: 1, cancelSource.Token) | |
.IsCanceled.ShouldBeTrue(); | |
enumerationDidContinue.ShouldBeFalse(); | |
} | |
} | |
[Test] | |
public static void Cancellation_from_MoveNext_still_accesses_Current_task_and_waits_for_it() | |
{ | |
using (var cancelSource = new CancellationTokenSource()) | |
{ | |
var enumerator = Substitute.For<IEnumerator<Task<int>>>(); | |
enumerator.MoveNext().Returns( | |
_ => | |
{ | |
cancelSource.Cancel(); | |
return true; | |
}, | |
_ => false); | |
var source = new TaskCompletionSource<int>(); | |
enumerator.Current.Returns(source.Task); | |
var enumerable = Substitute.For<IEnumerable<Task<int>>>(); | |
enumerable.GetEnumerator().Returns(enumerator); | |
var resultTask = enumerable.DoAllParallelAsync(degreeOfParallelism: 1, cancelSource.Token); | |
resultTask.IsCompleted.ShouldBeFalse(); | |
var exception = new Exception(); | |
source.SetException(exception); | |
resultTask.ShouldBeFaulted().InnerExceptions.ShouldHaveSingleItem().ShouldBeSameAs(exception); | |
} | |
} | |
[Test] | |
public static void Cancellation_from_Dispose_is_ignored() | |
{ | |
using (var cancelSource = new CancellationTokenSource()) | |
{ | |
var enumerator = Substitute.For<IEnumerator<Task<int>>>(); | |
enumerator.MoveNext().Returns(true, false); | |
enumerator.Current.Returns(42); | |
enumerator.When(e => e.Dispose()).Do(_ => cancelSource.Cancel()); | |
var enumerable = Substitute.For<IEnumerable<Task<int>>>(); | |
enumerable.GetEnumerator().Returns(enumerator); | |
enumerable.DoAllParallelAsync(degreeOfParallelism: 1, cancelSource.Token) | |
.ShouldCompleteSynchronously(); | |
} | |
} | |
[Test] | |
public static void All_exceptions_are_collected_in_order() | |
{ | |
var source1 = new TaskCompletionSource<int>(); | |
var source2 = new TaskCompletionSource<int>(); | |
var resultTask = new[] { source1, source2 }.DoAllParallelAsync( | |
s => s.Task, | |
degreeOfParallelism: 2, | |
CancellationToken.None); | |
source2.SetException(new[] { new Exception(), new Exception() }); | |
source1.SetException(new[] { new Exception(), new Exception() }); | |
resultTask.IsFaulted.ShouldBeTrue(); | |
resultTask.Exception.InnerExceptions.ShouldBe( | |
source1.Task.Exception.InnerExceptions.Concat( | |
source2.Task.Exception.InnerExceptions)); | |
} | |
[Test] | |
public static async Task InvalidOperationException_is_created_for_null_tasks() | |
{ | |
var source1 = new TaskCompletionSource<int>(); | |
var source2 = new TaskCompletionSource<int>(); | |
var source3 = new TaskCompletionSource<int>(); | |
var resultTask = new[] { source1, null, source2, null, source3 }.DoAllParallelAsync( | |
s => s?.Task, | |
degreeOfParallelism: 3, | |
CancellationToken.None); | |
source3.SetException(new[] { new Exception(), new Exception() }); | |
source2.SetException(new[] { new Exception(), new Exception() }); | |
source1.SetException(new[] { new Exception(), new Exception() }); | |
var ex = await resultTask.ShouldFaultAsync(TimeSpan.FromSeconds(10)); | |
ex.InnerExceptions.Count.ShouldBe(8); | |
ex.InnerExceptions[2].ShouldBeOfType<InvalidOperationException>(); | |
ex.InnerExceptions[5].ShouldBeOfType<InvalidOperationException>(); | |
} | |
[Test] | |
public static async Task Exception_accessing_Current_is_treated_like_a_faulted_task() | |
{ | |
var exception = new Exception(); | |
var enumerator = Substitute.For<IEnumerator<Task<int>>>(); | |
enumerator.MoveNext().Returns(true, false); | |
enumerator.Current.Returns<Task<int>>(_ => throw exception); | |
var enumerable = Substitute.For<IEnumerable<Task<int>>>(); | |
enumerable.GetEnumerator().Returns(enumerator); | |
var ex = await enumerable | |
.DoAllParallelAsync( | |
degreeOfParallelism: 2, | |
CancellationToken.None) | |
.ShouldFaultAsync(TimeSpan.FromSeconds(10)); | |
ex.InnerExceptions.ShouldHaveSingleItem().ShouldBeSameAs(exception); | |
} | |
[Test] | |
public static async Task Exception_from_MoveNext_cancels_further_enumeration() | |
{ | |
var exception = new Exception(); | |
var enumerator = Substitute.For<IEnumerator<Task<int>>>(); | |
enumerator.MoveNext().Returns(_ => throw exception, _ => false); | |
var enumerable = Substitute.For<IEnumerable<Task<int>>>(); | |
enumerable.GetEnumerator().Returns(enumerator); | |
var ex = await enumerable | |
.DoAllParallelAsync( | |
degreeOfParallelism: 2, | |
CancellationToken.None) | |
.ShouldFaultAsync(TimeSpan.FromSeconds(10)); | |
ex.InnerExceptions.ShouldHaveSingleItem().ShouldBeSameAs(exception); | |
enumerator.ReceivedCalls() | |
.Where(c => c.GetMethodInfo().Name != nameof(enumerator.Dispose)) | |
.ShouldHaveSingleItem(); | |
} | |
[Test] | |
public static async Task Exception_from_Dispose_is_treated_like_a_faulted_task() | |
{ | |
var exception = new Exception(); | |
var enumerator = Substitute.For<IEnumerator<Task<int>>>(); | |
enumerator.MoveNext().Returns(_ => false); | |
enumerator.When(e => e.Dispose()).Throw(exception); | |
var enumerable = Substitute.For<IEnumerable<Task<int>>>(); | |
enumerable.GetEnumerator().Returns(enumerator); | |
var ex = await enumerable | |
.DoAllParallelAsync( | |
degreeOfParallelism: 2, | |
CancellationToken.None) | |
.ShouldFaultAsync(TimeSpan.FromSeconds(10)); | |
ex.InnerExceptions.ShouldHaveSingleItem().ShouldBeSameAs(exception); | |
} | |
[Test] | |
public static void Exception_from_GetEnumerator_should_not_be_handled() | |
{ | |
var exception = new Exception(); | |
var enumerator = Substitute.For<IEnumerator<Task<int>>>(); | |
enumerator.MoveNext().Returns(_ => false); | |
enumerator.When(e => e.Dispose()).Throw(exception); | |
var enumerable = Substitute.For<IEnumerable<Task<int>>>(); | |
enumerable.GetEnumerator().Returns(_ => throw exception); | |
var ex = Should.Throw<Exception>(() => enumerable.DoAllParallelAsync( | |
degreeOfParallelism: 2, | |
CancellationToken.None)); | |
ex.ShouldBeSameAs(exception); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment