Skip to content

Instantly share code, notes, and snippets.

@NDiiong
Forked from gongdo/TaskExtensions.cs
Created April 25, 2022 06:53
Show Gist options
  • Save NDiiong/8e6b2c8ff2eb11cd618ed9dccc134111 to your computer and use it in GitHub Desktop.
Save NDiiong/8e6b2c8ff2eb11cd618ed9dccc134111 to your computer and use it in GitHub Desktop.
Task extensions for parallelism
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
namespace Gist
{
public static class TaskExtensions
{
/// <summary>
/// Creates a task that will complete when all of the <see cref="Task{TResult}"/> objects in an enumerable collection have completed.
/// It runs only specified number of tasks at most in parallel.
/// </summary>
/// <param name="tasks">The tasks to wait on for completion.</param>
/// <param name="maxConcurrency">The maximum number of tasks to run at the same time.</param>
/// <remarks>
/// It will skip a null in the tasks, instead to throw an exception.
/// It will run every tasks without throttling, if the maxConcurrency is less than 1.
/// </remarks>
/// <exception cref="ArgumentNullException">The tasks argument was null.</exception>
public static async Task<IEnumerable<TResult>> WhenAll<TResult>(this IEnumerable<Task<TResult>> tasks, int maxConcurrency)
{
if (tasks == null)
{
throw new ArgumentNullException(nameof(tasks));
}
if (maxConcurrency < 1)
{
return await Task.WhenAll(tasks.Where(t => t != null));
}
var tasksToRun = new List<Task<TResult>>();
using (var sem = new SemaphoreSlim(maxConcurrency))
{
var enumerator = tasks.Where(t => t != null).GetEnumerator();
while (true)
{
await sem.WaitAsync();
if (enumerator.MoveNext())
{
var task = enumerator.Current.ContinueWith(t =>
{
sem.Release();
return t.Result;
});
tasksToRun.Add(task);
}
else
{
break;
}
}
return (await Task.WhenAll(tasksToRun)).ToList();
}
}
/// <summary>
/// Creates a task that will complete when all of the <see cref="Task"/> objects in an enumerable collection have completed.
/// It runs only specified number of tasks at most in parallel.
/// </summary>
/// <param name="tasks">The tasks to wait on for completion.</param>
/// <param name="maxConcurrency">The maximum number of tasks to run at the same time.</param>
/// <remarks>
/// It will skip a null in the tasks, instead to throw an exception.
/// It will run every tasks without throttling, if the maxConcurrency is less than 1.
/// </remarks>
/// <exception cref="ArgumentNullException">The tasks argument was null.</exception>
public static async Task WhenAll(this IEnumerable<Task> tasks, int maxConcurrency)
{
if (tasks == null)
{
throw new ArgumentNullException(nameof(tasks));
}
if (maxConcurrency < 1)
{
await Task.WhenAll(tasks.Where(t => t != null));
return;
}
var tasksToRun = new List<Task>();
using (var sem = new SemaphoreSlim(maxConcurrency))
{
var enumerator = tasks.Where(t => t != null).GetEnumerator();
while (true)
{
await sem.WaitAsync();
if (enumerator.MoveNext())
{
var task = enumerator.Current.ContinueWith(t =>
{
sem.Release();
});
tasksToRun.Add(task);
}
else
{
break;
}
}
await Task.WhenAll(tasksToRun);
}
}
}
}
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Xunit;
namespace Gist.Tests
{
public class TaskExtensionsTest
{
[Theory]
[InlineData(1)]
[InlineData(5)]
[InlineData(10)]
public async Task WhenAll_runs_only_specific_number_of_tasks_at_a_time(int maxConcurrency)
{
var running = 0;
var results = new List<int>();
var tasks = Enumerable.Range(0, maxConcurrency * 5).Select(i =>
{
return Task.Run(() => { running++; })
.ContinueWith(task =>
{
results.Add(running);
Task.Delay(15).Wait();
results.Add(running);
})
.ContinueWith(task => { running--; });
});
await tasks.WhenAll(maxConcurrency);
Assert.All(results, r => Assert.True(r <= maxConcurrency));
}
[Theory]
[InlineData(1)]
[InlineData(5)]
[InlineData(10)]
public async Task WhenAll_with_result_runs_only_specific_number_of_tasks_at_a_time(int maxConcurrency)
{
var running = 0;
var tasks = Enumerable.Range(0, maxConcurrency * 5).Select(i =>
{
return Task.Run(() => { running++; })
.ContinueWith(task =>
{
Task.Delay(15).Wait();
return running;
})
.ContinueWith(task => { running--; return task.Result; });
});
var results = await tasks.WhenAll(maxConcurrency);
Assert.All(results, r => Assert.True(r <= maxConcurrency));
}
[Fact]
public async Task WhenAll_ignores_null_task()
{
var count = 0;
var tasks = new Task[]
{
Task.Run(() => count++),
Task.Run(() => count++),
null };
await tasks.WhenAll(3);
Assert.Equal(2, count);
}
[Fact]
public async Task WhenAll_with_result_ignores_null_task()
{
var tasks = new Task<int>[]
{
Task.Run(() => 1),
Task.Run(() => 1),
null };
var results = await tasks.WhenAll(3);
Assert.All(results, r => Assert.Equal(1, r));
}
[Fact]
public async Task WhenAll_has_guard_clause()
{
Task[] tasks = null;
await Assert.ThrowsAsync<System.ArgumentNullException>(
() => tasks.WhenAll(3));
}
[Fact]
public async Task WhenAll_with_result_has_guard_clause()
{
Task<int>[] tasks = null;
await Assert.ThrowsAsync<System.ArgumentNullException>(
() => tasks.WhenAll(3));
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment