Skip to content

Instantly share code, notes, and snippets.

@jnm2
Created August 24, 2019 22:37
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jnm2/7cd52dce0b7f761c9bd496dabdeeccaa to your computer and use it in GitHub Desktop.
Save jnm2/7cd52dce0b7f761c9bd496dabdeeccaa to your computer and use it in GitHub Desktop.
AsyncParallelQueue and tests
using System;
using System.Collections;
using System.Collections.Generic;
using System.Threading.Tasks;
internal sealed partial class AsyncParallelQueue<T>
{
/// <summary>
/// <para>
/// The purpose of this class is to separate what must be done inside a lock (everything this class does) from
/// what may or must be done outside the lock (everything <see cref="AsyncParallelQueue{T}"/> does).
/// </para>
/// <para>
/// Do not lock on objects of this type from outside the type. It locks on itself by design.
/// </para>
/// <para>
/// A lock-free implementation is not possible in this case due to the need to use <see cref="source"/> from
/// multiple threads in a thread-safe manner.
/// </para>
/// </summary>
private sealed class AtomicOperations
{
private readonly IEnumerator<Task<T>> source;
private ArrayBuilder<Task<T>> completedTasks;
private bool reachedEndOfEnumerator;
private bool cancelFurtherEnumeration;
/// <summary>
/// Used to detect reentry within the same lock due to user code triggering <see cref="OnCancel"/> during a
/// call to <see cref="IEnumerator.MoveNext"/> or <see cref="IEnumerator{T}.Current"/> (while inside <see
/// cref="TryStartNext"/>).
/// </summary>
private bool isCallingMoveNextOrCurrent;
/// <summary>
/// Ensures that multiple calls to <see cref="OnCancel"/> or <see cref="TryStartNext"/> cannot result in
/// more than call being assigned the responsibility of handling the completion.
/// </summary>
private bool completionHandled;
/// <summary>
/// Tracks which index the last enumerated task's result should occupy in the output array.
/// </summary>
private int startedTaskCount;
/// <summary>
/// Used to determine when all started tasks have completed.
/// </summary>
private int completedTaskCount;
public AtomicOperations(IEnumerable<Task<T>> source)
{
this.source = source.GetEnumerator();
completedTasks = new ArrayBuilder<Task<T>>(
initialCapacity: CommonUtils.TryGetCollectionCount(source, out var count) ? count : 0);
}
public NextOperation TryStartNext()
{
lock (this)
{
return TryStartNextOrTryComplete();
}
}
public NextOperation OnCancel()
{
lock (this)
{
cancelFurtherEnumeration = true;
if (isCallingMoveNextOrCurrent)
{
// TryStartNextOrTryComplete is the current caller of the user code that called this method by
// canceling the token. It's too soon to complete because user code hasn't returned to
// TryStartNextOrTryComplete yet, so the resulting task hasn't been seen yet.
// Because cancelFurtherEnumeration is set, TryStartNextOrTryComplete will call TryComplete when
// the time is right.
return NextOperation.None;
}
return TryComplete();
}
}
public NextOperation OnTaskCompleted(Task<T> completedTask, int taskIndex)
{
lock (this)
{
completedTaskCount++;
completedTasks[taskIndex] = completedTask;
return TryStartNextOrTryComplete();
}
}
/// <summary>Only call from within a lock.</summary>
private NextOperation TryStartNextOrTryComplete()
{
if (reachedEndOfEnumerator | cancelFurtherEnumeration)
return TryComplete();
var taskIndex = startedTaskCount;
isCallingMoveNextOrCurrent = true;
try
{
reachedEndOfEnumerator = !source.MoveNext();
}
catch (Exception ex)
{
startedTaskCount++;
completedTaskCount++;
completedTasks[taskIndex] = Task.FromException<T>(ex);
reachedEndOfEnumerator = true;
}
finally
{
isCallingMoveNextOrCurrent = false;
}
if (reachedEndOfEnumerator) return TryComplete();
startedTaskCount++;
Task<T> task;
isCallingMoveNextOrCurrent = true;
try
{
// Even though MoveNext may have caused cancelFurtherEnumeration to become true, assume that
// the task is already started even if we don't access Current and that therefore we should
// observe it.
task = source.Current;
}
catch (Exception ex)
{
task = Task.FromException<T>(ex);
}
finally
{
isCallingMoveNextOrCurrent = false;
}
if (task is null)
task = Task.FromException<T>(new InvalidOperationException("The source task enumerator returned a null instance."));
return NextOperation.Subscribe(task, taskIndex);
}
/// <summary>Only call from within a lock.</summary>
private NextOperation TryComplete()
{
if (completionHandled || completedTaskCount != startedTaskCount)
return NextOperation.None;
try
{
source.Dispose();
}
catch (Exception ex)
{
var taskIndex = startedTaskCount;
startedTaskCount++;
completedTaskCount++;
completedTasks[taskIndex] = Task.FromException<T>(ex);
}
completionHandled = true;
return NextOperation.Complete(canceled: !reachedEndOfEnumerator, completedTasks.MoveToArraySegment());
}
}
}
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
internal sealed partial class AsyncParallelQueue<T>
{
private readonly AtomicOperations atomicOperations;
private readonly TaskCompletionSource<ImmutableArray<T>> taskCompletionSource = new TaskCompletionSource<ImmutableArray<T>>();
/// <summary>
/// Initializes a new instance of the <see cref="AsyncParallelQueue{T}"/> class and begins enumerating all tasks
/// from the <paramref name="source"/> parameter in the background.
/// </summary>
/// <param name="source">
/// <para>
/// <see cref="IEnumerator.MoveNext"/> will be called each time a new task needs to be started in order to reach
/// the specified level of parallelism. This means that the enumerable you provide must create tasks on demand.
/// </para>
/// <para>
/// For example, <c>data.Select(async d => await FooAsync(d))</c>. <b>Do not</b> use <c>.ToList()</c> or
/// otherwise eagerly buffer the tasks into a list, or else all the tasks will be started in parallel with an
/// infinite degree of parallelism before you even call <see cref="AsyncParallelQueue{T}"/>.
/// </para>
/// <para>
/// Only one thread will interact with the enumerator at a time.
/// </para>
/// </param>
/// <param name="degreeOfParallelism">
/// The maximum number of incomplete tasks enumerated from the <paramref name="source"/> parameter at a time.
/// </param>
/// <param name="cancellationToken">
/// If enumeration has not already ended, prevents further calls to <see cref="IEnumerator.MoveNext"/> and
/// causes <see cref="WaitAllAsync"/> to result in cancellation instead of success once the already-started
/// tasks complete.
/// </param>
/// <exception cref="ArgumentNullException">
/// Thrown when <paramref name="source"/> is <see langword="null"/>.
/// </exception>
/// <exception cref="ArgumentOutOfRangeException">
/// Thrown when <paramref name="degreeOfParallelism"/> is less than zero.
/// </exception>
public AsyncParallelQueue(IEnumerable<Task<T>> source, int degreeOfParallelism, CancellationToken cancellationToken)
{
if (source is null)
throw new ArgumentNullException(nameof(source));
if (degreeOfParallelism < 1)
throw new ArgumentOutOfRangeException(nameof(degreeOfParallelism), degreeOfParallelism, "Degree of parallelism must be greater than or equal to one.");
if (cancellationToken.IsCancellationRequested)
{
taskCompletionSource.SetCanceled();
}
else
{
atomicOperations = new AtomicOperations(source);
cancellationToken.Register(OnCancel);
for (var i = 0; i < degreeOfParallelism; i++)
{
DoNextOperation(atomicOperations.TryStartNext());
}
}
}
/// <summary>
/// <para>
/// Asynchronously waits for all enumerated tasks and returns a collection of the task results in the same order
/// that the tasks were enumerated.
/// </para>
/// <para>
/// If any task fails or the enumerator misbehaves, the result will be a failed task aggregating all inner
/// exceptions once all tasks are no longer running. Otherwise, if any task is externally canceled or the
/// cancellation token passed to the constructor is canceled before the enumerator ends, the result will be a
/// canceled task once all tasks are no longer running.
/// </para>
/// </summary>
public Task<ImmutableArray<T>> WaitAllAsync() => taskCompletionSource.Task;
private void OnCancel()
{
DoNextOperation(atomicOperations.OnCancel());
}
private void DoNextOperation(NextOperation nextOperation)
{
if (nextOperation.IsSubscribe(out var task, out var taskIndex))
{
task.ContinueWith(OnTaskCompleted, state: taskIndex, TaskContinuationOptions.ExecuteSynchronously);
}
else if (nextOperation.IsComplete(out var canceled, out var completedTasks))
{
Complete(canceled, completedTasks);
}
else
{
RuntimeAssert.That(nextOperation.IsNone, "All possible operations must be handled.");
}
}
private void OnTaskCompleted(Task<T> completedTask, object state)
{
DoNextOperation(atomicOperations.OnTaskCompleted(completedTask, taskIndex: (int)state));
}
private void Complete(bool canceled, ArraySegment<Task<T>> completedTasks)
{
var exceptions = new List<Exception>();
var anyTaskWasCancelledExternally = false;
var results = ImmutableArray.CreateBuilder<T>(completedTasks.Count);
foreach (var completedTask in completedTasks)
{
switch (completedTask.Status)
{
case TaskStatus.RanToCompletion:
results.Add(completedTask.Result);
break;
case TaskStatus.Canceled:
anyTaskWasCancelledExternally = true;
break;
case TaskStatus.Faulted:
exceptions.AddRange(completedTask.Exception.InnerExceptions);
break;
}
}
if (exceptions.Any())
taskCompletionSource.TrySetException(exceptions);
else if (anyTaskWasCancelledExternally || canceled)
taskCompletionSource.SetCanceled();
else
taskCompletionSource.SetResult(results.MoveToImmutable());
}
}
using System;
using System.Threading.Tasks;
internal sealed partial class AsyncParallelQueue<T>
{
/// <summary>
/// A discriminated union which holds information which must be read within the lock inside <see
/// cref="AtomicOperations"/> but which must not be acted on inside that lock.
/// </summary>
private readonly struct NextOperation
{
public static NextOperation None => default;
public bool IsNone => kind == 0;
public static NextOperation Subscribe(Task<T> task, int taskIndex)
{
return new NextOperation(1, task, taskIndex, default, default);
}
public bool IsSubscribe(out Task<T> task, out int taskIndex)
{
task = this.task;
taskIndex = this.taskIndex;
return kind == 1;
}
public static NextOperation Complete(bool canceled, ArraySegment<Task<T>> completedTasks)
{
return new NextOperation(2, default, default, canceled, completedTasks);
}
public bool IsComplete(out bool canceled, out ArraySegment<Task<T>> completedTasks)
{
canceled = this.canceled;
completedTasks = this.completedTasks;
return kind == 2;
}
private readonly int kind;
private readonly Task<T> task;
private readonly int taskIndex;
private readonly bool canceled;
private readonly ArraySegment<Task<T>> completedTasks;
private NextOperation(int kind, Task<T> task, int taskIndex, bool canceled, ArraySegment<Task<T>> completedTasks)
{
this.kind = kind;
this.task = task;
this.taskIndex = taskIndex;
this.canceled = canceled;
this.completedTasks = completedTasks;
}
}
}
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using NSubstitute;
using NUnit.Framework;
using Shouldly;
public static class AsyncParallelQueueTests
{
[Test]
public static void Empty_source_should_complete_synchronously_with_empty_result()
{
var resultTask = Enumerable.Empty<int>().DoAllParallelAsync<int, int>(
_ => throw new AssertionException("Selector should not be invoked"),
degreeOfParallelism: 2,
CancellationToken.None);
resultTask.ShouldCompleteSynchronously().ShouldBeEmpty();
}
[Test]
public static void Source_must_not_be_null()
{
var ex = Should.Throw<ArgumentNullException>(() =>
((IEnumerable<Task<int>>)null).DoAllParallelAsync(
degreeOfParallelism: -1,
CancellationToken.None));
ex.ParamName.ShouldBe("source");
ex = Should.Throw<ArgumentNullException>(() =>
((IEnumerable<int>)null).DoAllParallelAsync<int, int>(
_ => throw new AssertionException("Selector should not be invoked"),
degreeOfParallelism: -1,
CancellationToken.None));
ex.ParamName.ShouldBe("source");
}
[Test]
public static void Selector_must_not_be_null()
{
var ex = Should.Throw<ArgumentNullException>(() =>
Enumerable.Empty<int>().DoAllParallelAsync<int, int>(
asyncSelector: null,
degreeOfParallelism: 2,
CancellationToken.None));
ex.ParamName.ShouldBe("asyncSelector");
}
[Test]
public static void Degree_of_parallelism_must_not_be_negative()
{
var ex = Should.Throw<ArgumentOutOfRangeException>(() =>
Enumerable.Empty<Task<int>>().DoAllParallelAsync(
degreeOfParallelism: -1,
CancellationToken.None));
ex.ParamName.ShouldBe("degreeOfParallelism");
ex.ActualValue.ShouldBe(-1);
ex = Should.Throw<ArgumentOutOfRangeException>(() =>
Enumerable.Empty<int>().DoAllParallelAsync<int, int>(
_ => throw new AssertionException("Selector should not be invoked"),
degreeOfParallelism: -1,
CancellationToken.None));
ex.ParamName.ShouldBe("degreeOfParallelism");
ex.ActualValue.ShouldBe(-1);
}
[Test]
public static void GetEnumerator_is_called_exactly_once()
{
var enumerable = Substitute.For<IEnumerable<Task<int>>>();
enumerable.GetEnumerator().Returns(Enumerable.Empty<Task<int>>().GetEnumerator());
enumerable.DoAllParallelAsync(degreeOfParallelism: 2, CancellationToken.None).ShouldCompleteSynchronously();
enumerable.ReceivedCalls().ShouldHaveSingleItem();
}
[Test]
public static async Task Enumerator_is_accessed_by_only_one_thread_at_a_time()
{
var enumerable = Substitute.For<IEnumerable<Task<int>>>();
enumerable.GetEnumerator().Returns(new ThreadSafetyCheckingEnumerator());
await enumerable
.DoAllParallelAsync(degreeOfParallelism: 10, CancellationToken.None)
.WithTimeout(TimeSpan.FromSeconds(10));
}
private sealed class ThreadSafetyCheckingEnumerator : IEnumerator<Task<int>>
{
private int currentValue;
private int concurrentCallerCount;
public Task<int> Current
{
get
{
if (Interlocked.Increment(ref concurrentCallerCount) != 1) Assert.Fail();
try
{
return Task.Delay(1).ContinueWith(_ => currentValue);
}
finally
{
Interlocked.Decrement(ref concurrentCallerCount);
}
}
}
object IEnumerator.Current => throw new NotImplementedException();
public void Dispose()
{
}
public bool MoveNext()
{
if (Interlocked.Increment(ref concurrentCallerCount) != 1) Assert.Fail();
try
{
if (currentValue >= 100) return false;
currentValue++;
return true;
}
finally
{
Interlocked.Decrement(ref concurrentCallerCount);
}
}
public void Reset()
{
throw new NotImplementedException();
}
}
[Test]
public static void Completes_synchronously_when_all_work_completes_synchronously()
{
var task = Enumerable.Range(1, 10).DoAllParallelAsync(
n =>
n % 3 == 0 ? Task.FromResult(42) :
n % 3 == 1 ? Stubs.Task.Canceled<int>() :
Task.FromException<int>(new Exception()),
degreeOfParallelism: 2,
CancellationToken.None);
task.IsCompleted.ShouldBeTrue();
}
[Test]
public static void All_results_are_collected_in_order()
{
var sources = Enumerable.Range(1, 100).Select(n => new TaskCompletionSource<int>()).ToList();
var resultTask = sources.DoAllParallelAsync(
s => s.Task,
degreeOfParallelism: 10,
CancellationToken.None);
for (var i = 0; i < sources.Count; i++)
sources[i].SetResult(i + 1);
var result = resultTask.ShouldCompleteSynchronously();
result.ShouldBe(Enumerable.Range(1, 100));
}
[Test]
public static void Externally_canceled_task_results_in_canceled_result()
{
var sources = Enumerable.Range(1, 100).Select(n => new TaskCompletionSource<int>()).ToList();
var resultTask = sources.DoAllParallelAsync(
s => s.Task,
degreeOfParallelism: 10,
CancellationToken.None);
for (var i = 1; i < sources.Count; i++)
sources[i].SetResult(i + 1);
sources[0].SetCanceled();
resultTask.IsCanceled.ShouldBeTrue();
}
[Test]
public static void Fault_overrides_success()
{
var exception = new Exception();
var resultTask = new[]
{
Task.FromException<int>(exception),
Task.FromResult(42)
}.DoAllParallelAsync(degreeOfParallelism: 2, CancellationToken.None);
resultTask.ShouldBeFaulted()
.InnerExceptions.ShouldHaveSingleItem().ShouldBeSameAs(exception);
}
[Test]
public static void Fault_overrides_external_cancellation()
{
var exception = new Exception();
var resultTask = new[]
{
Task.FromException<int>(exception),
Stubs.Task.Canceled<int>()
}.DoAllParallelAsync(degreeOfParallelism: 2, CancellationToken.None);
resultTask.ShouldBeFaulted()
.InnerExceptions.ShouldHaveSingleItem().ShouldBeSameAs(exception);
}
[Test]
public static void Fault_overrides_internal_cancellation()
{
var exception = new Exception();
using (var cancelSource = new CancellationTokenSource())
{
var resultTask = new Func<Task<int>>[]
{
() =>
{
cancelSource.Cancel();
return Task.FromException<int>(exception);
},
() => Task.FromException<int>(new Exception("Should not be called")),
}.DoAllParallelAsync(f => f.Invoke(), degreeOfParallelism: 2, cancelSource.Token);
resultTask.ShouldBeFaulted()
.InnerExceptions.ShouldHaveSingleItem().ShouldBeSameAs(exception);
}
}
[Test]
public static async Task Cancellation_from_another_thread_prevents_further_enumeration()
{
using (var cancelSource = new CancellationTokenSource())
{
var enumerationDidContinue = false;
await new Func<Task<int>>[]
{
async () =>
{
await Task.Run(cancelSource.Cancel);
return 42;
},
() =>
{
enumerationDidContinue = true;
return Task.FromResult(42);
}
}.DoAllParallelAsync(f => f.Invoke(), degreeOfParallelism: 1, cancelSource.Token)
.ShouldCancelAsync(TimeSpan.FromSeconds(10));
enumerationDidContinue.ShouldBeFalse();
}
}
[Test]
public static void Cancellation_from_Current_prevents_further_enumeration()
{
using (var cancelSource = new CancellationTokenSource())
{
var enumerationDidContinue = false;
var enumerator = Substitute.For<IEnumerator<Task<int>>>();
enumerator.MoveNext().Returns(
_ => true,
_ =>
{
enumerationDidContinue = true;
return false;
});
enumerator.Current.Returns(_ =>
{
cancelSource.Cancel();
return Task.FromResult(42);
});
var enumerable = Substitute.For<IEnumerable<Task<int>>>();
enumerable.GetEnumerator().Returns(enumerator);
enumerable.DoAllParallelAsync(degreeOfParallelism: 1, cancelSource.Token)
.IsCanceled.ShouldBeTrue();
enumerationDidContinue.ShouldBeFalse();
}
}
[Test]
public static void Calling_with_canceled_token_prevents_GetEnumerator_call()
{
using (var cancelSource = new CancellationTokenSource())
{
var enumerable = Substitute.For<IEnumerable<Task<int>>>();
enumerable.GetEnumerator().Returns(_ => throw new AssertionException("GetEnumerator should not be called"));
cancelSource.Cancel();
enumerable.DoAllParallelAsync(degreeOfParallelism: 10, cancelSource.Token)
.IsCanceled.ShouldBeTrue();
}
}
[Test]
public static void Cancellation_from_GetEnumerator_prevents_any_enumeration()
{
using (var cancelSource = new CancellationTokenSource())
{
var enumerationDidContinue = false;
var enumerator = Substitute.For<IEnumerator<Task<int>>>();
enumerator.MoveNext().Returns(
_ =>
{
enumerationDidContinue = true;
return false;
});
var enumerable = Substitute.For<IEnumerable<Task<int>>>();
enumerable.GetEnumerator().Returns(_ =>
{
cancelSource.Cancel();
return enumerator;
});
enumerable.DoAllParallelAsync(degreeOfParallelism: 10, cancelSource.Token)
.IsCanceled.ShouldBeTrue();
enumerationDidContinue.ShouldBeFalse();
}
}
[Test]
public static void Cancellation_from_MoveNext_prevents_further_enumeration()
{
using (var cancelSource = new CancellationTokenSource())
{
var enumerationDidContinue = false;
var enumerator = Substitute.For<IEnumerator<Task<int>>>();
enumerator.MoveNext().Returns(
_ =>
{
cancelSource.Cancel();
return true;
},
_ =>
{
enumerationDidContinue = true;
return false;
});
enumerator.Current.Returns(42);
var enumerable = Substitute.For<IEnumerable<Task<int>>>();
enumerable.GetEnumerator().Returns(enumerator);
enumerable.DoAllParallelAsync(degreeOfParallelism: 1, cancelSource.Token)
.IsCanceled.ShouldBeTrue();
enumerationDidContinue.ShouldBeFalse();
}
}
[Test]
public static void Cancellation_from_MoveNext_still_accesses_Current_task_and_waits_for_it()
{
using (var cancelSource = new CancellationTokenSource())
{
var enumerator = Substitute.For<IEnumerator<Task<int>>>();
enumerator.MoveNext().Returns(
_ =>
{
cancelSource.Cancel();
return true;
},
_ => false);
var source = new TaskCompletionSource<int>();
enumerator.Current.Returns(source.Task);
var enumerable = Substitute.For<IEnumerable<Task<int>>>();
enumerable.GetEnumerator().Returns(enumerator);
var resultTask = enumerable.DoAllParallelAsync(degreeOfParallelism: 1, cancelSource.Token);
resultTask.IsCompleted.ShouldBeFalse();
var exception = new Exception();
source.SetException(exception);
resultTask.ShouldBeFaulted().InnerExceptions.ShouldHaveSingleItem().ShouldBeSameAs(exception);
}
}
[Test]
public static void Cancellation_from_Dispose_is_ignored()
{
using (var cancelSource = new CancellationTokenSource())
{
var enumerator = Substitute.For<IEnumerator<Task<int>>>();
enumerator.MoveNext().Returns(true, false);
enumerator.Current.Returns(42);
enumerator.When(e => e.Dispose()).Do(_ => cancelSource.Cancel());
var enumerable = Substitute.For<IEnumerable<Task<int>>>();
enumerable.GetEnumerator().Returns(enumerator);
enumerable.DoAllParallelAsync(degreeOfParallelism: 1, cancelSource.Token)
.ShouldCompleteSynchronously();
}
}
[Test]
public static void All_exceptions_are_collected_in_order()
{
var source1 = new TaskCompletionSource<int>();
var source2 = new TaskCompletionSource<int>();
var resultTask = new[] { source1, source2 }.DoAllParallelAsync(
s => s.Task,
degreeOfParallelism: 2,
CancellationToken.None);
source2.SetException(new[] { new Exception(), new Exception() });
source1.SetException(new[] { new Exception(), new Exception() });
resultTask.IsFaulted.ShouldBeTrue();
resultTask.Exception.InnerExceptions.ShouldBe(
source1.Task.Exception.InnerExceptions.Concat(
source2.Task.Exception.InnerExceptions));
}
[Test]
public static async Task InvalidOperationException_is_created_for_null_tasks()
{
var source1 = new TaskCompletionSource<int>();
var source2 = new TaskCompletionSource<int>();
var source3 = new TaskCompletionSource<int>();
var resultTask = new[] { source1, null, source2, null, source3 }.DoAllParallelAsync(
s => s?.Task,
degreeOfParallelism: 3,
CancellationToken.None);
source3.SetException(new[] { new Exception(), new Exception() });
source2.SetException(new[] { new Exception(), new Exception() });
source1.SetException(new[] { new Exception(), new Exception() });
var ex = await resultTask.ShouldFaultAsync(TimeSpan.FromSeconds(10));
ex.InnerExceptions.Count.ShouldBe(8);
ex.InnerExceptions[2].ShouldBeOfType<InvalidOperationException>();
ex.InnerExceptions[5].ShouldBeOfType<InvalidOperationException>();
}
[Test]
public static async Task Exception_accessing_Current_is_treated_like_a_faulted_task()
{
var exception = new Exception();
var enumerator = Substitute.For<IEnumerator<Task<int>>>();
enumerator.MoveNext().Returns(true, false);
enumerator.Current.Returns<Task<int>>(_ => throw exception);
var enumerable = Substitute.For<IEnumerable<Task<int>>>();
enumerable.GetEnumerator().Returns(enumerator);
var ex = await enumerable
.DoAllParallelAsync(
degreeOfParallelism: 2,
CancellationToken.None)
.ShouldFaultAsync(TimeSpan.FromSeconds(10));
ex.InnerExceptions.ShouldHaveSingleItem().ShouldBeSameAs(exception);
}
[Test]
public static async Task Exception_from_MoveNext_cancels_further_enumeration()
{
var exception = new Exception();
var enumerator = Substitute.For<IEnumerator<Task<int>>>();
enumerator.MoveNext().Returns(_ => throw exception, _ => false);
var enumerable = Substitute.For<IEnumerable<Task<int>>>();
enumerable.GetEnumerator().Returns(enumerator);
var ex = await enumerable
.DoAllParallelAsync(
degreeOfParallelism: 2,
CancellationToken.None)
.ShouldFaultAsync(TimeSpan.FromSeconds(10));
ex.InnerExceptions.ShouldHaveSingleItem().ShouldBeSameAs(exception);
enumerator.ReceivedCalls()
.Where(c => c.GetMethodInfo().Name != nameof(enumerator.Dispose))
.ShouldHaveSingleItem();
}
[Test]
public static async Task Exception_from_Dispose_is_treated_like_a_faulted_task()
{
var exception = new Exception();
var enumerator = Substitute.For<IEnumerator<Task<int>>>();
enumerator.MoveNext().Returns(_ => false);
enumerator.When(e => e.Dispose()).Throw(exception);
var enumerable = Substitute.For<IEnumerable<Task<int>>>();
enumerable.GetEnumerator().Returns(enumerator);
var ex = await enumerable
.DoAllParallelAsync(
degreeOfParallelism: 2,
CancellationToken.None)
.ShouldFaultAsync(TimeSpan.FromSeconds(10));
ex.InnerExceptions.ShouldHaveSingleItem().ShouldBeSameAs(exception);
}
[Test]
public static void Exception_from_GetEnumerator_should_not_be_handled()
{
var exception = new Exception();
var enumerator = Substitute.For<IEnumerator<Task<int>>>();
enumerator.MoveNext().Returns(_ => false);
enumerator.When(e => e.Dispose()).Throw(exception);
var enumerable = Substitute.For<IEnumerable<Task<int>>>();
enumerable.GetEnumerator().Returns(_ => throw exception);
var ex = Should.Throw<Exception>(() => enumerable.DoAllParallelAsync(
degreeOfParallelism: 2,
CancellationToken.None));
ex.ShouldBeSameAs(exception);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment