Skip to content

Instantly share code, notes, and snippets.

@GrabYourPitchforks
Created April 2, 2018 20:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save GrabYourPitchforks/fbbe14b18043d4e2452205803afac344 to your computer and use it in GitHub Desktop.
Save GrabYourPitchforks/fbbe14b18043d4e2452205803afac344 to your computer and use it in GitHub Desktop.
Validating MemoryPool<T>
/*
* !! WARNING !!
*
* COMPLETELY UNTESTED CODE
*/
using Microsoft.Win32.SafeHandles;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.ConstrainedExecution;
using System.Runtime.InteropServices;
using System.Security;
using System.Threading;
namespace System.Buffers
{
/// <summary>
/// A memory pool that performs validation on the returned <see cref="Memory{T}"/> instances.
/// </summary>
public sealed class ValidatingMemoryPool<T> : MemoryPool<T>
{
// singleton instance
private ValidatingMemoryPool() { }
public new static ValidatingMemoryPool<T> Shared { get; } = new ValidatingMemoryPool<T>();
public override int MaxBufferSize => Int32.MaxValue;
public override IMemoryOwner<T> Rent(int minBufferSize = -1)
{
// Parameter validation
if (RuntimeHelpers.IsReferenceOrContainsReferences<T>())
{
throw new InvalidOperationException("This pool can only be used with blittable types.");
}
if (minBufferSize < -1)
{
throw new ArgumentOutOfRangeException(paramName: nameof(minBufferSize));
}
// Always make sure at least one byte is being allocated.
return new ValidatingMemoryManager(Math.Max(minBufferSize, 1 + (4095 / Unsafe.SizeOf<T>())));
}
protected override void Dispose(bool disposing)
{
// no-op
}
private unsafe sealed class ValidatingMemoryManager : MemoryManager<T>
{
private readonly int _elementCount;
private readonly VirtualAllocSafeHandle _handle;
private int _refCount = 0; // low bit is used to signal 'disposed'
public ValidatingMemoryManager(int elementCount)
{
Debug.Assert(elementCount > 0);
// Allocate a block of memory large enough to hold the desired number of elements.
_elementCount = elementCount;
_handle = VirtualAllocSafeHandle.Alloc((IntPtr)Unsafe.Add<T>((void*)null, elementCount));
// Don't allow the handle to be reclaimed (released) until the AppDomain terminates.
// Our dispose method will decommit the memory, not release it.
GCHandle.Alloc(_handle);
}
public override int Length
{
get
{
FailIfDisposed();
return _elementCount;
}
}
public override Memory<T> Memory
{
get
{
FailIfDisposed();
return base.Memory;
}
}
protected override void Dispose(bool disposing)
{
int existingRefCount = Interlocked.CompareExchange(ref _refCount, 1, 0);
if ((existingRefCount & 1) == 1)
{
Fail("Attempt to double-dispose this instance.");
}
else if (existingRefCount != 0)
{
Fail("Attempt to dispose an instance that has an outstanding call to Pin().");
}
_handle.Decommit();
}
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.NoOptimization)]
private static void Fail(string message)
{
if (Debugger.IsAttached)
{
Debugger.Log(0, null, message);
}
try
{
throw new InvalidOperationException(message);
}
finally
{
GC.KeepAlive(message); // to make sure it's present in debugger 'locals' window
}
}
private void FailIfDisposed()
{
if ((Volatile.Read(ref _refCount) & 1) == 1)
{
Fail("Attempt to access a disposed object.");
}
}
public override Span<T> GetSpan()
{
FailIfDisposed();
bool refAdded = false;
try
{
_handle.DangerousAddRef(ref refAdded);
return new Span<T>((void*)_handle.DangerousGetHandle(), _elementCount);
}
finally
{
if (refAdded)
{
_handle.DangerousRelease();
}
}
}
public override MemoryHandle Pin(int elementIndex = 0)
{
bool safeHandleRefAdded = false;
try
{
Interlocked.Add(ref _refCount, 2); // this won't throw
FailIfDisposed(); // Interlocked call introduces thread synchronization, so check disposed after
_handle.DangerousAddRef(ref safeHandleRefAdded); // this *shouldn't* throw
return new MemoryHandle(
pointer: (void*)_handle.DangerousGetHandle(),
handle: default,
pinnable: new ValidatingPinnable(this)); // this allocation might throw (OOM)
}
catch
{
Interlocked.Add(ref _refCount, -2);
throw;
}
finally
{
if (safeHandleRefAdded)
{
_handle.DangerousRelease();
}
}
}
protected override bool TryGetArray(out ArraySegment<T> segment)
{
FailIfDisposed();
segment = default;
return false;
}
public override void Unpin()
{
// We pass a different 'IPinnable' instance to the MemoryHandle ctor,
// so this Unpin method should never be called.
Fail("This method should never be called.");
}
private sealed class ValidatingPinnable : IPinnable
{
private ValidatingMemoryManager _manager;
public ValidatingPinnable(ValidatingMemoryManager manager)
{
_manager = manager;
}
~ValidatingPinnable()
{
if (Debugger.IsAttached)
{
Debugger.Log(0, null, "Leak detected: Unpin() never called.");
Debugger.Break();
}
Fail("Leak detected: Unpin() never called.");
}
public MemoryHandle Pin(int elementIndex)
{
Fail("This method should never be called.");
return default;
}
public void Unpin()
{
GC.SuppressFinalize(this);
var manager = Interlocked.Exchange(ref _manager, null);
if (manager != null)
{
Interlocked.Add(ref _manager._refCount, -2);
}
else
{
Fail("Unpin() called twice; should only have been called once.");
}
}
}
}
}
internal sealed class VirtualAllocSafeHandle : SafeHandleZeroOrMinusOneIsInvalid
{
private readonly IntPtr _cb;
private VirtualAllocSafeHandle(IntPtr cb)
: base(ownsHandle: true)
{
_cb = cb;
}
// Allocates (reserves + commits) the specified number of bytes of memory.
public static VirtualAllocSafeHandle Alloc(IntPtr cb)
{
VirtualAllocSafeHandle retVal = new VirtualAllocSafeHandle(cb);
bool success = false;
RuntimeHelpers.PrepareConstrainedRegions();
try
{
}
finally
{
IntPtr handle = UnsafeNativeMethods.VirtualAlloc(
lpAddress: IntPtr.Zero,
dwSize: cb,
flAllocationType: VirtualAllocAllocationType.MEM_RESERVE | VirtualAllocAllocationType.MEM_COMMIT,
flProtect: VirtualAllocMemoryProtection.PAGE_READWRITE);
if (handle != IntPtr.Zero)
{
retVal.SetHandle(handle);
success = true;
}
}
if (!success)
{
Marshal.ThrowExceptionForHR(Marshal.GetHRForLastWin32Error());
}
return retVal;
}
// Decommits this block of memory without releasing it.
public void Decommit()
{
bool refAdded = false;
try
{
DangerousAddRef(ref refAdded);
if (!UnsafeNativeMethods.VirtualProtect(DangerousGetHandle(), _cb, VirtualAllocMemoryProtection.PAGE_NOACCESS, out _))
{
Marshal.ThrowExceptionForHR(Marshal.GetHRForLastWin32Error());
}
if (!UnsafeNativeMethods.VirtualFree(DangerousGetHandle(), _cb, VirtualAllocAllocationType.MEM_DECOMMIT))
{
Marshal.ThrowExceptionForHR(Marshal.GetHRForLastWin32Error());
}
}
finally
{
if (!refAdded)
{
DangerousRelease();
}
}
}
[ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)]
protected override bool ReleaseHandle()
{
return UnsafeNativeMethods.VirtualFree(handle, IntPtr.Zero, VirtualAllocAllocationType.MEM_RELEASE);
}
}
[SuppressUnmanagedCodeSecurity]
internal static class UnsafeNativeMethods
{
private const string KERNEL32_LIB = "kernel32.dll";
private const DllImportSearchPath KERNEL32_SEARCHPATH = DllImportSearchPath.System32;
// https://msdn.microsoft.com/en-us/library/windows/desktop/aa366887(v=vs.85).aspx
[DllImport(KERNEL32_LIB, CallingConvention = CallingConvention.Winapi, SetLastError = true)]
[DefaultDllImportSearchPaths(KERNEL32_SEARCHPATH)]
[ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)]
public static extern IntPtr VirtualAlloc(
[In] IntPtr lpAddress,
[In] IntPtr dwSize,
[In] VirtualAllocAllocationType flAllocationType,
[In] VirtualAllocMemoryProtection flProtect);
// https://msdn.microsoft.com/en-us/library/windows/desktop/aa366892(v=vs.85).aspx
[DllImport(KERNEL32_LIB, CallingConvention = CallingConvention.Winapi, SetLastError = true)]
[DefaultDllImportSearchPaths(KERNEL32_SEARCHPATH)]
[ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)]
public static extern bool VirtualFree(
[In] IntPtr lpAddress,
[In] IntPtr dwSize,
[In] VirtualAllocAllocationType dwFreeType);
// https://msdn.microsoft.com/en-us/library/windows/desktop/aa366898(v=vs.85).aspx
[DllImport(KERNEL32_LIB, CallingConvention = CallingConvention.Winapi, SetLastError = true)]
[DefaultDllImportSearchPaths(KERNEL32_SEARCHPATH)]
public static extern bool VirtualProtect(
[In] IntPtr lpAddress,
[In] IntPtr dwSize,
[In] VirtualAllocMemoryProtection flNewProtect,
[Out] out VirtualAllocMemoryProtection lpflOldProtect);
}
[Flags]
internal enum VirtualAllocAllocationType : uint
{
MEM_COMMIT = 0x00001000,
MEM_RESERVE = 0x00002000,
MEM_DECOMMIT = 0x00004000,
MEM_RELEASE = 0x00008000,
}
internal enum VirtualAllocMemoryProtection : uint
{
PAGE_NOACCESS = 0x01,
PAGE_READWRITE = 0x04,
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment