Skip to content

Instantly share code, notes, and snippets.

@drub0y
Last active December 2, 2022 15:36
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save drub0y/5681496 to your computer and use it in GitHub Desktop.
Save drub0y/5681496 to your computer and use it in GitHub Desktop.
A .NET TPL TaskScheduler implementation that spins up STA threads running a Dispatcher message pump which enables the execution of logic that leverage System.Windows.* components. Why? Well, without this, the use of System.Windows.* components will create a new Dispatcher instance on the currently executing thread and since that thread is a) not…
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Threading;
using System.Collections.Concurrent;
using System.Windows.Threading;
namespace HackedBrain.Utilities
{
public sealed class DispatcherTaskScheduler : TaskScheduler, IDisposable
{
#region Fields
private static readonly Lazy<DispatcherTaskScheduler> DefaultSchedulerInstance = new Lazy<DispatcherTaskScheduler>(() => new DispatcherTaskScheduler());
private readonly DispatcherWorkerThreadDetails[] dispatcherWorkerThreadDetails;
private readonly Queue<Task> taskQueue = new Queue<Task>();
private readonly int maxDegreeOfParallelism;
private readonly Action tryExecuteAction;
private int nextRoundRobinSelectedWorkerThreadIndex;
#endregion
#region Constructors
public DispatcherTaskScheduler() : this(Environment.ProcessorCount)
{
}
public DispatcherTaskScheduler(int maxDegreeOfParallelism)
{
if(maxDegreeOfParallelism < 1)
{
throw new ArgumentOutOfRangeException("maxDegreeOfParallelism", "The value for maxDegreeOfParallelism cannot be less than 1.");
}
this.maxDegreeOfParallelism = maxDegreeOfParallelism;
this.tryExecuteAction = new Action(this.TryExecuteTaskFromQueue);
this.dispatcherWorkerThreadDetails = new DispatcherWorkerThreadDetails[maxDegreeOfParallelism];
for(int threadIndex = 0; threadIndex < maxDegreeOfParallelism; threadIndex++)
{
this.dispatcherWorkerThreadDetails[threadIndex] = this.StartNewDispatcherWorkerThread();
}
}
#endregion
#region Type specific properties
public static new DispatcherTaskScheduler Default
{
get
{
return DispatcherTaskScheduler.DefaultSchedulerInstance.Value;
}
}
#endregion
#region Base class overrides
public override int MaximumConcurrencyLevel
{
get
{
return this.maxDegreeOfParallelism;
}
}
protected override IEnumerable<Task> GetScheduledTasks()
{
lock(this.taskQueue)
{
return this.taskQueue.ToArray();
}
}
protected override void QueueTask(Task task)
{
lock(this.taskQueue)
{
this.taskQueue.Enqueue(task);
}
Thread targetThread;
// Look for an idle thread to dispatch the work to first
if(!this.TryFindIdleWorkerThread(out targetThread))
{
// No idle threads, use round robin to fairly distribute
targetThread = this.GetNextWorkerThreadViaRoundRobin();
}
Dispatcher.FromThread(targetThread).BeginInvoke(this.tryExecuteAction, DispatcherPriority.Normal);
}
protected override bool TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued)
{
bool result;
if(!taskWasPreviouslyQueued
&&
Dispatcher.FromThread(Thread.CurrentThread) != null)
{
result = this.TryExecuteTask(task);
}
else
{
result = false;
}
return result;
}
#endregion
#region IDisposable implementation
public void Dispose()
{
foreach(DispatcherWorkerThreadDetails dispatcherThreadDetails in this.dispatcherWorkerThreadDetails)
{
// Tell the dispatcher to shut down
Dispatcher.FromThread(dispatcherThreadDetails.Thread).BeginInvokeShutdown(DispatcherPriority.Send);
// Wait for the Dispatcher message pump to exit
dispatcherThreadDetails.Thread.Join();
}
}
#endregion
#region Helper methods
private DispatcherWorkerThreadDetails StartNewDispatcherWorkerThread()
{
Thread dispatcherThread = new Thread(() =>
{
// Start the dispatcher message loop so that it begins processing messages on this thread
Dispatcher.Run();
});
dispatcherThread.Name = "DispatcherTaskScheduler Worker Thread";
dispatcherThread.SetApartmentState(ApartmentState.STA);
dispatcherThread.IsBackground = true;
dispatcherThread.Start();
return new DispatcherWorkerThreadDetails(dispatcherThread);
}
private void TryExecuteTaskFromQueue()
{
Task task;
Debug.WriteLine("{0} #{1} attempting to grab work off the task queue...", Thread.CurrentThread.Name, Thread.CurrentThread.ManagedThreadId);
// Try to aquire the task queue lock lock
lock(this.taskQueue)
{
// Attempt to get a new task from the queue to process
if(this.taskQueue.Count > 0)
{
task = this.taskQueue.Dequeue();
}
else
{
task = null;
}
}
if(task != null)
{
// Attempt to execute the task
this.TryExecuteTask(task);
}
else
{
Debug.WriteLine("{0} #{1} didn't find a task to work on which would indicate another thread must have beat it to the dequeue and there was no work left.", Thread.CurrentThread.Name, Thread.CurrentThread.ManagedThreadId);
}
}
private bool TryFindIdleWorkerThread(out Thread targetThread)
{
targetThread = null;
for(int threadIndex = 0; threadIndex < this.dispatcherWorkerThreadDetails.Length; threadIndex++)
{
DispatcherWorkerThreadDetails details = this.dispatcherWorkerThreadDetails[threadIndex];
if(!details.HasWork)
{
targetThread = details.Thread;
break;
}
}
return targetThread != null;
}
private Thread GetNextWorkerThreadViaRoundRobin()
{
int lastLocal;
int nextLocal;
do
{
lastLocal = this.nextRoundRobinSelectedWorkerThreadIndex;
nextLocal = lastLocal + 1;
if(nextLocal == this.dispatcherWorkerThreadDetails.Length)
{
nextLocal = 0;
}
}
while(Interlocked.CompareExchange(ref this.nextRoundRobinSelectedWorkerThreadIndex, nextLocal, lastLocal) != lastLocal);
return this.dispatcherWorkerThreadDetails[lastLocal].Thread;
}
#endregion
#region Nested types
private sealed class DispatcherWorkerThreadDetails
{
public Thread Thread;
public volatile bool HasWork;
public DispatcherWorkerThreadDetails(Thread thread)
{
this.Thread = thread;
Dispatcher dispatcher;
// Get the Dispatcher instance for the thread
do
{
dispatcher = Dispatcher.FromThread(thread);
// The dispatcher might not have been spun up yet so, if not, sleep and try again until it is
if(dispatcher == null)
{
Thread.Sleep(100);
}
} while(dispatcher == null);
// Hook the events that tell us when the Dispatcher is actually doing something
DispatcherHooks dispatcherHooks = dispatcher.Hooks;
dispatcherHooks.DispatcherInactive += (s, a) => this.HasWork = false;
dispatcherHooks.OperationPosted += (s, a) => this.HasWork = true;
}
}
#endregion
}
}
@maxoizs
Copy link

maxoizs commented Sep 8, 2016

I'm not good a task factory or threading, and I was trying to implement your code, but I found a memory leak, Can you provide a unit test as an example please.
I used the bug example on Microsoft connect for this issue as test

public void TestDispatcherWithTaskScheduler()
        {
            int count = 0;
            object sync = new object();
            bool error = false;
            var a = new Action(() =>
            {
                try
                {
                    var richTextBox1 = new RichTextBox(); 
                    lock (sync)
                    {

                        if (count % 100 == 0)
                            Console.WriteLine(count);
                    }
                }
                catch (Exception e)
                {
                    Console.WriteLine(e.ToString());
                    error = true;
                }
            });
            while (!error)
            {
                count++;
                using (var t = new DispatcherTaskScheduler())
                {
                    var fact = new TaskFactory(t);
                    fact.StartNew(a);
                }

            }

            Console.WriteLine(count);
        }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment