Skip to content

Instantly share code, notes, and snippets.

@StephenCleary
Created October 23, 2023 20:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save StephenCleary/eae3cb18188258b700581b20acab548d to your computer and use it in GitHub Desktop.
Save StephenCleary/eae3cb18188258b700581b20acab548d to your computer and use it in GitHub Desktop.
A System.Threading.Channel<T> that allows custom bounds (more complex than counting items)
using System.Threading.Channels;
using Nito.AsyncEx;
// All members must be safe to call while under lock.
public interface ICustomBounds<in T>
{
bool IsEmpty { get; }
bool IsFull { get; }
void Add(T item);
void Subtract(T item);
}
public sealed class CustomBoundedChannel<T> : Channel<T>
{
public CustomBoundedChannel(UnboundedChannelOptions options, ICustomBounds<T> bounds)
{
_bounds = bounds;
_channel = Channel.CreateUnbounded<T>(options);
_mutex = new();
_completedOrNotFull = new(_mutex);
_completedOrNotEmpty = new(_mutex);
Reader = new ChannelReader(this);
Writer = new ChannelWriter(this);
}
private readonly ICustomBounds<T> _bounds;
private readonly Channel<T> _channel;
private readonly AsyncLock _mutex;
private readonly AsyncConditionVariable _completedOrNotFull;
private readonly AsyncConditionVariable _completedOrNotEmpty;
private sealed class ChannelWriter : ChannelWriter<T>
{
public ChannelWriter(CustomBoundedChannel<T> parent)
{
_parent = parent;
}
public override bool TryWrite(T item)
{
using var key = _parent._mutex.Lock();
if (_parent._bounds.IsFull)
return false;
if (!_parent._channel.Writer.TryWrite(item))
return false;
_parent._bounds.Add(item);
if (!_parent._bounds.IsEmpty)
_parent._completedOrNotEmpty.Notify();
return true;
}
public override bool TryComplete(Exception? error = null)
{
using var key = _parent._mutex.Lock();
if (!_parent._channel.Writer.TryComplete(error))
return false;
_parent._completedOrNotEmpty.NotifyAll();
_parent._completedOrNotFull.NotifyAll();
return true;
}
public override async ValueTask<bool> WaitToWriteAsync(CancellationToken cancellationToken = default)
{
using var key = _parent._mutex.Lock(CancellationToken.None);
while (true)
{
if (_parent._channel.Reader.Completion.IsCompleted)
return false;
if (!_parent._bounds.IsFull)
return true;
await _parent._completedOrNotFull.WaitAsync(cancellationToken).ConfigureAwait(false);
}
}
public override async ValueTask WriteAsync(T item, CancellationToken cancellationToken = default)
{
using var key = _parent._mutex.Lock(CancellationToken.None);
while (true)
{
if (_parent._channel.Reader.Completion.IsCompleted)
throw new ChannelClosedException();
if (!_parent._bounds.IsFull)
{
if (!_parent._channel.Writer.TryWrite(item))
continue;
_parent._bounds.Add(item);
if (!_parent._bounds.IsEmpty)
_parent._completedOrNotEmpty.Notify();
return;
}
await _parent._completedOrNotFull.WaitAsync(cancellationToken).ConfigureAwait(false);
}
}
private readonly CustomBoundedChannel<T> _parent;
}
private sealed class ChannelReader : ChannelReader<T>
{
public ChannelReader(CustomBoundedChannel<T> parent)
{
_parent = parent;
}
public override bool TryRead(out T item)
{
using var key = _parent._mutex.Lock();
if (!_parent._channel.Reader.TryRead(out item!))
return false;
_parent._bounds.Subtract(item);
if (!_parent._bounds.IsFull)
_parent._completedOrNotFull.Notify();
return true;
}
public override async ValueTask<bool> WaitToReadAsync(CancellationToken cancellationToken = default)
{
using var key = _parent._mutex.Lock(CancellationToken.None);
while (true)
{
if (_parent._channel.Reader.Completion.IsCompleted)
return false;
if (!_parent._bounds.IsEmpty)
return true;
await _parent._completedOrNotEmpty.WaitAsync(cancellationToken).ConfigureAwait(false);
}
}
public override async ValueTask<T> ReadAsync(CancellationToken cancellationToken = default)
{
using var key = _parent._mutex.Lock(cancellationToken);
while (true)
{
if (!_parent._bounds.IsEmpty)
{
if (!_parent._channel.Reader.TryRead(out var result))
continue;
_parent._bounds.Subtract(result);
if (!_parent._bounds.IsFull)
_parent._completedOrNotFull.Notify();
return result;
}
if (_parent._channel.Reader.Completion.IsCompleted)
throw new ChannelClosedException();
await _parent._completedOrNotEmpty.WaitAsync(cancellationToken).ConfigureAwait(false);
}
}
public override Task Completion => _parent._channel.Reader.Completion;
public override bool CanCount => _parent._channel.Reader.CanCount;
public override int Count => _parent._channel.Reader.Count;
public override bool CanPeek => _parent._channel.Reader.CanPeek;
public override bool TryPeek(out T item) => _parent._channel.Reader.TryPeek(out item!);
private readonly CustomBoundedChannel<T> _parent;
}
}
@StephenCleary
Copy link
Author

StephenCleary commented Oct 23, 2023

Example channel of strings that applies a bound based on the length of the strings rather than the count of the strings:

public sealed class MyCustomStringBounds : ICustomBounds<string>
{
  private const int MaxCount = 8192; // whatever bound you want applied
  
  public bool IsEmpty => _charCount == 0;
  public bool IsFull => _charCount >= MaxCount;
  public void Add(string item) => _charCount += item.Length;
  public void Subtract(string item) => _charCount -= item.Length;

  private int _charCount;
}

Channel<string> myStringChannel = new CustomBoundedChannel<string>(new(), new MyCustomStringBounds());

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment