Skip to content

Instantly share code, notes, and snippets.

@BastianBlokland
Last active October 30, 2019 14:28
Show Gist options
  • Save BastianBlokland/cd01667cc63b1871bd57bd975c49c2e2 to your computer and use it in GitHub Desktop.
Save BastianBlokland/cd01667cc63b1871bd57bd975c49c2e2 to your computer and use it in GitHub Desktop.
General purpose thread-safe event channel class for distributing events based on type.

General purpose thread-safe event channel class for distributing events based on type.

Uses the SynchronizedEvent<T> class from an earlier gist.

This builds on the features of the SynchronizedEvent<T> and adds distributing events based on type. This is usefull when for example you want to have a publisher for a generic type (like IMessage) and then allow subscribing to derrived types (like NameChangedMessage).

Example:

class Program
{
    interface IMessage
    {
    }

    class NameChangedMessage : IMessage
    {
        public NameChangedMessage(string newName) => this.NewName = newName;

        public string NewName { get; }
    }

    class ItemPurchasedMessage : IMessage
    {
        public ItemPurchasedMessage(int itemId) => this.ItemId = itemId;

        public int ItemId { get; }
    }

    static async Task Main()
    {
        var networkReceiver = new NetworkReceiver();

        // Async style:
        var nameChangedMessage = await networkReceiver.MessageChannel.WaitAsync<NameChangedMessage>();
        Console.WriteLine($"Got name changed async: '{nameChangedMessage.NewName}'");

        // Callback style:
        networkReceiver.MessageChannel.Subscribe<NameChangedMessage>(OnNameChanged);
        networkReceiver.MessageChannel.Subscribe<ItemPurchasedMessage>(OnItemPurchased);

        Console.ReadKey();

        networkReceiver.MessageChannel.Unsubscribe<NameChangedMessage>(OnNameChanged);
        networkReceiver.MessageChannel.Unsubscribe<ItemPurchasedMessage>(OnItemPurchased);
    }

    static void OnNameChanged(NameChangedMessage msg) => Console.WriteLine($"Name changed to: '{msg.NewName}'");
    static void OnItemPurchased(ItemPurchasedMessage msg) => Console.WriteLine($"Item purchased: '{msg.ItemId}'");

    class NetworkReceiver : IExceptionHandler
    {
        readonly SynchronizedChannel<IMessage> messageChannel;
        readonly Task receiveLoop;

        public NetworkReceiver()
        {
            this.messageChannel = new SynchronizedChannel<IMessage>(
                exceptionHandler: this,
                storeUnobservedData: true);
            this.receiveLoop = Task.Run(ReceiveLoopAsync);
        }

        public IReadOnlySynchronizedChannel<IMessage> MessageChannel => this.messageChannel;

        void IExceptionHandler.Handle(Exception e) => Console.Error.Write(e);

        async Task ReceiveLoopAsync()
        {
            while (true)
            {
                var message = await GetMessageAsync();
                this.messageChannel.Invoke(message);
            }
        }

        async Task<IMessage> GetMessageAsync()
        {
            await Task.Delay(1000);
            return new Random().NextDouble() > .5 ?
                (IMessage)new ItemPurchasedMessage(1337) :
                new NameChangedMessage("John");
        }
    }
}

SynchronizedChannel<TBase>:

using System;
using System.Collections.Concurrent;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;

/// <summary>
/// Thread-safe channel for receiving events. The channel has a base type of the events you want
/// to receive, for example: 'IEvent' and then you can listen for events that are derrived from
/// that base-type, for example 'NameChangedEvent'.
/// </summary>
/// <typeparam name="TBase">Base type that events need to derrive from.</typeparam>
public interface IReadOnlySynchronizedChannel<TBase>
{
    /// <summary>
    /// Wait for an event of type T to be fired.
    /// </summary>
    /// <param name="cancelToken">Token to be able to cancel the task.</param>
    /// <typeparam name="T">Type of event to receive.</typeparam>
    /// <returns>Task that completes when an event of type T is received or is cancelled.</returns>
    Task<T> WaitAsync<T>(CancellationToken cancelToken = default)
        where T : class, TBase;

    /// <summary>
    /// Subscribe to events of type T.
    /// </summary>
    /// <param name="action">Action to invoke when an event of type T is fired.</param>
    /// <param name="subscriptionToken">
    /// Token to use for unsubscribing, if none is provided then 'action' will be used.
    /// </param>
    /// <param name="callOnCapturedContext">
    /// Should the action only be called on the SynchronizationContext that was active when
    /// subscribing.
    /// </param>
    /// <typeparam name="T">Type of events to listen for.</typeparam>
    void Subscribe<T>(Action<T> action, object subscriptionToken = null, bool callOnCapturedContext = true)
        where T : class, TBase;

    /// <summary>
    /// Unsubscribe from events of type T.
    /// </summary>
    /// <param name="action">Action that was used as the subscriptionToken.</param>
    /// <typeparam name="T">Type of events to unsubscribe from.</typeparam>
    /// <returns>True if successfully unsubscribed otherwise False.</returns>
    bool Unsubscribe<T>(Action<T> action)
        where T : class, TBase;

    /// <summary>
    /// Unsubscribe from events of type T.
    /// </summary>
    /// <param name="subscriptionToken">Token that was used for subscribing.</param>
    /// <typeparam name="T">Type of events to unsubscribe from.</typeparam>
    /// <returns>True if successfully unsubscribed otherwise False.</returns>
    bool Unsubscribe<T>(object subscriptionToken)
        where T : class, TBase;
}

/// <summary>
/// Thread-safe channel for receiving events. The channel has a base type of the events you want
/// to receive, for example: 'IEvent' and then you can listen for events that are derrived from
/// that base-type, for example 'NameChangedEvent'.
/// </summary>
/// <typeparam name="TBase">Base type that events need to derrive from.</typeparam>
public sealed class SynchronizedChannel<TBase> : IReadOnlySynchronizedChannel<TBase>, IDisposable
    where TBase : class
{
    private readonly ConcurrentDictionary<Type, SynchronizedEvent<TBase>> events =
        new ConcurrentDictionary<Type, SynchronizedEvent<TBase>>();

    private readonly IExceptionHandler exceptionHandler;
    private readonly bool storeUnobservedData;
    private readonly bool allowSynchronousContinuations;
    private readonly Type baseType;

    private volatile int disposeCount;

    /// <summary>
    /// Initializes a new instance of the <see cref="SynchronizedChannel{TBase}"/> class.
    /// </summary>
    /// <param name="exceptionHandler">Handler for dealing with exceptions during callback invoke.</param>
    /// <param name="storeUnobservedData">Should data be stored when there is no-one listening.</param>
    /// <param name="allowSynchronousContinuations">
    /// Are task-continuations allowed to execute synchronously.
    /// Use with caution, for more info see docs on <see cref="SynchronizedEvent{T}"/>.
    /// </param>
    public SynchronizedChannel(
        IExceptionHandler exceptionHandler,
        bool storeUnobservedData = true,
        bool allowSynchronousContinuations = false)
    {
        this.exceptionHandler = exceptionHandler ?? throw new ArgumentNullException(nameof(exceptionHandler));
        this.storeUnobservedData = storeUnobservedData;
        this.allowSynchronousContinuations = allowSynchronousContinuations;
        this.baseType = typeof(TBase);
    }

    /// <inheritdoc/>
    public Task<T> WaitAsync<T>(CancellationToken cancelToken = default)
        where T : class, TBase
    {
        var type = typeof(T);
        if (type == this.baseType)
            throw new ArgumentException("T cannot be equal to channel base-type", nameof(T));

        if (this.disposeCount != 0)
            throw new ObjectDisposedException(nameof(SynchronizedChannel<TBase>));

        // If cancellation is already requested then early out.
        if (cancelToken.IsCancellationRequested)
            return Task.FromCanceled<T>(cancelToken);

        var waitTask = this.GetEvent(type).WaitAsync(cancelToken);

        /* Note: The 'ContinueWith' is required to get a task of the expected type, unfortunately
        this has to allocate a new task, but we reuse the delegate and (attempt to) execute it
        synchronously so the overhead should be minimal.*/
        var result = waitTask.ContinueWith(
            CastHelper<T>.ContinueCast,
            waitTask.AsyncState,
            TaskContinuationOptions.OnlyOnRanToCompletion | TaskContinuationOptions.ExecuteSynchronously);

        // Its important to perserve the 'AsyncState' as 'SynchronizedEvent' uses it for optimizations.
        Debug.Assert(result.AsyncState == waitTask.AsyncState, "AsyncState was not perserved");

        return result;
    }

    /// <inheritdoc/>
    public void Subscribe<T>(Action<T> action, object subscriptionToken = null, bool callOnCapturedContext = true)
        where T : class, TBase
    {
        if (action is null)
            throw new ArgumentNullException(nameof(action));

        var type = typeof(T);
        if (type == this.baseType)
            throw new ArgumentException("T cannot be equal to channel base-type", nameof(T));

        if (this.disposeCount != 0)
            throw new ObjectDisposedException(nameof(SynchronizedChannel<TBase>));

        this.GetEvent(type).Subscribe(
            action: CastHelper<T>.CastInvoke,
            state: action,
            subscriptionToken: subscriptionToken ?? action,
            callOnCapturedContext);
    }

    /// <inheritdoc/>
    public bool Unsubscribe<T>(Action<T> action)
        where T : class, TBase
    {
        if (action is null)
            throw new ArgumentNullException(nameof(action));

        var type = typeof(T);
        if (type == this.baseType)
            throw new ArgumentException("T cannot be equal to channel base-type", nameof(T));

        // No need to unsubscribe when we've disposed.
        if (this.disposeCount > 0)
            return false;

        return this.GetEvent(type).Unsubscribe(action);
    }

    /// <inheritdoc/>
    public bool Unsubscribe<T>(object subscriptionToken)
        where T : class, TBase
    {
        if (subscriptionToken is null)
            throw new ArgumentNullException(nameof(subscriptionToken));

        var type = typeof(T);
        if (type == this.baseType)
            throw new ArgumentException("T cannot be equal to channel base-type", nameof(T));

        // No need to unsubscribe when we've disposed.
        if (this.disposeCount > 0)
            return false;

        return this.GetEvent(type).Unsubscribe(subscriptionToken);
    }

    /// <summary>
    /// Invoke a event.
    /// </summary>
    /// <remarks>
    /// Type of data needs to derrive from 'BaseT' but cannot be 'BaseT' itself.
    /// </remarks>
    /// <param name="data">Event to invoke.</param>
    public void Invoke(TBase data)
    {
        if (data is null)
            throw new ArgumentNullException(nameof(data));

        var type = data.GetType();
        if (type == this.baseType)
            throw new ArgumentException("Type of data cannot be equal to channel base-type", nameof(data));

        if (this.disposeCount != 0)
            throw new ObjectDisposedException(nameof(SynchronizedChannel<TBase>));

        this.GetEvent(type).Invoke(data);
    }

    /// <inheritdoc/>
    public void Dispose()
    {
        // Using 'Interlocked.Increment' to make sure we only dispose once even when called concurrently.
        if (Interlocked.Increment(ref this.disposeCount) == 1)
        {
            // Dispose all events.
            foreach (var evt in this.events)
                evt.Value.Dispose();

            this.events.Clear();
        }
    }

    private SynchronizedEvent<TBase> GetEvent(Type type)
    {
        // First attempt to get a previously created event.
        if (this.events.TryGetValue(type, out var evt))
            return evt;

        // Create a new event.
        evt = new SynchronizedEvent<TBase>(
            this.exceptionHandler,
            this.storeUnobservedData,
            this.allowSynchronousContinuations);

        if (this.events.TryAdd(type, evt))
            return evt;

        /* If there was a race condition and another event was created right before then we discard
        the one we've just created. */
        evt.Dispose();

        // Get the one that the other thread just created.
        if (this.events.TryGetValue(type, out evt))
            return evt;

        /* Getting here should be super rare but its possible that the channel was disposed while we
        where creating a new event. */
        if (this.disposeCount != 0)
            throw new ObjectDisposedException(nameof(SynchronizedChannel<TBase>));

        Debug.Fail("Failed to create a new event");
        return null;
    }

    private static class CastHelper<T>
        where T : class, TBase
    {
        // Cache the delegate here avoid having to create a new one for each call.
        internal static Action<TBase, object> CastInvoke = (TBase data, object target) =>
        {
            Debug.Assert(target != null && target is Action<T>, "Invalid target received");
            ((Action<T>)target).Invoke(data as T);
        };

        // Cache the delegate here avoid having to create a new one for each call.
        internal static Func<Task<TBase>, object, T> ContinueCast = (Task<TBase> task, object state) =>
        {
            Debug.Assert(task != null && task.IsCompleted, "Invalid task");
            return task.Result as T;
        };
    }
}

Extension method for waiting either an event of type A or an event of type B:

using System;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;

/// <summary>
/// Extensions for the <see cref="IReadOnlySynchronizedChannel{TBase}"/> interface.
/// </summary>
public static class ReadOnlySynchronizedChannelExtensions
{
    /// <summary>
    /// Wait for a event of type TA or of type TB.
    /// </summary>
    /// <param name="channel">Chanel to listen for the events on.</param>
    /// <param name="cancelToken">Token to be able to cancel the task.</param>
    /// <typeparam name="TBase">Base type of the events to wait for.</typeparam>
    /// <typeparam name="TA">Type of event to wait for.</typeparam>
    /// <typeparam name="TB">Type of event to wait for.</typeparam>
    /// <returns>Task that completes when an event either type is received or is cancelled.</returns>
    public static Task<TBase> WaitAnyAsync<TBase, TA, TB>(
        this IReadOnlySynchronizedChannel<TBase> channel, CancellationToken cancelToken = default)
        where TBase : class
        where TA : class, TBase
        where TB : class, TBase
    {
        if (channel is null)
            throw new ArgumentNullException(nameof(channel));

        var handle = new WaitAnyHandle<TBase, TA, TB>(channel, cancelToken);
        return handle.ConstructWaitForAny();
    }

    private class WaitAnyHandle<TBase, TA, TB>
        where TBase : class
        where TA : class, TBase
        where TB : class, TBase
    {
        private static Func<Task<Task>, object, TBase> getResult = (Task<Task> task, object state) =>
        {
            Debug.Assert(task != null, "No finishing task received");
            Debug.Assert(task.IsCompleted, "Finishing task was not completed");
            Debug.Assert(state != null && state is WaitAnyHandle<TBase, TA, TB>, "Invalid state received");

            var handle = (WaitAnyHandle<TBase, TA, TB>)state;
            var finishingTask = task.Result;

            Debug.Assert(finishingTask == handle.waitA || finishingTask == handle.waitB, "Unexpected finishing task");

            /* We throw here as otherwise 'Result' will throw and it wraps the
            'OperationCanceledException' in a 'System.AggregateException' which is more annoying
            to catch for the user. */
            if (finishingTask.IsCanceled)
                throw new OperationCanceledException();

            /* This is kinda cheating as it assumes that 'IReadOnlySynchronizedChannel<TBase>' is
            always backed by 'SynchronizedEvent<T>', but doing it like this avoids us having to
            create two more 'CancellationTokenSource' (1 for us and 1 linked wrapper). */

            if (finishingTask == handle.waitA)
            {
                SynchronizedEvent<TBase>.TryCancel(handle.waitB);
                return handle.waitA.Result;
            }
            else
            {
                SynchronizedEvent<TBase>.TryCancel(handle.waitA);
                return handle.waitB.Result;
            }
        };

        private readonly Task<TA> waitA;
        private readonly Task<TB> waitB;

        internal WaitAnyHandle(IReadOnlySynchronizedChannel<TBase> channel, CancellationToken cancelToken)
        {
            Debug.Assert(channel != null, "Null channel received");

            // Start listening for both events.
            this.waitA = channel.WaitAsync<TA>(cancelToken);
            this.waitB = channel.WaitAsync<TB>(cancelToken);
        }

        internal Task<TBase> ConstructWaitForAny()
        {
            return Task.WhenAny(this.waitA, this.waitB).
                ContinueWith(getResult, state: this, TaskContinuationOptions.ExecuteSynchronously);
        }
    }
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment