Created
April 2, 2018 20:21
-
-
Save GrabYourPitchforks/fbbe14b18043d4e2452205803afac344 to your computer and use it in GitHub Desktop.
Validating MemoryPool<T>
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* !! WARNING !! | |
* | |
* COMPLETELY UNTESTED CODE | |
*/ | |
using Microsoft.Win32.SafeHandles; | |
using System.Diagnostics; | |
using System.Runtime.CompilerServices; | |
using System.Runtime.ConstrainedExecution; | |
using System.Runtime.InteropServices; | |
using System.Security; | |
using System.Threading; | |
namespace System.Buffers | |
{ | |
/// <summary> | |
/// A memory pool that performs validation on the returned <see cref="Memory{T}"/> instances. | |
/// </summary> | |
public sealed class ValidatingMemoryPool<T> : MemoryPool<T> | |
{ | |
// singleton instance | |
private ValidatingMemoryPool() { } | |
public new static ValidatingMemoryPool<T> Shared { get; } = new ValidatingMemoryPool<T>(); | |
public override int MaxBufferSize => Int32.MaxValue; | |
public override IMemoryOwner<T> Rent(int minBufferSize = -1) | |
{ | |
// Parameter validation | |
if (RuntimeHelpers.IsReferenceOrContainsReferences<T>()) | |
{ | |
throw new InvalidOperationException("This pool can only be used with blittable types."); | |
} | |
if (minBufferSize < -1) | |
{ | |
throw new ArgumentOutOfRangeException(paramName: nameof(minBufferSize)); | |
} | |
// Always make sure at least one byte is being allocated. | |
return new ValidatingMemoryManager(Math.Max(minBufferSize, 1 + (4095 / Unsafe.SizeOf<T>()))); | |
} | |
protected override void Dispose(bool disposing) | |
{ | |
// no-op | |
} | |
private unsafe sealed class ValidatingMemoryManager : MemoryManager<T> | |
{ | |
private readonly int _elementCount; | |
private readonly VirtualAllocSafeHandle _handle; | |
private int _refCount = 0; // low bit is used to signal 'disposed' | |
public ValidatingMemoryManager(int elementCount) | |
{ | |
Debug.Assert(elementCount > 0); | |
// Allocate a block of memory large enough to hold the desired number of elements. | |
_elementCount = elementCount; | |
_handle = VirtualAllocSafeHandle.Alloc((IntPtr)Unsafe.Add<T>((void*)null, elementCount)); | |
// Don't allow the handle to be reclaimed (released) until the AppDomain terminates. | |
// Our dispose method will decommit the memory, not release it. | |
GCHandle.Alloc(_handle); | |
} | |
public override int Length | |
{ | |
get | |
{ | |
FailIfDisposed(); | |
return _elementCount; | |
} | |
} | |
public override Memory<T> Memory | |
{ | |
get | |
{ | |
FailIfDisposed(); | |
return base.Memory; | |
} | |
} | |
protected override void Dispose(bool disposing) | |
{ | |
int existingRefCount = Interlocked.CompareExchange(ref _refCount, 1, 0); | |
if ((existingRefCount & 1) == 1) | |
{ | |
Fail("Attempt to double-dispose this instance."); | |
} | |
else if (existingRefCount != 0) | |
{ | |
Fail("Attempt to dispose an instance that has an outstanding call to Pin()."); | |
} | |
_handle.Decommit(); | |
} | |
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.NoOptimization)] | |
private static void Fail(string message) | |
{ | |
if (Debugger.IsAttached) | |
{ | |
Debugger.Log(0, null, message); | |
} | |
try | |
{ | |
throw new InvalidOperationException(message); | |
} | |
finally | |
{ | |
GC.KeepAlive(message); // to make sure it's present in debugger 'locals' window | |
} | |
} | |
private void FailIfDisposed() | |
{ | |
if ((Volatile.Read(ref _refCount) & 1) == 1) | |
{ | |
Fail("Attempt to access a disposed object."); | |
} | |
} | |
public override Span<T> GetSpan() | |
{ | |
FailIfDisposed(); | |
bool refAdded = false; | |
try | |
{ | |
_handle.DangerousAddRef(ref refAdded); | |
return new Span<T>((void*)_handle.DangerousGetHandle(), _elementCount); | |
} | |
finally | |
{ | |
if (refAdded) | |
{ | |
_handle.DangerousRelease(); | |
} | |
} | |
} | |
public override MemoryHandle Pin(int elementIndex = 0) | |
{ | |
bool safeHandleRefAdded = false; | |
try | |
{ | |
Interlocked.Add(ref _refCount, 2); // this won't throw | |
FailIfDisposed(); // Interlocked call introduces thread synchronization, so check disposed after | |
_handle.DangerousAddRef(ref safeHandleRefAdded); // this *shouldn't* throw | |
return new MemoryHandle( | |
pointer: (void*)_handle.DangerousGetHandle(), | |
handle: default, | |
pinnable: new ValidatingPinnable(this)); // this allocation might throw (OOM) | |
} | |
catch | |
{ | |
Interlocked.Add(ref _refCount, -2); | |
throw; | |
} | |
finally | |
{ | |
if (safeHandleRefAdded) | |
{ | |
_handle.DangerousRelease(); | |
} | |
} | |
} | |
protected override bool TryGetArray(out ArraySegment<T> segment) | |
{ | |
FailIfDisposed(); | |
segment = default; | |
return false; | |
} | |
public override void Unpin() | |
{ | |
// We pass a different 'IPinnable' instance to the MemoryHandle ctor, | |
// so this Unpin method should never be called. | |
Fail("This method should never be called."); | |
} | |
private sealed class ValidatingPinnable : IPinnable | |
{ | |
private ValidatingMemoryManager _manager; | |
public ValidatingPinnable(ValidatingMemoryManager manager) | |
{ | |
_manager = manager; | |
} | |
~ValidatingPinnable() | |
{ | |
if (Debugger.IsAttached) | |
{ | |
Debugger.Log(0, null, "Leak detected: Unpin() never called."); | |
Debugger.Break(); | |
} | |
Fail("Leak detected: Unpin() never called."); | |
} | |
public MemoryHandle Pin(int elementIndex) | |
{ | |
Fail("This method should never be called."); | |
return default; | |
} | |
public void Unpin() | |
{ | |
GC.SuppressFinalize(this); | |
var manager = Interlocked.Exchange(ref _manager, null); | |
if (manager != null) | |
{ | |
Interlocked.Add(ref _manager._refCount, -2); | |
} | |
else | |
{ | |
Fail("Unpin() called twice; should only have been called once."); | |
} | |
} | |
} | |
} | |
} | |
internal sealed class VirtualAllocSafeHandle : SafeHandleZeroOrMinusOneIsInvalid | |
{ | |
private readonly IntPtr _cb; | |
private VirtualAllocSafeHandle(IntPtr cb) | |
: base(ownsHandle: true) | |
{ | |
_cb = cb; | |
} | |
// Allocates (reserves + commits) the specified number of bytes of memory. | |
public static VirtualAllocSafeHandle Alloc(IntPtr cb) | |
{ | |
VirtualAllocSafeHandle retVal = new VirtualAllocSafeHandle(cb); | |
bool success = false; | |
RuntimeHelpers.PrepareConstrainedRegions(); | |
try | |
{ | |
} | |
finally | |
{ | |
IntPtr handle = UnsafeNativeMethods.VirtualAlloc( | |
lpAddress: IntPtr.Zero, | |
dwSize: cb, | |
flAllocationType: VirtualAllocAllocationType.MEM_RESERVE | VirtualAllocAllocationType.MEM_COMMIT, | |
flProtect: VirtualAllocMemoryProtection.PAGE_READWRITE); | |
if (handle != IntPtr.Zero) | |
{ | |
retVal.SetHandle(handle); | |
success = true; | |
} | |
} | |
if (!success) | |
{ | |
Marshal.ThrowExceptionForHR(Marshal.GetHRForLastWin32Error()); | |
} | |
return retVal; | |
} | |
// Decommits this block of memory without releasing it. | |
public void Decommit() | |
{ | |
bool refAdded = false; | |
try | |
{ | |
DangerousAddRef(ref refAdded); | |
if (!UnsafeNativeMethods.VirtualProtect(DangerousGetHandle(), _cb, VirtualAllocMemoryProtection.PAGE_NOACCESS, out _)) | |
{ | |
Marshal.ThrowExceptionForHR(Marshal.GetHRForLastWin32Error()); | |
} | |
if (!UnsafeNativeMethods.VirtualFree(DangerousGetHandle(), _cb, VirtualAllocAllocationType.MEM_DECOMMIT)) | |
{ | |
Marshal.ThrowExceptionForHR(Marshal.GetHRForLastWin32Error()); | |
} | |
} | |
finally | |
{ | |
if (!refAdded) | |
{ | |
DangerousRelease(); | |
} | |
} | |
} | |
[ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] | |
protected override bool ReleaseHandle() | |
{ | |
return UnsafeNativeMethods.VirtualFree(handle, IntPtr.Zero, VirtualAllocAllocationType.MEM_RELEASE); | |
} | |
} | |
[SuppressUnmanagedCodeSecurity] | |
internal static class UnsafeNativeMethods | |
{ | |
private const string KERNEL32_LIB = "kernel32.dll"; | |
private const DllImportSearchPath KERNEL32_SEARCHPATH = DllImportSearchPath.System32; | |
// https://msdn.microsoft.com/en-us/library/windows/desktop/aa366887(v=vs.85).aspx | |
[DllImport(KERNEL32_LIB, CallingConvention = CallingConvention.Winapi, SetLastError = true)] | |
[DefaultDllImportSearchPaths(KERNEL32_SEARCHPATH)] | |
[ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] | |
public static extern IntPtr VirtualAlloc( | |
[In] IntPtr lpAddress, | |
[In] IntPtr dwSize, | |
[In] VirtualAllocAllocationType flAllocationType, | |
[In] VirtualAllocMemoryProtection flProtect); | |
// https://msdn.microsoft.com/en-us/library/windows/desktop/aa366892(v=vs.85).aspx | |
[DllImport(KERNEL32_LIB, CallingConvention = CallingConvention.Winapi, SetLastError = true)] | |
[DefaultDllImportSearchPaths(KERNEL32_SEARCHPATH)] | |
[ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] | |
public static extern bool VirtualFree( | |
[In] IntPtr lpAddress, | |
[In] IntPtr dwSize, | |
[In] VirtualAllocAllocationType dwFreeType); | |
// https://msdn.microsoft.com/en-us/library/windows/desktop/aa366898(v=vs.85).aspx | |
[DllImport(KERNEL32_LIB, CallingConvention = CallingConvention.Winapi, SetLastError = true)] | |
[DefaultDllImportSearchPaths(KERNEL32_SEARCHPATH)] | |
public static extern bool VirtualProtect( | |
[In] IntPtr lpAddress, | |
[In] IntPtr dwSize, | |
[In] VirtualAllocMemoryProtection flNewProtect, | |
[Out] out VirtualAllocMemoryProtection lpflOldProtect); | |
} | |
[Flags] | |
internal enum VirtualAllocAllocationType : uint | |
{ | |
MEM_COMMIT = 0x00001000, | |
MEM_RESERVE = 0x00002000, | |
MEM_DECOMMIT = 0x00004000, | |
MEM_RELEASE = 0x00008000, | |
} | |
internal enum VirtualAllocMemoryProtection : uint | |
{ | |
PAGE_NOACCESS = 0x01, | |
PAGE_READWRITE = 0x04, | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment