Skip to content

Instantly share code, notes, and snippets.

@neuecc
Created May 7, 2020 11:13
Show Gist options
  • Save neuecc/9cef2b79a0828796306fd835a0641fbb to your computer and use it in GitHub Desktop.
Save neuecc/9cef2b79a0828796306fd835a0641fbb to your computer and use it in GitHub Desktop.
using System;
using System.Threading;
namespace Cysharp.Threading.Tasks
{
public interface IUniTaskAsyncEnumerable<out T>
{
IUniTaskAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default);
}
public interface IUniTaskAsyncEnumerator<out T> : IUniTaskAsyncDisposable
{
T Current { get; }
UniTask<bool> MoveNextAsync();
}
public interface IUniTaskAsyncDisposable
{
UniTask DisposeAsync();
}
public static class UniTaskAsyncEnumerable
{
public static IUniTaskAsyncEnumerable<TResult> Select<TSource, TResult>(this IUniTaskAsyncEnumerable<TSource> source, Func<TSource, TResult> selector)
{
return new Cysharp.Threading.Tasks.Linq.Select<TSource, TResult>(source, selector);
}
}
}
namespace Cysharp.Threading.Tasks.Linq
{
public abstract class AsyncEnumeratorBase<TSource, TResult> : IUniTaskAsyncEnumerator<TResult>, IUniTaskSource<bool>
{
static Action<object> moveNextCallbackDelegate = MoveNextCallBack;
readonly IUniTaskAsyncEnumerable<TSource> source;
CancellationToken cancellationToken;
UniTaskCompletionSourceCore<bool> completionSource;
IUniTaskAsyncEnumerator<TSource> enumerator;
UniTask<bool>.Awaiter sourceMoveNext;
public AsyncEnumeratorBase(IUniTaskAsyncEnumerable<TSource> source, CancellationToken cancellationToken)
{
this.source = source;
this.cancellationToken = cancellationToken;
}
// abstract
protected abstract bool MoveNextCore(bool sourceHasCurrent);
// Util
protected TSource SourceCurrent => enumerator.Current;
// IUniTaskAsyncEnumerator<T>
public TResult Current { get; protected set; }
public UniTask<bool> MoveNextAsync()
{
completionSource.Reset();
if (enumerator == null)
{
enumerator = source.GetAsyncEnumerator(cancellationToken);
}
sourceMoveNext = enumerator.MoveNextAsync().GetAwaiter();
if (sourceMoveNext.IsCompleted)
{
bool result = false;
try
{
result = MoveNextCore(sourceMoveNext.GetResult());
}
catch (Exception ex)
{
completionSource.TrySetException(ex);
goto RETURN;
}
if (cancellationToken.IsCancellationRequested)
{
completionSource.TrySetCanceled(cancellationToken);
}
else
{
completionSource.TrySetResult(result);
}
}
else
{
sourceMoveNext.SourceOnCompleted(moveNextCallbackDelegate, this);
}
RETURN:
return new UniTask<bool>(this, completionSource.Version);
}
static void MoveNextCallBack(object state)
{
var self = (AsyncEnumeratorBase<TSource, TResult>)state;
bool result;
try
{
result = self.MoveNextCore(self.sourceMoveNext.GetResult());
}
catch (Exception ex)
{
self.completionSource.TrySetException(ex);
return;
}
if (self.cancellationToken.IsCancellationRequested)
{
self.completionSource.TrySetCanceled(self.cancellationToken);
}
else
{
self.completionSource.TrySetResult(result);
}
}
// if require additional resource to dispose, override and call base.DisposeAsync.
public virtual UniTask DisposeAsync()
{
if (enumerator != null)
{
return enumerator.DisposeAsync();
}
return default;
}
// IUniTaskSource<bool>
public bool GetResult(short token)
{
return completionSource.GetResult(token);
}
public UniTaskStatus GetStatus(short token)
{
return completionSource.GetStatus(token);
}
public void OnCompleted(Action<object> continuation, object state, short token)
{
completionSource.OnCompleted(continuation, state, token);
}
public UniTaskStatus UnsafeGetStatus()
{
return completionSource.UnsafeGetStatus();
}
void IUniTaskSource.GetResult(short token)
{
completionSource.GetResult(token);
}
}
internal class Select<TSource, TResult> : IUniTaskAsyncEnumerable<TResult>
{
readonly IUniTaskAsyncEnumerable<TSource> source;
readonly Func<TSource, TResult> selector;
public Select(IUniTaskAsyncEnumerable<TSource> source, Func<TSource, TResult> selector)
{
this.source = source;
this.selector = selector;
}
public IUniTaskAsyncEnumerator<TResult> GetAsyncEnumerator(CancellationToken cancellationToken = default)
{
return new Enumerator(source, selector, cancellationToken);
}
class Enumerator : AsyncEnumeratorBase<TSource, TResult>
{
readonly Func<TSource, TResult> selector;
public Enumerator(IUniTaskAsyncEnumerable<TSource> source, Func<TSource, TResult> selector, CancellationToken cancellationToken)
: base(source, cancellationToken)
{
this.selector = selector;
}
protected override bool MoveNextCore(bool sourceHasCurrent)
{
if (sourceHasCurrent)
{
Current = selector(SourceCurrent);
return true;
}
else
{
return false;
}
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment