Created
March 2, 2021 20:29
-
-
Save noseratio/73dd6af4e4d710cb85cd4947256e9d5e to your computer and use it in GitHub Desktop.
Fluent awaiters concept
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Fluent awaiters concept, by @noseratio | |
// https://github.com/dotnet/runtime/issues/47525#issuecomment-788373184 | |
// | |
using System; | |
using System.Runtime.CompilerServices; | |
using System.Threading; | |
using System.Threading.Tasks; | |
namespace Noseratio.Experimental | |
{ | |
public static class TaskExt | |
{ | |
/// <summary> | |
/// IAwaitable | |
/// </summary> | |
public interface IAwaitable | |
{ | |
public IAwaiter GetAwaiter(); | |
} | |
/// <summary> | |
/// IAwaiter | |
/// </summary> | |
public interface IAwaiter: ICriticalNotifyCompletion | |
{ | |
public bool IsCompleted { get; } | |
public void GetResult(); | |
} | |
public static IgnoreContextAwaitable IgnoreContext(this Task @this) | |
{ | |
return new IgnoreContextAwaitable(@this); | |
} | |
public static IgnoreContextAwaitable IgnoreContext(this IAwaitable @this) | |
{ | |
return new IgnoreContextAwaitable(@this); | |
} | |
public static ForceAsyncAwaitable ForceAsync(this Task @this) | |
{ | |
return new ForceAsyncAwaitable(@this); | |
} | |
public static ForceAsyncAwaitable ForceAsync(this IAwaitable @this) | |
{ | |
return new ForceAsyncAwaitable(@this); | |
} | |
public static RestoreContextAwaitable RestoreContext(this Task @this) | |
{ | |
return new RestoreContextAwaitable(@this); | |
} | |
public static RestoreContextAwaitable RestoreContext(this IAwaitable @this) | |
{ | |
return new RestoreContextAwaitable(@this); | |
} | |
/// <summary> | |
/// IgnoreContextAwaitable | |
/// </summary> | |
public struct IgnoreContextAwaitable : IAwaitable, IAwaiter | |
{ | |
readonly Task _task; | |
readonly IAwaiter _parentAwaiter; | |
public IgnoreContextAwaitable(Task task) | |
{ | |
_task = task; | |
_parentAwaiter = null; | |
} | |
public IgnoreContextAwaitable(IAwaitable awaitable) | |
{ | |
_task = null; | |
_parentAwaiter = awaitable.GetAwaiter(); | |
} | |
public IAwaiter GetAwaiter() | |
{ | |
return this; | |
} | |
public bool IsCompleted => | |
_task != null ? _task.IsCompleted : _parentAwaiter.IsCompleted; | |
public void GetResult() | |
{ | |
if (_task != null) | |
{ | |
_task.GetAwaiter().GetResult(); | |
} | |
else | |
{ | |
_parentAwaiter.GetResult(); | |
} | |
} | |
public void OnCompleted(Action continuation) | |
{ | |
throw new NotImplementedException(nameof(OnCompleted)); | |
} | |
public void UnsafeOnCompleted(Action continuation) | |
{ | |
// an unsafe version that doesn't have to flow the execution context | |
if (_task != null) | |
{ | |
_task.ContinueWith( | |
continuationAction: _ => continuation(), | |
cancellationToken: CancellationToken.None, | |
TaskContinuationOptions.ExecuteSynchronously, | |
scheduler: TaskScheduler.Default); | |
} | |
else | |
{ | |
_parentAwaiter.UnsafeOnCompleted(continuation); | |
} | |
} | |
} | |
/// <summary> | |
/// ForceAsyncAwaitable | |
/// </summary> | |
public struct ForceAsyncAwaitable : IAwaitable, IAwaiter | |
{ | |
readonly Task _task; | |
readonly IAwaiter _parentAwaiter; | |
public ForceAsyncAwaitable(Task task) | |
{ | |
_task = task; | |
_parentAwaiter = null; | |
} | |
public ForceAsyncAwaitable(IAwaitable awaitable) | |
{ | |
_task = null; | |
_parentAwaiter = awaitable.GetAwaiter(); | |
} | |
public IAwaiter GetAwaiter() | |
{ | |
return this; | |
} | |
public bool IsCompleted => false; // to make sure UnsafeOnCompleted gets called | |
public void GetResult() | |
{ | |
if (_task != null) | |
{ | |
_task.GetAwaiter().GetResult(); | |
} | |
else | |
{ | |
_parentAwaiter.GetResult(); | |
} | |
} | |
public void OnCompleted(Action continuation) | |
{ | |
throw new NotImplementedException(nameof(OnCompleted)); | |
} | |
private bool IsActuallyCompleted() | |
{ | |
return _task != null ? _task.IsCanceled : _parentAwaiter.IsCompleted; | |
} | |
public void UnsafeOnCompleted(Action continuation) | |
{ | |
if (IsActuallyCompleted()) | |
{ | |
// if the task is already completed when UnsafeOnCompleted is called | |
// (at await point), then we don't force asynchrony | |
continuation(); | |
return; | |
} | |
void InvokeContinuation() | |
{ | |
// we force asynchrony on the current context | |
var sc = SynchronizationContext.Current; | |
if (sc != null) | |
{ | |
sc.Post(_ => continuation(), null); | |
} | |
else | |
{ | |
ThreadPool.UnsafeQueueUserWorkItem(_ => continuation(), null); | |
} | |
} | |
// otherwise, inovoke the continuation asynchronously | |
// on whatever context the task's or the parent awaiter's continuation has been called | |
if (_task != null) | |
{ | |
// an unsafe version that doesn't have to flow the execution context | |
_task.ContinueWith( | |
continuationAction: _ => InvokeContinuation(), | |
cancellationToken: CancellationToken.None, | |
TaskContinuationOptions.ExecuteSynchronously, | |
scheduler: TaskScheduler.Default); | |
} | |
else | |
{ | |
_parentAwaiter.UnsafeOnCompleted(InvokeContinuation); | |
} | |
} | |
} | |
/// <summary> | |
/// RestoreContextAwaitable | |
/// </summary> | |
public struct RestoreContextAwaitable : IAwaitable, IAwaiter | |
{ | |
readonly Task _task; | |
readonly IAwaiter _parentAwaiter; | |
public RestoreContextAwaitable(Task task) | |
{ | |
_task = task; | |
_parentAwaiter = null; | |
} | |
public RestoreContextAwaitable(IAwaitable awaitable) | |
{ | |
_task = null; | |
_parentAwaiter = awaitable.GetAwaiter(); | |
} | |
public IAwaiter GetAwaiter() | |
{ | |
return this; | |
} | |
public bool IsCompleted => | |
_task != null ? _task.IsCompleted : _parentAwaiter.IsCompleted; | |
public void GetResult() | |
{ | |
if (_task != null) | |
{ | |
_task.GetAwaiter().GetResult(); | |
} | |
else | |
{ | |
_parentAwaiter.GetResult(); | |
} | |
} | |
public void OnCompleted(Action continuation) | |
{ | |
throw new NotImplementedException(nameof(OnCompleted)); | |
} | |
public void UnsafeOnCompleted(Action continuation) | |
{ | |
var sc = SynchronizationContext.Current; | |
void InvokeContinuation() | |
{ | |
if (SynchronizationContext.Current != sc) | |
{ | |
sc.Post(_ => continuation(), null); | |
} | |
else | |
{ | |
continuation(); | |
} | |
} | |
if (_task != null) | |
{ | |
// an unsafe version that doesn't have to flow the execution context | |
_task.ContinueWith( | |
continuationAction: _ => InvokeContinuation(), | |
cancellationToken: CancellationToken.None, | |
TaskContinuationOptions.ExecuteSynchronously, | |
scheduler: TaskScheduler.Default); | |
} | |
else | |
{ | |
_parentAwaiter.UnsafeOnCompleted(InvokeContinuation); | |
} | |
} | |
} | |
} | |
/// <summary> | |
/// Program | |
/// </summary> | |
static class Program | |
{ | |
static async Task Main(string[] args) | |
{ | |
try | |
{ | |
await Task.Delay(5000).RestoreContext().ForceAsync(); | |
Console.WriteLine("Done"); | |
} | |
catch (Exception exception) | |
{ | |
Console.WriteLine(exception); | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment