Skip to content

Instantly share code, notes, and snippets.

@StephenCleary
Last active April 15, 2024 10:22
Show Gist options
  • Save StephenCleary/39a2cd0aa3c705a984a4dbbea8275fe9 to your computer and use it in GitHub Desktop.
Save StephenCleary/39a2cd0aa3c705a984a4dbbea8275fe9 to your computer and use it in GitHub Desktop.
Asynchronous cache
using System;
using System.Collections.Generic;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Primitives;
/// <summary>
/// Provides an asynchronous cache with exactly-once creation method semantics and flexible cache entries.
/// </summary>
public sealed class AsyncCache
{
private readonly object _mutex = new object();
private readonly IMemoryCache _cache;
private readonly ILogger<AsyncCache> _logger;
/// <summary>
/// Creates an asynchronous cache wrapping an existing memory cache.
/// </summary>
/// <param name="cache">The memory cache.</param>
public AsyncCache(IMemoryCache cache, ILogger<AsyncCache> logger)
{
_cache = cache;
_logger = logger;
}
/// <summary>
/// Removes an item from the cache.
/// </summary>
/// <param name="key">The key of the item.</param>
public void Remove(object key)
{
using var _ = _logger.BeginDataScope(new {cacheKey = key});
_logger.LogTrace("Removing entry.");
_cache.Remove(key);
}
/// <summary>
/// Removes a specific future from the cache.
/// </summary>
/// <typeparam name="T">The type of the item.</typeparam>
/// <param name="key">The key of the item.</param>
/// <param name="value">The future that has to match the entry.</param>
public bool TryRemove<T>(object key, Task<T> value)
{
_ = value ?? throw new ArgumentNullException(nameof(value));
using var __ = _logger.BeginDataScope(new {cacheKey = key, taskId = value.Id});
lock (_mutex)
{
var existingTask = _cache.TryGetValue(key, out TaskCompletionSource<T> tcs) ? tcs.Task : null;
if (existingTask != value)
{
if (existingTask == null)
_logger.LogTrace("Attempted to remove entry, but it was already removed.");
else
_logger.LogTrace("Attempted to remove entry, but it had already been replaced by {existingTaskId}.", existingTask.Id);
return false;
}
_logger.LogTrace("Removing entry.");
_cache.Remove(key);
return true;
}
}
/// <summary>
/// Removes a specific value from the cache.
/// </summary>
/// <typeparam name="T">The type of the item.</typeparam>
/// <param name="key">The key of the item.</param>
/// <param name="value">The entry.</param>
public bool TryRemove<T>(object key, T value)
{
using var __ = _logger.BeginDataScope(new { cacheKey = key });
lock (_mutex)
{
var existingTask = _cache.TryGetValue(key, out TaskCompletionSource<T> tcs) ? tcs.Task : null;
if (existingTask == null || !existingTask.IsCompletedSuccessfully || !object.Equals(existingTask.Result, value))
{
if (existingTask == null || !existingTask.IsCompletedSuccessfully)
_logger.LogTrace("Attempted to remove entry, but it was already removed.");
else
_logger.LogTrace("Attempted to remove entry, but it had already been replaced by {existingTaskId}.", existingTask.Id);
return false;
}
_logger.LogTrace("Removing entry.");
_cache.Remove(key);
return true;
}
}
/// <summary>
/// Attempts to retrieve an item from the cache.
/// </summary>
/// <typeparam name="T">The type of the item.</typeparam>
/// <param name="key">The key of the item.</param>
/// <param name="task">On return, contains a future item.</param>
public bool TryGet<T>(object key, out Task<T>? task)
{
using var _ = _logger.BeginDataScope(new {cacheKey = key});
task = _cache.TryGetValue(key, out TaskCompletionSource<T> tcs) ? tcs.Task : null;
if (task == null)
_logger.LogTrace("Attempted to retrieve entry, but it was not found.");
else
_logger.LogTrace("Retrieved entry {taskId}.", task.Id);
return task != null;
}
/// <summary>
/// Atomically retrieves or creates a cache item.
/// </summary>
/// <typeparam name="T">The type of the item.</typeparam>
/// <param name="key">The key of the item.</param>
/// <param name="create">An asynchronous creation method. This method will only be invoked once. The creation method may control the cache entry behavior for the resulting value by using its <see cref="ICacheEntry"/> parameter; the <see cref="ICacheEntry.Value"/> member is ignored, but all other members are honored.</param>
/// <returns>A future item.</returns>
public Task<T> GetOrCreateAsync<T>(object key, Func<ICacheEntry, Task<T>> create)
{
using var _ = _logger.BeginDataScope(new {cacheKey = key});
TaskCompletionSource<T> tcs;
CancellationTokenSource cts;
lock (_mutex)
{
if (_cache.TryGetValue(key, out tcs))
{
_logger.LogTrace("GetOrCreate found existing entry {taskId}.", tcs.Task.Id);
return tcs.Task;
}
tcs = new TaskCompletionSource<T>();
using var entry = _cache.CreateEntry(key).SetSize(1);
#pragma warning disable CA2000 // Dispose objects before losing scope
cts = new CancellationTokenSource();
#pragma warning restore CA2000 // Dispose objects before losing scope
entry.Value = tcs;
entry.RegisterPostEvictionCallback((_, __, ___, ____) => cts.Dispose());
entry.AddExpirationToken(new CancellationChangeToken(cts.Token));
_logger.LogTrace("GetOrCreate creating new entry {taskId}.", tcs.Task.Id);
}
InvokeAndPropagateCompletion(create, _cache.CreateEntry(key).SetSize(1), tcs, cts);
return tcs.Task;
}
/// <summary>
/// Invokes the creation method and (asynchronously) updates the cache entry with the results.
/// - If the function succeeds synchronously, the cache entry is updated and the TCS completed by the time this method returns.
/// - If the function fails synchronously, the cache entry is removed and the TCS faulted by the time this method returns.
/// - If the function succeeds asynchronously, the cache entry is updated when the function completes *if* the cache entry has not changed by that time.
/// - If the function faults asynchronously, the cache entry is removed when the function completes *if* the cache entry has not changed by that time.
/// </summary>
/// <typeparam name="T">The type of object created by the <paramref name="create"/> method.</typeparam>
/// <param name="create">The creation method, which may update the cache entry set when the creation method completes. The <see cref="ICacheEntry.Value"/> member is ignored, but all other members are honored.</param>
/// <param name="cacheEntry">The cache entry that will be used to replace the cache entry currently containing <paramref name="tcs"/> if the creation succeeds.</param>
/// <param name="tcs">The task completion source currently in the cache entry. This method will (eventually) complete this task completion source.</param>
/// <param name="cts">The cancellation token source used to evict the current cache entry if the creation method fails.</param>
private async void InvokeAndPropagateCompletion<T>(Func<ICacheEntry, Task<T>> create, ICacheEntry cacheEntry, TaskCompletionSource<T> tcs, CancellationTokenSource cts)
{
try
{
// Asynchronously create the value.
var result = await create(cacheEntry);
// Atomically:
// - Check to see if we're still the one in the cache, and
// - If we are, update the cache entry with a new one having the same TCS value, but including new expiration and other settings from the creation method.
lock (_mutex)
{
// This check is necessary to avoid a race condition where our entry has been removed and re-created.
// In that case, there will be a cache entry but it will not be our cache entry, so we should not replace it.
// Rather, we'll leave the cache as-is (without our entry) and just complete our listeners (via TrySetResult, below).
if (_cache.TryGetValue(cacheEntry.Key, out TaskCompletionSource<T> existingTcs) && existingTcs == tcs)
{
_logger.LogTrace("GetOrCreate creation completed successfully; updating entry {taskId}.", tcs.Task.Id);
using (cacheEntry)
cacheEntry.Value = tcs;
}
else
{
if (existingTcs == null)
_logger.LogTrace("GetOrCreate creation completed successfully, but entry {taskId} has been removed.", tcs.Task.Id);
else
_logger.LogTrace("GetOrCreate creation completed successfully, but entry {taskId} has been replaced by entry {replacementTaskId}.", tcs.Task.Id, existingTcs.Task.Id);
}
}
// Propagate the result to any listeners.
tcs.TrySetResult(result);
}
catch (OperationCanceledException oce)
{
_logger.LogTrace("GetOrCreate creation cancelled; removing entry {taskId}.", tcs.Task.Id);
// Remove the cache entry. This will throw if the cache entry has already been removed.
try { cts.Cancel(); } catch (ObjectDisposedException) { }
// Propagate the cancellation to any listeners.
if (oce.CancellationToken.IsCancellationRequested)
tcs.TrySetCanceled(oce.CancellationToken);
else
tcs.TrySetCanceled();
}
#pragma warning disable CA1031 // Do not catch general exception types
catch (Exception ex)
#pragma warning restore CA1031 // Do not catch general exception types
{
_logger.LogTrace("GetOrCreate creation failed; removing entry {taskId}.", tcs.Task.Id);
// Remove the cache entry. This will throw if the cache entry has already been removed.
try { cts.Cancel(); } catch (ObjectDisposedException) { }
// Propagate the exception to any listeners.
tcs.TrySetException(ex);
}
}
}
@molinch
Copy link

molinch commented Jan 20, 2022

public static ICacheEntry Create(IMemoryCache cache, object key)
        {
            return AsyncCreateEntry().GetAwaiter().GetResult();

#pragma warning disable 1998
            async Task<ICacheEntry> AsyncCreateEntry() => new SafeCacheEntry(cache.CreateEntry(key));
#pragma warning restore 1998
        }

What's the advantage of this construct over directly returning new SafeCacheEntry(cache.CreateEntry(key) ?

@StephenCleary
Copy link
Author

It avoids linked cache entries, a nice undocumented "feature" that completely breaks asynchronous caches.

@molinch
Copy link

molinch commented Jan 21, 2022

Thanks @StephenCleary it's really obscure !

@molinch
Copy link

molinch commented Jan 21, 2022

Wrote a very simple unit test you may/may not want to add:

public class AsyncCacheTests
{

        [Fact]
        public async Task ShouldNotRunFactoryDelegateMoreThanOnce()
        {
            using var memoryCache = new MemoryCache(new MemoryCacheOptions());
            var asyncCache = new AsyncCache(memoryCache);
            var factoryCalls = 0;

            // Act
            await Parallel.ForEachAsync(Enumerable.Range(0, 50), async (i, cancellationToken) =>
            {
                await asyncCache.GetOrCreateAsync("key", async cacheEntry =>
                {
                    await Task.Delay(TimeSpan.FromSeconds(1));
                    cacheEntry.SetAbsoluteExpiration(TimeSpan.FromDays(1));
                    return Interlocked.Increment(ref factoryCalls);
                });
            });

            Assert.Equal(1, factoryCalls);
        }
}

@StephenCleary
Copy link
Author

Thanks! I do have a bunch of unit tests (just not included in this gist); I'll add this one, too.

@cocowalla
Copy link

Would be ideal if the async methods could return ValueTask<T> instead of Task<T>, but I guess that's not possible as-is because of the user of TaskCompletionSource?

@StephenCleary
Copy link
Author

@cocowalla I don't see the need for ValueTask<T> here. GetOrCreateAsync returns a Task<T>, but it's a cached instance; it's not creating a new one on each call.

It may be possible to have the create delegate return a ValueTask<T>, but I'm not sure how much that would improve performance-wise. It would only be called once per cached value.

@theodorzoulias
Copy link

try { cts.Cancel(); } catch (ObjectDisposedException) { }

This line makes me nervous. The Cancel might throw an AggregateException containing all the exceptions thrown by the registered callbacks on the associated CancellationToken. In case this happens the tcs will not get completed, and most likely a deadlock will occur.

@StephenCleary
Copy link
Author

@theodorzoulias Thanks! I believe you're correct, and that this can actually happen with a custom memory cache underlying this one. I'll change it to catch Exception instead since we always want to continue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment