Skip to content

Instantly share code, notes, and snippets.

@jnm2
Last active August 12, 2023 15:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jnm2/c82431b2c506ea3d761b162febad4213 to your computer and use it in GitHub Desktop.
Save jnm2/c82431b2c506ea3d761b162febad4213 to your computer and use it in GitHub Desktop.
/// <summary>
/// Tracks how many references there are to a disposable object, and disposes it when there are none remaining.
/// </summary>
public sealed class RefCountingDisposer
{
private readonly IDisposable disposable;
private uint refCount = 1;
private readonly object lockObject = new();
/// <summary>
/// Begins tracking an initial reference to <paramref name="disposable"/>. The reference count starts as <c>1</c>.
/// If the next call is <see cref="Release"/>, the reference count will go to <c>0</c> and <paramref
/// name="disposable"/> will be disposed. An additional <see cref="Release"/> call will be needed for each <see
/// cref="AddRef"/> call, if any.
/// </summary>
/// <returns>
/// A <see cref="RefCountingDisposer"/> which can track further references and which will dispose <paramref
/// name="disposable"/> when the last reference has been released.
/// </returns>
public RefCountingDisposer(IDisposable disposable)
{
this.disposable = disposable ?? throw new ArgumentNullException(nameof(disposable));
}
/// <summary>
/// <para>
/// Reflects that an additional reference to the tracked object has been made. This will require an additional call
/// to <see cref="Release"/> before the tracked object will be disposed by this <see cref="RefCountingDisposer"/>
/// instance.
/// </para>
/// <para>
/// <see cref="InvalidOperationException"/> is thrown if all references have already been released and there is no
/// longer anything to track.
/// </para>
/// </summary>
/// <exception cref="InvalidOperationException">
/// Thrown if all references have already been released and there is no longer anything to track.
/// </exception>
public void AddRef()
{
lock (lockObject)
{
if (refCount == 0)
throw new InvalidOperationException($"{nameof(AddRef)} must not be called after all references have been released.");
refCount++;
}
}
/// <summary>
/// Reflects that a reference to the tracked object has been released. If the last remaining reference is released,
/// the tracked object will be disposed and future calls to <see cref="AddRef()"/> and <see cref="Release"/> will
/// throw <see cref="InvalidOperationException"/>.
/// </summary>
/// <exception cref="InvalidOperationException">
/// Thrown if all references have already been released and there is no longer anything to track.
/// </exception>
public void Release()
{
bool dispose;
lock (lockObject)
{
if (refCount == 0)
throw new InvalidOperationException($"{nameof(Release)} must not be called after all references have been released.");
refCount--;
dispose = refCount == 0;
}
if (dispose) disposable.Dispose();
}
/// <summary>
/// <para>
/// Indicates whether all references have been released. When <see langword="true"/>, the tracked object is either
/// disposed already or in the process of being disposed on a different thread.
/// </para>
/// <para>
/// ⚠️ Subsequent calls to <see cref="AddRef"/> and <see cref="Release"/> may still throw even if this property
/// returns <see langword="false"/>. Another thread may have executed the final <see cref="Release"/> call in the
/// meantime.
/// </para>
/// </summary>
public bool IsClosed => Volatile.Read(ref refCount) == 0;
/// <summary>
/// <para>
/// Reflects that an additional reference to the tracked object has been made and returns a lease object which
/// releases that reference when disposed. <see cref="IDisposable.Dispose"/> is idempotent and thread-safe on the
/// returned object.
/// </para>
/// <para>
/// <see cref="InvalidOperationException"/> is thrown if all references have already been released and there is no
/// longer anything to track. If unbalanced calls to <see cref="Release"/> are made separately, <see
/// cref="IDisposable.Dispose"/> may also throw <see cref="InvalidOperationException"/> due to all references
/// already having been released.
/// </para>
/// </summary>
/// <exception cref="InvalidOperationException">
/// Thrown if all references have already been released and there is no longer anything to track.
/// </exception>
public IDisposable Lease()
{
AddRef();
return new RefLease(this);
}
private sealed class RefLease : IDisposable
{
private RefCountingDisposer? disposer;
public RefLease(RefCountingDisposer disposer)
{
this.disposer = disposer;
}
public void Dispose()
{
Interlocked.Exchange(ref disposer, null)?.Release();
}
}
}
using NSubstitute;
using NUnit.Framework;
public static class RefCountingDisposerTests
{
[Test]
public static void Disposable_object_is_required()
{
Should.Throw<ArgumentNullException>(() => new RefCountingDisposer(disposable: null!))
.ParamName.ShouldBe("disposable");
}
[Test]
public static void Initial_count_is_1()
{
var disposable = Substitute.For<IDisposable>();
var disposer = new RefCountingDisposer(disposable);
disposable.DidNotReceive().Dispose();
disposer.Release();
disposable.Received().Dispose();
}
[Test]
public static void AddRef_calls_require_additional_Release_calls_before_disposal([Range(1, 4)] int additionalRefs)
{
var disposable = Substitute.For<IDisposable>();
var disposer = new RefCountingDisposer(disposable);
for (var i = 0; i < additionalRefs; i++)
{
disposer.AddRef();
}
for (var i = 0; i < additionalRefs; i++)
{
disposer.Release();
disposable.DidNotReceive().Dispose();
}
disposer.Release();
disposable.Received().Dispose();
}
[Test]
public static void AddRef_is_invalid_after_final_release()
{
var disposer = new RefCountingDisposer(Substitute.For<IDisposable>());
disposer.Release();
Should.Throw<InvalidOperationException>(disposer.AddRef)
.Message.ShouldBe("AddRef must not be called after all references have been released.");
}
[Test]
public static void Release_is_invalid_after_final_release()
{
var disposer = new RefCountingDisposer(Substitute.For<IDisposable>());
disposer.Release();
Should.Throw<InvalidOperationException>(disposer.Release)
.Message.ShouldBe("Release must not be called after all references have been released.");
}
[Test]
public static void AddRef_during_dispose_is_detected_as_invalid()
{
var disposable = Substitute.For<IDisposable>();
var disposer = new RefCountingDisposer(disposable);
disposable.When(disposable => disposable.Dispose()).Do(_ =>
{
Should.Throw<InvalidOperationException>(disposer.AddRef)
.Message.ShouldBe("AddRef must not be called after all references have been released.");
});
disposer.Release();
}
[Test]
public static void Release_during_dispose_is_detected_as_invalid()
{
var disposable = Substitute.For<IDisposable>();
var disposer = new RefCountingDisposer(disposable);
var disposeCalled = false;
disposable.When(disposable => disposable.Dispose()).Do(_ =>
{
if (disposeCalled) return; // Prefer test failure rather than stack overflow which crashes test host
disposeCalled = true;
Should.Throw<InvalidOperationException>(disposer.Release)
.Message.ShouldBe("Release must not be called after all references have been released.");
});
disposer.Release();
}
[Test]
public static void AddRef_is_thread_safe_with_itself()
{
var disposable = Substitute.For<IDisposable>();
var disposer = new RefCountingDisposer(disposable);
const int degreeOfParallelism = 8;
const int iterationsPerJob = 1_000_000;
Parallel.For(
fromInclusive: 0,
toExclusive: degreeOfParallelism,
_ =>
{
for (var i = 0; i < iterationsPerJob; i++)
{
disposer.AddRef();
}
});
for (var i = 0; i < degreeOfParallelism * iterationsPerJob; i++)
{
disposer.Release();
}
disposable.DidNotReceive().Dispose();
disposer.Release();
disposable.Received().Dispose();
}
[Test]
public static void Non_final_release_is_thread_safe_with_itself()
{
var disposable = Substitute.For<IDisposable>();
var disposer = new RefCountingDisposer(disposable);
const int degreeOfParallelism = 8;
const int iterationsPerJob = 1_000_000;
for (var i = 0; i < degreeOfParallelism * iterationsPerJob; i++)
{
disposer.AddRef();
}
Parallel.For(
fromInclusive: 0,
toExclusive: degreeOfParallelism,
_ =>
{
for (var i = 0; i < iterationsPerJob; i++)
{
disposer.Release();
}
});
disposable.DidNotReceive().Dispose();
disposer.Release();
disposable.Received().Dispose();
}
[Test]
public static void AddRef_and_non_final_Release_are_thread_safe_with_each_other()
{
var disposable = Substitute.For<IDisposable>();
var disposer = new RefCountingDisposer(disposable);
const int degreeOfParallelism = 8;
const int iterationsPerJob = 1_000_000;
Parallel.For(
fromInclusive: 0,
toExclusive: degreeOfParallelism,
_ =>
{
for (var i = 0; i < iterationsPerJob; i++)
{
disposer.AddRef();
disposer.Release();
}
});
disposable.DidNotReceive().Dispose();
disposer.Release();
disposable.Received().Dispose();
}
[Test]
public static void Racing_final_releases_dispose_exactly_once()
{
var disposable = Substitute.For<IDisposable>();
var disposer = new RefCountingDisposer(disposable);
const int degreeOfParallelism = 8;
for (var i = 0; i < 1_000_000; i++)
{
disposer.AddRef();
}
Parallel.For(
fromInclusive: 0,
toExclusive: degreeOfParallelism,
_ =>
{
for (var i = 0; ; i++)
{
try
{
disposer.Release();
}
catch (InvalidOperationException)
{
i.ShouldBeGreaterThan(0, "Suspiciously low iteration count, likely not running in parallel.");
break;
}
}
});
disposable.Received(requiredNumberOfCalls: 1).Dispose();
}
[Test]
public static async Task Disposal_is_not_done_inside_a_lock()
{
using var readyToLetDisposeCallReturn = new ManualResetEventSlim();
using var insideDispose = new ManualResetEventSlim();
var disposable = Substitute.For<IDisposable>();
disposable.When(disposable => disposable.Dispose()).Do(_ =>
{
insideDispose.Set();
readyToLetDisposeCallReturn.Wait();
});
var disposer = new RefCountingDisposer(disposable);
var disposeTask = Task.Run(disposer.Release);
insideDispose.Wait();
Should.CompleteIn(
() => Should.Throw<InvalidOperationException>(disposer.AddRef),
TimeSpan.FromSeconds(5));
readyToLetDisposeCallReturn.Set();
await disposeTask;
}
[Test]
public static async Task Lock_cannot_be_accidentally_used_outside_the_class()
{
using var stopAccidentalUsage = new ManualResetEventSlim();
var disposable = Substitute.For<IDisposable>();
var disposer = new RefCountingDisposer(disposable);
var accidentalUsageTask = Task.Run(() =>
{
lock (disposer)
{
lock (disposable)
{
stopAccidentalUsage.Wait();
}
}
});
Should.CompleteIn(disposer.AddRef, TimeSpan.FromSeconds(5));
stopAccidentalUsage.Set();
await accidentalUsageTask;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment