Skip to content

Instantly share code, notes, and snippets.

@jnm2
Last active February 1, 2020 20:21
Show Gist options
  • Save jnm2/ab5624b10efd1ae6fbd6aa8f081a0ec9 to your computer and use it in GitHub Desktop.
Save jnm2/ab5624b10efd1ae6fbd6aa8f081a0ec9 to your computer and use it in GitHub Desktop.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
partial class AmbientTasks
{
private sealed class AmbientTaskContext
{
private readonly Action<Exception> exceptionHandler;
/// <summary>
/// Doubles as a lockable object for all access to <see cref="exceptions"/>, <see cref="currentTaskCount"/>, and <see cref="waitAllSource"/>.
/// </summary>
private readonly List<Exception> exceptions = new List<Exception>();
private int currentTaskCount;
private TaskCompletionSource<object> waitAllSource;
public AmbientTaskContext(Action<Exception> exceptionHandler)
{
this.exceptionHandler = exceptionHandler;
}
public bool RecordAndTrySuppress(Exception exception)
{
if (exceptionHandler is null)
{
lock (exceptions)
exceptions.Add(exception);
return false;
}
try
{
exceptionHandler.Invoke(exception);
}
catch (Exception handlerException)
{
lock (exceptions)
exceptions.Add(exception);
try
{
exceptionHandler.Invoke(handlerException);
}
catch (Exception secondHandlerException)
{
lock (exceptions)
{
exceptions.Add(handlerException);
exceptions.Add(secondHandlerException);
}
}
}
return true;
}
public void StartTask()
{
lock (exceptions)
{
currentTaskCount = checked(currentTaskCount + 1);
}
}
public void EndTask()
{
TaskCompletionSource<object> sourceToComplete;
Exception[] bufferedExceptions;
lock (exceptions)
{
var newCount = currentTaskCount - 1;
if (newCount < 0) throw new InvalidOperationException($"More calls to {nameof(EndTask)} than {nameof(StartTask)}.");
currentTaskCount = newCount;
if (newCount > 0) return;
sourceToComplete = waitAllSource;
if (sourceToComplete is null) return; // No one is waiting
waitAllSource = null;
bufferedExceptions = exceptions.ToArray();
exceptions.Clear();
}
// Do not set the source inside the lock. Arbitrary user continuations may have been set
// on sourceToComplete.Task since it was previously returned from WaitAllAsync, and executing
// arbitrary user code within a lock is a Very Bad Idea™.
if (bufferedExceptions.Any())
sourceToComplete.SetException(bufferedExceptions);
else
sourceToComplete.SetResult(null);
}
public Task WaitAllAsync()
{
lock (exceptions)
{
if (waitAllSource != null) return waitAllSource.Task;
if (currentTaskCount > 0)
{
waitAllSource = new TaskCompletionSource<object>();
return waitAllSource.Task;
}
if (exceptions.Any())
{
var source = new TaskCompletionSource<object>();
source.SetException(exceptions);
exceptions.Clear();
return source.Task;
}
}
return Task.CompletedTask;
}
}
}
using System;
using System.ComponentModel;
using System.Runtime.ExceptionServices;
using System.Threading;
using System.Threading.Tasks;
/// <summary>
/// Enables scoped completion tracking and error handling of tasks as an alternative to <c>async void</c>.
/// Easy to produce and consume, and test-friendly.
/// </summary>
public static partial class AmbientTasks
{
private static readonly AsyncLocal<AmbientTaskContext> Context = new AsyncLocal<AmbientTaskContext>();
private static AmbientTaskContext CurrentContext => Context.Value ?? (Context.Value = new AmbientTaskContext(exceptionHandler: null));
/// <summary>
/// <para>
/// Replaces the current async-local scope with a new scope which has its own exception handler and isolated set
/// of tracked tasks.
/// </para>
/// <para>If <paramref name="exceptionHandler"/> is <see langword="null"/>, exceptions will be left uncaught. In
/// the case of tracked <see cref="Task"/> objects, the exception will be rethrown on the synchronization
/// context which began tracking it.
/// </para>
/// </summary>
public static void BeginContext(Action<Exception> exceptionHandler = null)
{
Context.Value = new AmbientTaskContext(exceptionHandler);
}
/// <summary>
/// Waits until all tracked tasks are complete. Any exceptions that were not handled, including exceptions
/// thrown by an exception handler, will be included as inner exceptions of the <see cref="Task.Exception"/>
/// property.
/// </summary>
public static Task WaitAllAsync() => CurrentContext.WaitAllAsync();
/// <summary>
/// <para>
/// Begins tracking a <see cref="Task"/> so that any exception is handled and so that <see cref="WaitAllAsync"/>
/// waits for its completion.
/// </para>
/// <para>
/// Once passed to this method, a task’s exception will never be unobserved. If the task faults or is already
/// faulted and an exception handler is currently registered (see <see cref="BeginContext"/>), the handler will
/// receive the task’s <see cref="AggregateException"/>. If no handler has been registered, the <see
/// cref="AggregateException"/> will be rethrown on the <see cref="SynchronizationContext"/> that was current
/// when <see cref="Add"/> was called. (If there was no synchronization context, it will be rethrown immediately
/// by a continuation requesting <see cref="TaskContinuationOptions.ExecuteSynchronously"/>.)
/// </para>
/// </summary>
public static void Add(Task task)
{
switch (task?.Status)
{
case null:
case TaskStatus.Canceled:
case TaskStatus.RanToCompletion:
break;
case TaskStatus.Faulted:
OnTaskCompleted(task, state: (CurrentContext, SynchronizationContext.Current));
break;
default:
var context = CurrentContext;
context.StartTask();
task.ContinueWith(
OnTaskCompleted,
state: (context, SynchronizationContext.Current),
TaskContinuationOptions.ExecuteSynchronously);
break;
}
}
private static void OnTaskCompleted(Task completedTask, object state)
{
var (context, addSynchronizationContext) = ((AmbientTaskContext, SynchronizationContext))state;
if (completedTask.IsFaulted)
{
// Send AggregateException to registered global handler
if (!context.RecordAndTrySuppress(completedTask.Exception))
{
var exceptionInfo = ExceptionDispatchInfo.Capture(completedTask.Exception);
if (addSynchronizationContext is null)
OnTaskFaultWithoutHandler(exceptionInfo);
else
addSynchronizationContext.Post(OnTaskFaultWithoutHandler, state: exceptionInfo);
}
}
context.EndTask();
}
private static void OnTaskFaultWithoutHandler(object state)
{
((ExceptionDispatchInfo)state).Throw();
}
/// <summary>
/// <para>
/// Executes the specified delegate on the current <see cref="SynchronizationContext"/> while tracking so that
/// any exception is handled and so that <see cref="WaitAllAsync"/> waits for its completion.
/// </para>
/// <para>
/// A default <see cref="SynchronizationContext"/> is installed if the current one is <see langword="null"/>.
/// </para>
/// <para>
/// If an exception handler has been registered (see <see cref="BeginContext"/>), any exception will be caught
/// and routed to the handler instead of <see cref="WaitAllAsync"/>. If no handler has been registered, the
/// exception will not be caught even though it will be recorded and thrown by
/// <see cref="WaitAllAsync"/>.
/// </para>
/// </summary>
public static void Post(SendOrPostCallback d, object state)
{
// Install a default synchronization context if one does not exist
Post(AsyncOperationManager.SynchronizationContext, d, state);
}
/// <summary>
/// <para>
/// Executes the specified delegate on the current <see cref="SynchronizationContext"/> while tracking so that
/// any exception is handled and so that <see cref="WaitAllAsync"/> waits for its completion.
/// </para>
/// <para>
/// <see cref="ArgumentNullException"/> is thrown if <paramref name="synchronizationContext"/> is <see
/// langword="null"/>.
/// </para>
/// <para>
/// If an exception handler has been registered (see <see cref="BeginContext"/>), any exception will be
/// caught and routed to the handler instead of <see cref="WaitAllAsync"/>. If no handler has been registered,
/// the exception will not be caught even though it will be recorded and thrown by
/// <see cref="WaitAllAsync"/>.
/// </para>
/// </summary>
public static void Post(SynchronizationContext synchronizationContext, SendOrPostCallback d, object state)
{
if (synchronizationContext is null)
throw new ArgumentNullException(nameof(synchronizationContext));
if (d is null) return;
var context = CurrentContext;
context.StartTask();
synchronizationContext.Post(OnPost, (context, d, state));
}
private static void OnPost(object state)
{
var (context, d, invokeState) = ((AmbientTaskContext, SendOrPostCallback, object))state;
try
{
d.Invoke(invokeState);
}
catch (Exception ex) when (context.RecordAndTrySuppress(ex))
{
}
finally
{
context.EndTask();
}
}
/// <summary>
/// <para>
/// Executes the specified delegate on the current <see cref="SynchronizationContext"/> while tracking so that
/// any exception is handled and so that <see cref="WaitAllAsync"/> waits for its completion.
/// </para>
/// <para>
/// A default <see cref="SynchronizationContext"/> is installed if the current one is <see langword="null"/>.
/// </para>
/// <para>
/// If an exception handler has been registered (see <see cref="BeginContext"/>), any exception will be caught
/// and routed to the handler instead of <see cref="WaitAllAsync"/>. If no handler has been registered, the
/// exception will not be caught even though it will be recorded and thrown by
/// <see cref="WaitAllAsync"/>.
/// </para>
/// </summary>
public static void Post(Action postCallbackAction)
{
// Install a default synchronization context if one does not exist
Post(AsyncOperationManager.SynchronizationContext, postCallbackAction);
}
/// <summary>
/// <para>
/// Executes the specified delegate on the current <see cref="SynchronizationContext"/> while tracking so that
/// any exception is handled and so that <see cref="WaitAllAsync"/> waits for its completion.
/// </para>
/// <para>
/// <see cref="ArgumentNullException"/> is thrown if <paramref name="synchronizationContext"/> is <see
/// langword="null"/>.
/// </para>
/// <para>
/// If an exception handler has been registered (see <see cref="RegisterHandler"/>), any exception will be
/// caught and routed to the handler instead of <see cref="WaitAllAsync"/>. If no handler has been registered,
/// the exception will not be caught even though it will be recorded and thrown by
/// <see cref="WaitAllAsync"/>.
/// </para>
/// </summary>
public static void Post(SynchronizationContext synchronizationContext, Action postCallbackAction)
{
if (synchronizationContext is null)
throw new ArgumentNullException(nameof(synchronizationContext));
if (postCallbackAction is null) return;
var context = CurrentContext;
context.StartTask();
synchronizationContext.Post(OnPostAction, (context, postCallbackAction));
}
private static void OnPostAction(object state)
{
var (context, action) = ((AmbientTaskContext, Action))state;
try
{
action.Invoke();
}
catch (Exception ex) when (context.RecordAndTrySuppress(ex))
{
}
finally
{
context.EndTask();
}
}
}
@shadowbane1000
Copy link

@jnm2 From what I've seen so far, this looks great! I changed my DontWait() task extension method to push into this, and it was a clean change over. What kind of license would you put on it, so that I can know if I can use it?

public static class TaskExtensions {
  public static void DontWait(this Task task){
    AmbientTasks.Add(task);
  }
}

usage now looks like this:
Task.Delay(100).DontWait();

A github project might not be a bad idea. Up to you though. The tests would be nice to have if I end up being able to use it.

@jnm2
Copy link
Author

jnm2 commented Apr 1, 2019

@shadowbane1000 I'd like to use the MIT license. Let me see about getting this moved over to a repository. Would you use a package if I published it to NuGet?

Gist comments don't notify yet, so feel free to @ me somewhere (e.g. https://gitter.im/jnm2).

@jnm2
Copy link
Author

jnm2 commented Apr 1, 2019

The current name of your extension method might obscure the fact that the task might later be waited on via AmbientTasks.WaitAllAsync.
Is directly using AmbientTasks.Add(...) not as good in your estimation as an extension method?

@shadowbane1000
Copy link

Sorry for the delayed response. I was out of town for a week.
Personally, I like the extension method, but I guess it's a matter of personal preference. In my case, I want it to look very fire-and-forget, and simple. I guess it could be renamed to something more like AddToAmbientTasks, though I would probably still use DontWait. LOL. But I don't know that I would make an extension method part of your package. It's just preference.

Yes, if it were a NuGet package, I would use it. I don't know that there would be many other users without some way of finding people with a need.

@jnm2
Copy link
Author

jnm2 commented Jun 23, 2019

@shadowbane1000 All the tests are written! CI builds are now available at https://www.myget.org/feed/ambienttasks/package/nuget/AmbientTasks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment