Skip to content

Instantly share code, notes, and snippets.

@Brokolis
Created February 24, 2023 10:29
Show Gist options
  • Save Brokolis/8ca83b4516bb572760bd604b98991117 to your computer and use it in GitHub Desktop.
Save Brokolis/8ca83b4516bb572760bd604b98991117 to your computer and use it in GitHub Desktop.
Limiting the number of tests executed in parallel in xUnit

In reference to: xunit/xunit#2003

Summary

I've run into this problem and decided to try to fix it myself. This is my code, which adds a way to selectively, in an opt-in way, decide which tests (or rather test collections) have which concurrency requirements.

Note: my knowledge about xUnit and how it works might be wrong and my code most likely contains bugs. I made it work for myself, for anyone else trying to use my code, try to understand it and adapt it to your use case.

Since all tests in the same collection are always executed serially, parallelization in xUnit executes collections concurrently, but the tests in the same collection run one after the other, without being parallelized. What I implemented is a way to mark which collections should have their concurrency limited, and optionally by how much (you can specify how many collections can be executed concurrently). This effectively solves the concurrency problems discussed, while (hopefully) maintaining full backwards compatibility with how xUnit works, by simply adding an optional feature.

How to "install" my code

Just copy-paste the classes LimitConcurrencyAttribute, LimitedConcurrencyAssemblyRunner, LimitedConcurrencyTestFrameworkExecutor and LimitedConcurrencyTestFramework into your test project and adapt them to your needs. The code is written for .NET 7, but should be fairly easily portable to older .NET versions.

How to enable concurrency limits

The easiest way to enable concurrency limits is to change the xUnit test framework on the test assembly to LimitedConcurrencyTestFramework (put this line in any .cs file, like AssemblyInfo.cs):

// TODO: change the type name and assembly name to match your use case
[assembly: TestFramework("xUnitLimitedConcurrencyTestFramework.LimitedConcurrencyTestFramework", "xUnitLimitedConcurrencyTestFramework")]

Note: don't forget to change the test framework full type name and assembly name to match your case.

The LimitConcurrencyAttribute does not do anything by itself. In order to use it, you must use LimitedConcurrencyAssemblyRunner. There is no direct way to swap out just the runner, so I created a LimitedConcurrencyTestFrameworkExecutor and LimitedConcurrencyTestFramework for that.

How do I know concurrency limits are enabled

LimitedConcurrencyTestFramework and LimitedConcurrencyAssemblyRunner write out some diagnostic messages while they are running. To see those messages (for example when running dotnet test), you have to enable xUnit diagnostic messages. For example, by adding an xunit.runner.json file to the project:

{
  "$schema": "https://xunit.net/schema/current/xunit.runner.schema.json",
  "diagnosticMessages": true
}

And adding this line to your test project .csproj:

<ItemGroup>
  <None Update="xunit.runner.json" CopyToOutputDirectory="Always" />
</ItemGroup>

Running the example code, with diagnostic messages enabled, I get this output:

$ dotnet test
  Determining projects to restore...
  All projects are up-to-date for restore.
  xUnitLimitedConcurrencyTestFramework -> /home/mantas/projects/xUnitLimitedConcurrencyTestFramework/xUnitLimitedConcurrencyTestFramework/bin/Debug/net7.0/xUnitLimitedConcurrencyTestFramework.dll
Test run for /home/mantas/projects/xUnitLimitedConcurrencyTestFramework/xUnitLimitedConcurrencyTestFramework/bin/Debug/net7.0/xUnitLimitedConcurrencyTestFramework.dll (.NETCoreApp,Version=v7.0)
Microsoft (R) Test Execution Command Line Tool Version 17.4.0+c02ece877c62577810f893c44279ce79af820112 (x64)
Copyright (c) Microsoft Corporation.  All rights reserved.

Starting test execution, please wait...
A total of 1 test files matched the specified pattern.
[xUnit.net 00:00:00.59] xUnitLimitedConcurrencyTestFramework: Using xUnitLimitedConcurrencyTestFramework.LimitedConcurrencyTestFramework.
[xUnit.net 00:00:00.67] xUnitLimitedConcurrencyTestFramework: [xUnitLimitedConcurrencyTestFramework.LimitedConcurrencyAssemblyRunner] Setting up the synchronization context...
[xUnit.net 00:00:00.68] xUnitLimitedConcurrencyTestFramework: [xUnitLimitedConcurrencyTestFramework.LimitedConcurrencyAssemblyRunner] Collecting collections...
[xUnit.net 00:00:00.71] xUnitLimitedConcurrencyTestFramework: [xUnitLimitedConcurrencyTestFramework.LimitedConcurrencyAssemblyRunner] Executing collections...
[xUnit.net 00:00:00.71] xUnitLimitedConcurrencyTestFramework: [xUnitLimitedConcurrencyTestFramework.LimitedConcurrencyAssemblyRunner] Starting test collection 1/7...
[xUnit.net 00:00:00.71] xUnitLimitedConcurrencyTestFramework: [xUnitLimitedConcurrencyTestFramework.LimitedConcurrencyAssemblyRunner] Starting test collection 2/7...
[xUnit.net 00:00:00.71] xUnitLimitedConcurrencyTestFramework: [xUnitLimitedConcurrencyTestFramework.LimitedConcurrencyAssemblyRunner] Starting test collection 3/7...
[xUnit.net 00:00:01.79] xUnitLimitedConcurrencyTestFramework: [xUnitLimitedConcurrencyTestFramework.LimitedConcurrencyAssemblyRunner] Starting test collection 4/7...
[xUnit.net 00:00:02.80] xUnitLimitedConcurrencyTestFramework: [xUnitLimitedConcurrencyTestFramework.LimitedConcurrencyAssemblyRunner] Starting test collection 5/7...
[xUnit.net 00:00:03.81] xUnitLimitedConcurrencyTestFramework: [xUnitLimitedConcurrencyTestFramework.LimitedConcurrencyAssemblyRunner] Starting test collection 6/7...
[xUnit.net 00:00:05.82] xUnitLimitedConcurrencyTestFramework: [xUnitLimitedConcurrencyTestFramework.LimitedConcurrencyAssemblyRunner] Starting test collection 7/7...
[xUnit.net 00:00:05.82] xUnitLimitedConcurrencyTestFramework: [xUnitLimitedConcurrencyTestFramework.LimitedConcurrencyAssemblyRunner] Waiting for the last 1 test collections to end...

Passed!  - Failed:     0, Passed:     8, Skipped:     0, Total:     8, Duration: 5 s - xUnitLimitedConcurrencyTestFramework.dll (net7.0)

The line Using xUnitLimitedConcurrencyTestFramework.LimitedConcurrencyTestFramework. indicates that the concurrency limits are enabled. The other messages are just diagnostic messages to keep track of how the concurrency is doing.

How to limit concurrency

To limit concurrency I created a LimitConcurrencyAttribute. You can place the attribute on (in order of precedence):

  1. Test methods ([Fact], [Theory], etc..)
  2. Test classes
  3. Collection definitions ([CollectionDefinition])
  4. Assembly ([assembly: ...])

The LimitConcurrencyAttribute attribute also has a property MaxConcurrentCollections. This property determines how many collections can be executed concurrently in total.

To determine what concurrency limits have been applied to a single collection, concurrency for each test case (test method) in the collection is determined. Then, the lowest concurrency value of all of the test cases is used.

To determine what concurrency limits have been applied to a single test case (test method), LimitConcurrencyAttribute is scanned for according to precedence rules defined above.

For example, if we have a [CollectionDefinition("collection1")] class which also has [LimitConcurrency(4)], but then a test class using that collection definition [Collection("collection1")] has a method [Fact] with [LimitConcurrency(2)], the concurrency limit for the whole collection "collection1" will be set to 2.

Customization

Don't be afraid to simply inspect what the LimitedConcurrencyAssemblyRunner does in order to adapt it more to your use case.

If you just want to customize how to determine concurrency limits for a collection, you can customize the LimitedConcurrencyAssemblyRunner.GetMaxConcurrentCollectionsCount method. By default, that method just looks for the LimitConcurrencyAttribute, but this can be customized, for example, to look if a class of the method implements an interface, like IClassFixture<WebApplicationFactory<Startup>>.

using Xunit;
using xUnitLimitedConcurrencyTestFramework;
/* Example usage */
// TODO: change the type name and assembly name to match your use case
[assembly: TestFramework("xUnitLimitedConcurrencyTestFramework.LimitedConcurrencyTestFramework", "xUnitLimitedConcurrencyTestFramework")]
// limit concurrent collections in the whole assembly to MaxParallelThreads
[assembly: LimitConcurrency]
// limit concurrent collections in the whole assembly to 4
// [assembly: LimitConcurrency(maxConcurrentCollections: 4)]
namespace xUnitLimitedConcurrencyTestFramework;
public static class Consts
{
public static readonly TimeSpan DefaultTestSleep = TimeSpan.FromSeconds(1);
}
/* Collection definitions */
[CollectionDefinition(nameof(CollectionWithoutLimits))]
public class CollectionWithoutLimits
{
}
[CollectionDefinition(nameof(CollectionWithParallelizationDisabled), DisableParallelization = true)]
public class CollectionWithParallelizationDisabled
{
}
[CollectionDefinition(nameof(CollectionWithParallelizationDisabledUsingLimit))]
// attribute on collection definition overrides assembly-level attribute
[LimitConcurrency(maxConcurrentCollections: 1)] // same as `DisableParallelization = true`
public class CollectionWithParallelizationDisabledUsingLimit
{
}
[CollectionDefinition(nameof(CollectionWithParallelizationLimits))]
[LimitConcurrency(maxConcurrentCollections: 4)]
public class CollectionWithParallelizationLimits
{
}
/* Test classes */
public class TestClassWithoutCollection
{
[Fact]
public async Task Test()
{
await Task.Delay(Consts.DefaultTestSleep);
}
}
// attribute on test class overrides collection definition-level attribute
[LimitConcurrency(2)]
public class TestClassWithoutCollectionWithLimits
{
[Fact]
public async Task Test()
{
await Task.Delay(Consts.DefaultTestSleep);
}
}
// This test class will have it's concurrency set to 1, because Test1 has the lowest concurrency defined.
public class TestClassWithoutCollectionWithLimitedTests
{
// attribute on test method overrides class-level attribute
[Fact]
[LimitConcurrency(1)]
public async Task Test1()
{
await Task.Delay(Consts.DefaultTestSleep);
}
[Fact]
[LimitConcurrency(2)]
public async Task Test2()
{
await Task.Delay(Consts.DefaultTestSleep);
}
}
[Collection(nameof(CollectionWithoutLimits))]
public class TestClassWithCollectionWithoutLimits
{
[Fact]
public async Task Test()
{
await Task.Delay(Consts.DefaultTestSleep);
}
}
[Collection(nameof(CollectionWithParallelizationDisabled))]
public class TestClassWithCollectionWithParallelizationDisabled
{
[Fact]
public async Task Test()
{
await Task.Delay(Consts.DefaultTestSleep);
}
}
[Collection(nameof(CollectionWithParallelizationDisabledUsingLimit))]
public class TestClassWithCollectionWithParallelizationDisabledUsingLimit
{
[Fact]
public async Task Test()
{
await Task.Delay(Consts.DefaultTestSleep);
}
}
[Collection(nameof(CollectionWithParallelizationLimits))]
public class TestClassWithCollectionWithParallelizationLimits
{
[Fact]
public async Task Test()
{
await Task.Delay(Consts.DefaultTestSleep);
}
}
namespace xUnitLimitedConcurrencyTestFramework;
/// <summary>
/// <para>
/// Marks a test, a class with tests, a collection definition or the whole assembly to be limited in terms of how many individual collections can be executed concurrently.
/// </para>
///
/// <para>
/// By default xUnit will run as many collections as possible concurrently (without considering <c>MaxParallelThreads</c> or any other options).
/// This attribute marks collections so that only a limited number of collections can be executed concurrently.
/// </para>
/// </summary>
[AttributeUsage(AttributeTargets.Method | AttributeTargets.Class | AttributeTargets.Assembly)]
public class LimitConcurrencyAttribute : Attribute
{
/// <summary>
/// Declares how many collections concurrently can be running when the collections marked by this attribute are running.
/// If the value is <c>null</c> or less than <c>1</c> defaults to <c>MaxParallelThreads</c>.
/// </summary>
public int? MaxConcurrentCollections { get; }
public LimitConcurrencyAttribute()
{
MaxConcurrentCollections = null;
}
public LimitConcurrencyAttribute(int maxConcurrentCollections)
{
MaxConcurrentCollections = maxConcurrentCollections;
}
}
using Xunit;
using Xunit.Abstractions;
using Xunit.Sdk;
namespace xUnitLimitedConcurrencyTestFramework;
/// <summary>
/// Limits collection concurrency (<see cref="LimitConcurrencyAttribute"/>).
/// </summary>
public class LimitedConcurrencyAssemblyRunner : XunitTestAssemblyRunner
{
public LimitedConcurrencyAssemblyRunner(
ITestAssembly testAssembly,
IEnumerable<IXunitTestCase> testCases,
IMessageSink diagnosticMessageSink,
IMessageSink executionMessageSink,
ITestFrameworkExecutionOptions executionOptions
)
: base(testAssembly, testCases, diagnosticMessageSink, executionMessageSink, executionOptions)
{
}
protected override async Task<RunSummary> RunTestCollectionsAsync(IMessageBus messageBus, CancellationTokenSource cancellationTokenSource)
{
return await RunTestCollectionsByLimitingAsync(messageBus, cancellationTokenSource);
}
private async Task<RunSummary> RunTestCollectionsByLimitingAsync(IMessageBus messageBus, CancellationTokenSource cancellationTokenSource)
{
// Adapted from Xunit.Sdk.XunitTestAssemblyRunner.RunTestCollectionsAsync.
// If parallelization is disabled, don't do anything.
if (ExecutionOptions.DisableParallelizationOrDefault())
{
DiagnosticMessageSink.OnMessage(new DiagnosticMessage($"[{typeof(LimitedConcurrencyAssemblyRunner)}] Parallelization is disabled, running tests normally..."));
return await base.RunTestCollectionsAsync(messageBus, cancellationTokenSource);
}
DiagnosticMessageSink.OnMessage(new DiagnosticMessage($"[{typeof(LimitedConcurrencyAssemblyRunner)}] Setting up the synchronization context..."));
// hack: running base.RunTestCollectionsAsync without any test cases sets up the underlying synchronization context
var testCases = TestCases;
TestCases = Array.Empty<IXunitTestCase>();
await base.RunTestCollectionsAsync(messageBus, cancellationTokenSource);
TestCases = testCases;
Func<Func<Task<RunSummary>>, Task<RunSummary>> collectionRunner;
if (SynchronizationContext.Current is not null)
{
var scheduler = TaskScheduler.FromCurrentSynchronizationContext();
collectionRunner = code => Task.Factory
.StartNew(code, cancellationTokenSource.Token, TaskCreationOptions.DenyChildAttach | TaskCreationOptions.HideScheduler, scheduler)
.Unwrap();
}
else
{
collectionRunner = code => Task.Run(code, cancellationTokenSource.Token);
}
DiagnosticMessageSink.OnMessage(new DiagnosticMessage($"[{typeof(LimitedConcurrencyAssemblyRunner)}] Collecting collections..."));
var runTestCollectionTasks = new List<(Func<Task<RunSummary>> RunTestCollection, int? Concurrency)>();
foreach (var (testCollection, collectionTestCases) in OrderTestCollections())
{
var runTestCollection = () => RunTestCollectionAsync(messageBus, testCollection, collectionTestCases, cancellationTokenSource);
int? collectionConcurrency;
var attr = testCollection.CollectionDefinition?.GetCustomAttributes(typeof(CollectionDefinitionAttribute)).SingleOrDefault();
if (attr?.GetNamedArgument<bool>(nameof(CollectionDefinitionAttribute.DisableParallelization)) is true)
{
collectionConcurrency = 1;
}
else
{
// find the test case with the smallest parallelization and use that
collectionConcurrency = collectionTestCases
.Select(GetMaxConcurrentCollectionsCount)
.Append(null)
.Distinct()
.MinBy(maxConcurrentCollectionsCount => maxConcurrentCollectionsCount is > 0 ? maxConcurrentCollectionsCount.Value : int.MaxValue);
}
runTestCollectionTasks.Add((runTestCollection, collectionConcurrency));
}
// order test collections so that the collections with the biggest concurrency run first
// preserve the order of collections with the same concurrency in relation to each other
runTestCollectionTasks = runTestCollectionTasks
.Select((c, index) => (c, index))
.OrderByDescending(c => c.c.Concurrency is > 0 ? c.c.Concurrency : int.MaxValue)
.ThenBy(c => c.index)
.Select(c => c.c)
.ToList();
DiagnosticMessageSink.OnMessage(new DiagnosticMessage($"[{typeof(LimitedConcurrencyAssemblyRunner)}] Executing collections..."));
var runningCollections = new List<(Task<RunSummary> Task, int? Concurrency)>();
var totalSummary = new RunSummary();
var totalCollections = runTestCollectionTasks.Count;
var collectionNumber = 1;
foreach (var (runTestCollection, concurrency) in runTestCollectionTasks)
{
await WaitForAndCompleteCollectionsAsync(runningCollections, concurrency, totalSummary);
DiagnosticMessageSink.OnMessage(
new DiagnosticMessage($"[{typeof(LimitedConcurrencyAssemblyRunner)}] Starting test collection {collectionNumber}/{totalCollections}...")
);
runningCollections.Add((collectionRunner(runTestCollection), concurrency));
collectionNumber++;
}
DiagnosticMessageSink.OnMessage(
new DiagnosticMessage($"[{typeof(LimitedConcurrencyAssemblyRunner)}] Waiting for the last {runningCollections.Count} test collections to end...")
);
// wait for tasks to complete
await WaitForAndCompleteCollectionsAsync(runningCollections, 1, totalSummary);
return totalSummary;
}
private async ValueTask WaitForAndCompleteCollectionsAsync(
ICollection<(Task<RunSummary> Task, int? Concurrency)> runningCollections,
int? targetConcurrency,
RunSummary totalSummary
)
{
while (!CanQueueCollectionWithTargetConcurrency())
{
await Task.WhenAny(runningCollections.Select(runningCollection => runningCollection.Task));
var finishedCollections = runningCollections.Where(runningCollection => runningCollection.Task.IsCompleted).ToArray();
foreach (var finishedCollection in finishedCollections)
{
totalSummary.Aggregate(await finishedCollection.Task);
runningCollections.Remove(finishedCollection);
}
}
bool CanQueueCollectionWithTargetConcurrency()
{
var currentRunningCollectionsConcurrency = runningCollections
.Select(runningCollection => runningCollection.Concurrency)
.Append(null)
.MinBy(concurrency => concurrency is > 0 ? concurrency.Value : int.MaxValue);
if (currentRunningCollectionsConcurrency is not null && currentRunningCollectionsConcurrency <= runningCollections.Count)
{
// currently running collections do not allow any additional collections to be executed concurrently
return false;
}
// we have room for at least one more collection to be executed concurrently,
// check weather the target concurrency allows for runningCollections.Count + 1 concurrent collections to be executed
return targetConcurrency is null || targetConcurrency > runningCollections.Count;
}
}
/// <summary>
/// Returns how many collections can be executed concurrently together with this test case.
/// </summary>
/// <param name="testCase">The test case.</param>
/// <returns><c>null</c> if concurrency should not be limited, a positive integer indicating how many collections can be executed concurrently otherwise.</returns>
private int? GetMaxConcurrentCollectionsCount(IXunitTestCase testCase)
{
var limitConcurrencyAttributeInfo =
testCase.TestMethod.Method.GetCustomAttributes(typeof(LimitConcurrencyAttribute)).SingleOrDefault()
?? GetAttributeOfClassOrBaseClasses<LimitConcurrencyAttribute>(testCase.TestMethod.TestClass.Class)
?? GetAttributeOfClassOrBaseClasses<LimitConcurrencyAttribute>(testCase.TestMethod.TestClass.TestCollection.CollectionDefinition)
?? testCase.TestMethod.TestClass.TestCollection.TestAssembly.Assembly.GetCustomAttributes(typeof(LimitConcurrencyAttribute)).SingleOrDefault();
if (limitConcurrencyAttributeInfo is null) return null;
var maxConcurrentCollections = limitConcurrencyAttributeInfo.GetNamedArgument<int?>(nameof(LimitConcurrencyAttribute.MaxConcurrentCollections))
?? ExecutionOptions.MaxParallelThreadsOrDefault();
return maxConcurrentCollections <= 0 ? null : maxConcurrentCollections;
}
private static IAttributeInfo? GetAttributeOfClassOrBaseClasses<TAttribute>(ITypeInfo? typeInfo)
{
try
{
while (typeInfo is not null)
{
var attribute = typeInfo.GetCustomAttributes(typeof(TAttribute)).SingleOrDefault();
if (attribute is not null) return attribute;
typeInfo = typeInfo.BaseType;
}
return null;
}
catch
{
// xUnit does not support base type being null so exceptions occur when the top of the hierarchy is reached
return null;
}
}
}
using System.Reflection;
using Xunit.Abstractions;
using Xunit.Sdk;
namespace xUnitLimitedConcurrencyTestFramework;
/// <summary>
/// This xUnit test framework implementation limits how many tests can be executed concurrently (<see cref="LimitConcurrencyAttribute"/>).
/// </summary>
public class LimitedConcurrencyTestFramework : XunitTestFramework
{
public LimitedConcurrencyTestFramework(IMessageSink messageSink) : base(messageSink)
{
messageSink.OnMessage(new DiagnosticMessage($"Using {typeof(LimitedConcurrencyTestFramework)}."));
}
protected override ITestFrameworkExecutor CreateExecutor(AssemblyName assemblyName)
{
return new LimitedConcurrencyTestFrameworkExecutor(assemblyName, SourceInformationProvider, DiagnosticMessageSink);
}
}
using System.Reflection;
using Xunit.Abstractions;
using Xunit.Sdk;
namespace xUnitLimitedConcurrencyTestFramework;
/// <summary>
/// Allows to selectively limit concurrency of specific collections (<see cref="LimitConcurrencyAttribute"/>).
/// </summary>
public class LimitedConcurrencyTestFrameworkExecutor : XunitTestFrameworkExecutor
{
public LimitedConcurrencyTestFrameworkExecutor(AssemblyName assemblyName, ISourceInformationProvider sourceInformationProvider, IMessageSink diagnosticMessageSink)
: base(assemblyName, sourceInformationProvider, diagnosticMessageSink)
{
}
protected override async void RunTestCases(IEnumerable<IXunitTestCase> testCases, IMessageSink executionMessageSink, ITestFrameworkExecutionOptions executionOptions)
{
using var assemblyRunner = new LimitedConcurrencyAssemblyRunner(TestAssembly, testCases, DiagnosticMessageSink, executionMessageSink, executionOptions);
await assemblyRunner.RunAsync();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment