Skip to content

Instantly share code, notes, and snippets.

@aelij
Last active July 5, 2019 21:26
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save aelij/5d046b86bfca13fb682c411852d08cfd to your computer and use it in GitHub Desktop.
Save aelij/5d046b86bfca13fb682c411852d08cfd to your computer and use it in GitHub Desktop.
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Security;
using System.Threading;
using System.Threading.Tasks;
namespace AsyncEnumeratorGenerator
{
class Program
{
static void Main(string[] args)
{
var task = Run();
try
{
task.Wait();
}
catch (Exception ex)
{
Console.WriteLine(task.Status);
Console.WriteLine(ex);
}
}
private static async Task Run()
{
// run in parallel
var enumerable = GetValuesAsync();
await Task.WhenAll(Enumerable.Range(0, 10).Select(i => Print(enumerable)));
// cancellation
var cts = new CancellationTokenSource();
cts.CancelAfter(100);
await Print(Where(Where(AsyncEnumerable.Range(0, 1000000), i => i > 100, cts.Token), i => i % 2 == 0, cts.Token));
}
private static async Task Print(AsyncEnumerable<int> enumerable)
{
var enumerator = enumerable.GetEnumerator();
while (await enumerator.MoveNext())
{
Console.WriteLine(enumerator.Current);
}
}
public static AsyncEnumerable<int> GetValuesAsync()
{
return new AsyncEnumerable<int>(GetValuesAsyncEnumerator);
}
private static async AsyncEnumerator<int> GetValuesAsyncEnumerator()
{
for (int i = 0; i < 5; i++)
{
await Task.Delay(200);
await Task.Run(() => i).YieldReturn();
}
return default(int); // dummy value
}
public static AsyncEnumerable<T> Where<T>(IAsyncEnumerable<T> enumerable, Func<T, bool> predicate, CancellationToken cancellationToken)
{
return new AsyncEnumerable<T>(() => WhereEnumerator<T>(enumerable, predicate, cancellationToken));
}
private static async AsyncEnumerator<T> WhereEnumerator<T>(IAsyncEnumerable<T> enumerable, Func<T, bool> predicate, CancellationToken cancellationToken)
{
var enumerator = enumerable.GetEnumerator();
while (await enumerator.MoveNext(cancellationToken))
{
if (predicate(enumerator.Current))
{
// we can use a value task or a simple value awaitable here instead
await Task.FromResult(enumerator.Current).YieldReturn();
}
}
return default(T);
}
}
public static class YieldReturnExtensions
{
public static YieldReturnAwaitable<T> YieldReturn<T>(this Task<T> task)
{
return new YieldReturnAwaitable<T>(task);
}
}
public struct YieldReturnAwaitable<TResult>
{
private readonly Task<TResult> _task;
public YieldReturnAwaitable(Task<TResult> task)
{
_task = task;
}
public YieldReturnAwaiter GetAwaiter() => new YieldReturnAwaiter(_task);
public struct YieldReturnAwaiter : ICriticalNotifyCompletion, INotifyCompletion
{
private readonly TaskAwaiter<TResult> _awaiter;
// if we return true, the state machine would skip calling AwaitOnCompleted
// and we won't be able to yield
// instead we check IsCompletedInternal in the Await method to optimize the continuation
public bool IsCompleted => false;
internal bool IsCompletedInternal => _awaiter.IsCompleted;
internal YieldReturnAwaiter(Task<TResult> task)
{
_awaiter = task.GetAwaiter();
}
[SecuritySafeCritical]
public void OnCompleted(Action continuation) => _awaiter.OnCompleted(continuation);
[SecurityCritical]
public void UnsafeOnCompleted(Action continuation) => _awaiter.UnsafeOnCompleted(continuation);
public TResult GetResult() => _awaiter.GetResult();
}
}
public sealed class AsyncEnumerable<T> : IAsyncEnumerable<T>
{
private readonly Func<IAsyncEnumerator<T>> _getEnumerator;
public AsyncEnumerable(Func<IAsyncEnumerator<T>> getEnumerator)
{
_getEnumerator = getEnumerator;
}
public IAsyncEnumerator<T> GetEnumerator() => _getEnumerator();
}
public sealed class AsyncEnumerator<T> : IAsyncEnumerator<T>
{
private static readonly Func<bool> _emptyFunc = () => false;
private AsyncEnumeratorTaskMethodBuilder<T> _builder;
internal AsyncEnumerator(AsyncEnumeratorTaskMethodBuilder<T> builder)
{
_builder = builder;
}
public T Current => _builder._current;
public void Dispose()
{
}
public Task<bool> MoveNext(CancellationToken cancellationToken)
{
if (cancellationToken.IsCancellationRequested)
{
// same as Task.FromCancelled
return new Task<bool>(_emptyFunc, cancellationToken);
}
else
{
_builder._tcs = new TaskCompletionSource<bool>();
_builder._stateMachine.MoveNext();
return _builder._tcs.Task;
}
}
[EditorBrowsable(EditorBrowsableState.Never)]
public static AsyncEnumeratorTaskMethodBuilder<T> CreateAsyncMethodBuilder() => AsyncEnumeratorTaskMethodBuilder<T>.Create();
}
public sealed class AsyncEnumeratorTaskMethodBuilder<T>
{
private YieldReturnAwaitable<T>.YieldReturnAwaiter? _yieldReturnAwaiter;
internal IAsyncStateMachine _stateMachine;
internal T _current;
internal TaskCompletionSource<bool> _tcs;
public static AsyncEnumeratorTaskMethodBuilder<T> Create() => new AsyncEnumeratorTaskMethodBuilder<T>();
public void Start<TStateMachine>(ref TStateMachine stateMachine) where TStateMachine : IAsyncStateMachine
{
_stateMachine = stateMachine;
}
public void SetStateMachine(IAsyncStateMachine stateMachine)
{
}
public void SetResult(T result)
{
// ignore the result value
if (_tcs != null)
{
_tcs.TrySetResult(false);
}
}
public void SetException(Exception exception)
{
if (_tcs != null)
{
if (exception is OperationCanceledException)
{
_tcs.TrySetCanceled();
}
else
{
_tcs.TrySetException(exception);
}
}
}
public AsyncEnumerator<T> Task => new AsyncEnumerator<T>(this);
public void AwaitOnCompleted<TAwaiter, TStateMachine>(ref TAwaiter awaiter, ref TStateMachine stateMachine)
where TAwaiter : INotifyCompletion
where TStateMachine : IAsyncStateMachine
{
Await(ref awaiter, ref stateMachine);
}
[SecuritySafeCritical]
public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(ref TAwaiter awaiter, ref TStateMachine stateMachine)
where TAwaiter : ICriticalNotifyCompletion
where TStateMachine : IAsyncStateMachine
{
Await(ref awaiter, ref stateMachine);
}
private void Await<TAwaiter, TStateMachine>(ref TAwaiter awaiter, ref TStateMachine stateMachine)
where TAwaiter : INotifyCompletion
where TStateMachine : IAsyncStateMachine
{
_yieldReturnAwaiter = awaiter as YieldReturnAwaitable<T>.YieldReturnAwaiter?;
if (_yieldReturnAwaiter?.IsCompletedInternal == true)
{
InvokeMoveNext();
}
else
{
var runner = new MoveNextRunner(ExecutionContext.Capture(), _stateMachine, this);
awaiter.OnCompleted(() => runner.Run());
}
}
internal void InvokeMoveNext()
{
if (_yieldReturnAwaiter == null)
{
// this is a "normal" await - just continue async execution
_stateMachine.MoveNext();
return;
}
try
{
// GetResult will be called again by the async state machine (oh, well :)
_current = _yieldReturnAwaiter.Value.GetResult();
_tcs.TrySetResult(true);
}
catch (Exception ex)
{
SetException(ex);
}
}
private sealed class MoveNextRunner
{
private readonly ExecutionContext _context;
private readonly IAsyncStateMachine _stateMachine;
[SecurityCritical]
private static ContextCallback _invokeMoveNext;
private readonly AsyncEnumeratorTaskMethodBuilder<T> _builder;
[SecurityCritical]
internal MoveNextRunner(ExecutionContext context, IAsyncStateMachine stateMachine, AsyncEnumeratorTaskMethodBuilder<T> builder)
{
_context = context;
_stateMachine = stateMachine;
_builder = builder;
}
[SecuritySafeCritical]
internal void Run()
{
if (_context != null)
{
try
{
ContextCallback contextCallback = _invokeMoveNext;
if (contextCallback == null)
{
contextCallback = (_invokeMoveNext = new ContextCallback(InvokeMoveNext));
}
ExecutionContext.Run(_context, contextCallback, _builder);
return;
}
finally
{
_context.Dispose();
}
}
_builder.InvokeMoveNext();
}
[SecurityCritical]
private static void InvokeMoveNext(object builder)
{
((AsyncEnumeratorTaskMethodBuilder<T>)builder).InvokeMoveNext();
}
}
}
}
@aelij
Copy link
Author

aelij commented Aug 10, 2016

Stuff to review:

  • ExecutionContext capturing
  • Cancellation
  • Should we clone the async state machine each time we execute GetEnumerator()?
  • Should we throw if we YieldReturn() on the wrong type? (there's no compile-time check)

@aelij
Copy link
Author

aelij commented Aug 11, 2016

Updated version:

  • The async method now generates an AsyncEnumerator<T> instead of an enumerable.
    • I've come to the conclusion it's not possible to really create a factory from the async method since the builder field is instantiated once inside the state machine and cannot be modified (it works as long as you don't try to execute the enumerators in parallel, but that's not an option IMO).
    • So now we use two methods, one that the compiler generates - the enumerator, and another "factory" method - for the enumerable.
  • Cancellation should now work as expected.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment