Skip to content

Instantly share code, notes, and snippets.

@dfederm
Last active December 4, 2023 05:02
Show Gist options
  • Save dfederm/445e971abf5340ab4b5b9ec8ef41a460 to your computer and use it in GitHub Desktop.
Save dfederm/445e971abf5340ab4b5b9ec8ef41a460 to your computer and use it in GitHub Desktop.
Limited Parallelism Work Queue
const int Parallelism = 3;
const int NumItems = 20;
TimeSpan ProcessingTime = TimeSpan.FromMilliseconds(100);
Console.WriteLine("WorkQueue starting");
await using (WorkQueue workQueue1 = new(Parallelism))
{
List<Task> tasks = new(NumItems);
for (int i = 0; i < NumItems; i++)
{
string str = $"Processing Item {i}";
tasks.Add(workQueue1.EnqueueWorkAsync(
async cancellationToken =>
{
await ProcessAsync(cancellationToken);
Console.WriteLine(str);
},
CancellationToken.None));
}
await Task.WhenAll(tasks);
}
Console.WriteLine("WorkQueue done");
Console.WriteLine();
Console.WriteLine("WorkQueue<T> starting");
await using (WorkQueue<string, int> workQueue2 = new(ProcessStringAsync, Parallelism))
{
List<Task> tasks = new(NumItems);
for (int i = 0; i < NumItems; i++)
{
tasks.Add(workQueue2.EnqueueWorkAsync($"Processing Item {i}", CancellationToken.None));
}
await Task.WhenAll(tasks);
}
Console.WriteLine("WorkQueue<T> done");
async Task<int> ProcessStringAsync(string str, CancellationToken cancellationToken)
{
await ProcessAsync(cancellationToken);
Console.WriteLine(str);
return str.GetHashCode();
}
async Task ProcessAsync(CancellationToken cancellationToken)
{
TimeSpan delay = TimeSpan.FromMilliseconds(Random.Shared.Next(100, 500));
await Task.Delay(delay, cancellationToken);
}
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net8.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
</PropertyGroup>
</Project>
using System.Threading.Channels;
public sealed class WorkQueue : IAsyncDisposable
{
private readonly CancellationTokenSource _cancellationTokenSource;
private readonly Channel<WorkContext> _channel;
private readonly Task[] _workerTasks;
private readonly record struct WorkContext(Func<CancellationToken, Task> TaskFunc, TaskCompletionSource TaskCompletionSource, CancellationToken CancellationToken);
public WorkQueue()
: this (Environment.ProcessorCount)
{
}
public WorkQueue(int parallelism)
{
_cancellationTokenSource = new CancellationTokenSource();
_channel = Channel.CreateUnbounded<WorkContext>();
// Create a bunch of worker tasks to process the work.
_workerTasks = new Task[parallelism];
for (int i = 0; i < _workerTasks.Length; i++)
{
_workerTasks[i] = Task.Run(
async () =>
{
// Not passing using the cancellation token here as we need to drain the entire channel to ensure we don't leave dangling Tasks.
await foreach (WorkContext context in _channel.Reader.ReadAllAsync())
{
await ProcessWorkAsync(context);
}
});
}
}
public async Task EnqueueWorkAsync(Func<CancellationToken, Task> taskFunc, CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();
TaskCompletionSource taskCompletionSource = new();
CancellationToken linkedToken = cancellationToken.CanBeCanceled
? CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _cancellationTokenSource.Token).Token
: _cancellationTokenSource.Token;
WorkContext context = new(taskFunc, taskCompletionSource, linkedToken);
await _channel.Writer.WriteAsync(context, linkedToken);
await taskCompletionSource.Task;
}
public async ValueTask DisposeAsync()
{
await _cancellationTokenSource.CancelAsync();
_channel.Writer.Complete();
await _channel.Reader.Completion;
await Task.WhenAll(_workerTasks);
_cancellationTokenSource.Dispose();
}
private static async Task ProcessWorkAsync(WorkContext context)
{
if (context.CancellationToken.IsCancellationRequested)
{
context.TaskCompletionSource.TrySetCanceled(context.CancellationToken);
return;
}
try
{
await context.TaskFunc(context.CancellationToken);
context.TaskCompletionSource.TrySetResult();
}
catch (OperationCanceledException ex)
{
context.TaskCompletionSource.TrySetCanceled(ex.CancellationToken);
}
catch (Exception ex)
{
context.TaskCompletionSource.TrySetException(ex);
}
}
}
using System.Threading.Channels;
public sealed class WorkQueue<TInput, TResult> : IAsyncDisposable
{
private readonly Func<TInput, CancellationToken, Task<TResult>> _processFunc;
private readonly CancellationTokenSource _cancellationTokenSource;
private readonly Channel<WorkContext> _channel;
private readonly Task[] _workerTasks;
private readonly record struct WorkContext(TInput Input, TaskCompletionSource<TResult> TaskCompletionSource, CancellationToken CancellationToken);
public WorkQueue(Func<TInput, CancellationToken, Task<TResult>> processFunc)
: this(processFunc, Environment.ProcessorCount)
{
}
public WorkQueue(Func<TInput, CancellationToken, Task<TResult>> processFunc, int parallelism)
{
_processFunc = processFunc;
_cancellationTokenSource = new CancellationTokenSource();
_channel = Channel.CreateUnbounded<WorkContext>();
// Create a bunch of worker tasks to process the work.
_workerTasks = new Task[parallelism];
for (int i = 0; i < _workerTasks.Length; i++)
{
_workerTasks[i] = Task.Run(
async () =>
{
// Not passing using the cancellation token here as we need to drain the entire channel to ensure we don't leave dangling Tasks.
await foreach (WorkContext context in _channel.Reader.ReadAllAsync())
{
await ProcessWorkAsync(context, _cancellationTokenSource.Token);
}
});
}
}
public async Task<TResult> EnqueueWorkAsync(TInput input, CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();
TaskCompletionSource<TResult> taskCompletionSource = new();
CancellationToken linkedToken = cancellationToken.CanBeCanceled
? CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _cancellationTokenSource.Token).Token
: _cancellationTokenSource.Token;
WorkContext context = new(input, taskCompletionSource, linkedToken);
await _channel.Writer.WriteAsync(context, linkedToken);
return await taskCompletionSource.Task;
}
public async ValueTask DisposeAsync()
{
await _cancellationTokenSource.CancelAsync();
_channel.Writer.Complete();
await _channel.Reader.Completion;
await Task.WhenAll(_workerTasks);
_cancellationTokenSource.Dispose();
}
private async Task ProcessWorkAsync(WorkContext context, CancellationToken cancellationToken)
{
if (cancellationToken.IsCancellationRequested)
{
context.TaskCompletionSource.TrySetCanceled(cancellationToken);
return;
}
try
{
TResult result = await _processFunc(context.Input, cancellationToken);
context.TaskCompletionSource.TrySetResult(result);
}
catch (OperationCanceledException ex)
{
context.TaskCompletionSource.TrySetCanceled(ex.CancellationToken);
}
catch (Exception ex)
{
context.TaskCompletionSource.TrySetException(ex);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment