Last active
July 27, 2022 16:41
-
-
Save rquackenbush/96157ec6940ea1dd33774670cdc4786c to your computer and use it in GitHub Desktop.
A simple batching bit of code for IAsyncEnumerable.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System.Runtime.CompilerServices; | |
namespace Extensions.Core; | |
public static class AsyncEnumerableExtensions | |
{ | |
public static async IAsyncEnumerable<T[]> BatchAsync<T>( | |
this IAsyncEnumerable<T> source, | |
int batchSize, | |
[EnumeratorCancellation] CancellationToken cancellationToken = default) | |
{ | |
var batch = new List<T>(batchSize); | |
await foreach (var item in source.WithCancellation(cancellationToken)) | |
{ | |
if (cancellationToken.IsCancellationRequested) | |
yield break; | |
batch.Add(item); | |
if (batch.Count >= batchSize) | |
{ | |
yield return batch.ToArray(); | |
batch.Clear(); | |
} | |
} | |
if (batch.Count > 0) | |
yield return batch.ToArray(); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System.Runtime.CompilerServices; | |
using Shouldly; | |
namespace Extensions.Tests; | |
public class AsyncEnumerableTests | |
{ | |
[Fact] | |
public async Task SimpleBatches() | |
{ | |
var batches = new List<int[]>(); | |
await foreach (var batch in RangeAsync(0, 10).BatchAsync(4)) | |
{ | |
batches.Add(batch); | |
} | |
batches.ShouldBe(new List<int[]> | |
{ | |
new []{ 0, 1, 2, 3}, | |
new []{ 4, 5, 6, 7}, | |
new []{ 8, 9} | |
}); | |
} | |
[Fact] | |
public async Task WithCancellation() | |
{ | |
var batches = new List<int[]>(); | |
var cts = new CancellationTokenSource(); | |
await foreach (var batch in RangeAsync(0, 10).BatchAsync(4, cts.Token)) | |
{ | |
batches.Add(batch); | |
cts.Cancel(); | |
} | |
batches.ShouldBe(new List<int[]> | |
{ | |
new []{ 0, 1, 2, 3} | |
}); | |
} | |
[Fact] | |
public async Task SimpleCancellation() | |
{ | |
var cts = new CancellationTokenSource(); | |
var results = new List<int>(); | |
int index = 0; | |
await foreach (var item in RangeAsync(0, 10, cts.Token)) | |
{ | |
results.Add(item); | |
if (index == 1) | |
cts.Cancel(); | |
index++; | |
} | |
Assert.Equal(2, results.Count); | |
} | |
private static async IAsyncEnumerable<int> RangeAsync( | |
int start, | |
int count, | |
[EnumeratorCancellation] CancellationToken cancellationToken = default) | |
{ | |
for (int i = 0; i < count; i++) | |
{ | |
if (cancellationToken.IsCancellationRequested) | |
yield break; | |
await Task.Delay(1, cancellationToken); | |
yield return start + i; | |
} | |
} | |
} | |
public static class AsyncEnumerableExtensions | |
{ | |
public static async IAsyncEnumerable<T[]> BatchAsync<T>( | |
this IAsyncEnumerable<T> source, | |
int batchSize, | |
[EnumeratorCancellation]CancellationToken cancellationToken = default) | |
{ | |
var batch = new List<T>(batchSize); | |
await foreach (var item in source.WithCancellation(cancellationToken)) | |
{ | |
if (cancellationToken.IsCancellationRequested) | |
yield break; | |
batch.Add(item); | |
if (batch.Count >= batchSize) | |
{ | |
yield return batch.ToArray(); | |
batch.Clear(); | |
} | |
} | |
if (batch.Count > 0) | |
yield return batch.ToArray(); | |
} | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment