Skip to content

Instantly share code, notes, and snippets.

@bradphelan
Created December 5, 2016 07:42
Show Gist options
  • Save bradphelan/cb4f484fbf6a7f9829de0dd52036fd63 to your computer and use it in GitHub Desktop.
Save bradphelan/cb4f484fbf6a7f9829de0dd52036fd63 to your computer and use it in GitHub Desktop.
// Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
// ParallelExtensionsExtras: https://code.msdn.microsoft.com/ParExtSamples
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
namespace Roslyn.Test.Utilities
{
/// <summary>Provides a scheduler that uses STA threads.</summary>
public sealed class StaTaskScheduler : TaskScheduler, IDisposable
{
/// <summary>Gets a StaTaskScheduler for the current AppDomain.</summary>
/// <remarks>We use a count of 1, because the editor ends up re-using <see cref="System.Windows.Threading.DispatcherObject"/>
/// instances between tests, so we need to always use the same thread for our Sta tests.</remarks>
public static StaTaskScheduler DefaultSta { get; } = new StaTaskScheduler(1);
/// <summary>Stores the queued tasks to be executed by our pool of STA threads.</summary>
private BlockingCollection<Task> _tasks;
/// <summary>The STA threads used by the scheduler.</summary>
private readonly ImmutableArray<Thread> _threads;
public ImmutableArray<Thread> Threads => _threads;
/// <summary>Initializes a new instance of the StaTaskScheduler class with the specified concurrency level.</summary>
/// <param name="numberOfThreads">The number of threads that should be created and used by this scheduler.</param>
public StaTaskScheduler(int numberOfThreads)
{
// Validate arguments
if (numberOfThreads < 1)
throw new ArgumentOutOfRangeException(nameof(numberOfThreads));
// Initialize the tasks collection
_tasks = new BlockingCollection<Task>();
// Create the threads to be used by this scheduler
_threads = Enumerable.Range(0, numberOfThreads).Select(i =>
{
var thread = new Thread(() =>
{
// Continually get the next task and try to execute it.
// This will continue until the scheduler is disposed and no more tasks remain.
foreach (var t in _tasks.GetConsumingEnumerable())
{
if (!TryExecuteTask(t))
{
System.Diagnostics.Debug.Assert(t.IsCompleted, "Can't run, not completed");
}
}
});
thread.Name = "STATestThread";
thread.IsBackground = true;
thread.SetApartmentState(ApartmentState.STA);
return thread;
}).ToImmutableArray();
// Start all of the threads
foreach (var thread in _threads)
{
thread.Start();
}
}
/// <summary>Queues a Task to be executed by this scheduler.</summary>
/// <param name="task">The task to be executed.</param>
protected override void QueueTask(Task task)
{
// Push it into the blocking collection of tasks
_tasks.Add(task);
}
/// <summary>Provides a list of the scheduled tasks for the debugger to consume.</summary>
/// <returns>An enumerable of all tasks currently scheduled.</returns>
protected override IEnumerable<Task> GetScheduledTasks()
{
// Serialize the contents of the blocking collection of tasks for the debugger
return _tasks.ToArray();
}
/// <summary>Determines whether a Task may be inlined.</summary>
/// <param name="task">The task to be executed.</param>
/// <param name="taskWasPreviouslyQueued">Whether the task was previously queued.</param>
/// <returns>true if the task was successfully inlined; otherwise, false.</returns>
protected override bool TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued)
{
// Try to inline if the current thread is STA
return
Thread.CurrentThread.GetApartmentState() == ApartmentState.STA &&
TryExecuteTask(task);
}
/// <summary>Gets the maximum concurrency level supported by this scheduler.</summary>
public override int MaximumConcurrencyLevel
{
get
{
return _threads.Length;
}
}
/// <summary>
/// Cleans up the scheduler by indicating that no more tasks will be queued.
/// This method blocks until all threads successfully shutdown.
/// </summary>
public void Dispose()
{
if (_tasks != null)
{
// Indicate that no new tasks will be coming in
_tasks.CompleteAdding();
// Wait for all threads to finish processing tasks
foreach (var thread in _threads)
thread.Join();
// Cleanup
_tasks.Dispose();
_tasks = null;
}
}
public bool IsAnyQueued()
{
if (_threads.Length != 1 || _threads[0] != Thread.CurrentThread)
{
throw new InvalidOperationException("Operation invalid in this context");
}
return _tasks.Count > 0;
}
}
}
using System;
using System.Reactive;
using System.Reactive.Disposables;
using System.Threading;
using System.Threading.Tasks;
using System.Windows;
using System.Windows.Threading;
using Roslyn.Test.Utilities;
namespace SpecHelper.Wpf
{
public static class STAThread
{
/// <summary>
/// Shows a window and returns an IDisposable to close the window
/// and shutdown the dispatcher.
/// </summary>
/// <param name="window"></param>
/// <returns></returns>
public static IDisposable ShowTemporarily(this Window window)
{
window.Show();
return Disposable.Create(() =>
{
window.Close();
Dispatcher.CurrentDispatcher.InvokeShutdown();
});
}
private static readonly TaskFactory staTaskFactory = new TaskFactory( CancellationToken.None, TaskCreationOptions.DenyChildAttach, TaskContinuationOptions.None, StaTaskScheduler.DefaultSta);
public static Task Run(this Func<Task> func)
{
return staTaskFactory.StartNew(func).Unwrap();
}
public static Task<T> Run<T>(this Func<Task<T>> func)
{
return staTaskFactory.StartNew(func).Unwrap();
}
public static Task Run(this Action func)
{
return staTaskFactory.StartNew(func);
}
public static Task<T> Run<T>(this Func<T> func)
{
return staTaskFactory.StartNew(func);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment