Skip to content

Instantly share code, notes, and snippets.

@theodorzoulias
Last active May 3, 2022 09:07
Show Gist options
  • Save theodorzoulias/715a19e0dc69bd23143826e23d826a83 to your computer and use it in GitHub Desktop.
Save theodorzoulias/715a19e0dc69bd23143826e23d826a83 to your computer and use it in GitHub Desktop.
PressureAwareUnboundedChannel -- https://stackoverflow.com/a/69284386/11178549
using System;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using System.Threading.Tasks.Dataflow;
namespace CustomChannels
{
/**
<example>
<code>
var channel = new PressureAwareUnboundedChannel{Item}(500, 1000);
var subscription = channel.SubscribeForPressureNotifications(underPressure =>
{
if (underPressure) Producer.Pause(); else Producer.Resume();
});
// At this point the Producer is owned by the channel
//...
channel.Writer.Complete();
await channel.Reader.Completion;
await subscription.UnsubscribeAsync();
// At this point the Producer is no longer owned by the channel
</code>
</example>
*/
public sealed class PressureAwareUnboundedChannel<T> : Channel<T>
{
private readonly Channel<T> _channel;
private readonly int _highPressureThreshold;
private readonly int _lowPressureThreshold;
private ImmutableArray<Subscription> _subscribers;
private bool _writerCompleted = false;
private bool _underPressure = false;
private int _count = 0;
public interface ISubscription { public Task UnsubscribeAsync(); }
private class Subscription : ISubscription
{
private enum SecondStep { Include, Skip }
private readonly PressureAwareUnboundedChannel<T> _parent;
private readonly ActionBlock<Task<SecondStep>> _block;
private TaskCompletionSource<SecondStep> _pending;
public Subscription(PressureAwareUnboundedChannel<T> parent,
Func<bool, Task> action, TaskScheduler scheduler = null)
{
_parent = parent;
_block = new ActionBlock<Task<SecondStep>>(async secondStep =>
{
if (Volatile.Read(ref _parent._writerCompleted)) return;
await action(true); // Emit "under pressure"
if (await secondStep == SecondStep.Skip) return;
if (Volatile.Read(ref _parent._writerCompleted)) return;
await action(false); // Emit "pressure released"
}, new()
{
BoundedCapacity = 1,
TaskScheduler = scheduler ?? TaskScheduler.Default
});
}
public void Post(bool underPressure)
{
Debug.Assert(Monitor.IsEntered(_parent._channel));
if (underPressure)
{
var tcs = new TaskCompletionSource<SecondStep>(
TaskCreationOptions.RunContinuationsAsynchronously);
if (_block.Post(tcs.Task))
{
_pending?.TrySetResult(SecondStep.Skip);
_pending = tcs;
}
}
else
{
_pending?.TrySetResult(SecondStep.Include);
_pending = null;
}
}
public void Complete()
{
Debug.Assert(Monitor.IsEntered(_parent._channel));
_block.Complete();
_pending?.TrySetResult(SecondStep.Skip);
_pending = null;
}
public Task UnsubscribeAsync()
{
lock (_parent._channel)
{
int index = _parent._subscribers.IndexOf(this);
if (index >= 0)
{
_parent._subscribers = _parent._subscribers.RemoveAt(index);
_block.Complete();
_pending?.TrySetResult(SecondStep.Skip);
_pending = null;
}
return _block.Completion;
}
}
}
public ISubscription SubscribeForPressureNotifications(Func<bool, Task> action,
TaskScheduler scheduler = null)
{
lock (_channel)
{
var subscriber = new Subscription(this, action, scheduler);
if (_writerCompleted)
subscriber.Complete();
else
_subscribers = _subscribers.Add(subscriber);
return subscriber;
}
}
public ISubscription SubscribeForPressureNotifications(Action<bool> action,
TaskScheduler scheduler = null)
=> SubscribeForPressureNotifications(
e => { action(e); return Task.CompletedTask; }, scheduler);
public PressureAwareUnboundedChannel(int lowPressureThreshold,
int highPressureThreshold, TaskScheduler eventTaskScheduler = null)
{
if (highPressureThreshold < lowPressureThreshold)
throw new ArgumentOutOfRangeException(nameof(highPressureThreshold));
if (lowPressureThreshold < 0)
throw new ArgumentOutOfRangeException(nameof(lowPressureThreshold));
_highPressureThreshold = highPressureThreshold;
_lowPressureThreshold = lowPressureThreshold;
_channel = Channel.CreateUnbounded<T>();
_subscribers = ImmutableArray.Create<Subscription>();
this.Writer = new ChannelWriter(this);
this.Reader = new ChannelReader(this);
}
private class ChannelWriter : ChannelWriter<T>
{
private readonly PressureAwareUnboundedChannel<T> _parent;
public ChannelWriter(PressureAwareUnboundedChannel<T> parent)
=> _parent = parent;
public override bool TryComplete(Exception error = null)
{
bool success = _parent._channel.Writer.TryComplete(error);
if (success) _parent.Complete();
return success;
}
public override bool TryWrite(T item)
{
bool success = _parent._channel.Writer.TryWrite(item);
if (success) _parent.SignalWriteOrRead(1);
return success;
}
public override ValueTask<bool> WaitToWriteAsync(
CancellationToken cancellationToken = default)
=> _parent._channel.Writer.WaitToWriteAsync(cancellationToken);
}
private class ChannelReader : ChannelReader<T>
{
private readonly PressureAwareUnboundedChannel<T> _parent;
public ChannelReader(PressureAwareUnboundedChannel<T> parent)
=> _parent = parent;
public override Task Completion => _parent._channel.Reader.Completion;
public override bool CanCount => true;
public override int Count => Volatile.Read(ref _parent._count);
public override bool TryRead(out T item)
{
bool success = _parent._channel.Reader.TryRead(out item);
if (success) _parent.SignalWriteOrRead(-1);
return success;
}
public override ValueTask<bool> WaitToReadAsync(
CancellationToken cancellationToken = default)
=> _parent._channel.Reader.WaitToReadAsync(cancellationToken);
}
private void Complete()
{
lock (_channel)
{
if (_writerCompleted) return;
_writerCompleted = true;
foreach (var subscriber in _subscribers) subscriber.Complete();
_subscribers = _subscribers.Clear();
}
}
private void SignalWriteOrRead(int countDelta)
{
lock (_channel)
{
_count += countDelta;
bool underPressure;
if (_count > _highPressureThreshold)
underPressure = true;
else if (_count <= _lowPressureThreshold)
underPressure = false;
else
return;
if (underPressure == _underPressure) return;
_underPressure = underPressure;
foreach (var subscriber in _subscribers) subscriber.Post(underPressure);
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment