Skip to content

Instantly share code, notes, and snippets.

@redknightlois
Created July 10, 2023 21:15
Show Gist options
  • Save redknightlois/de4e99ff64f4c979ae9597e11fdc6d06 to your computer and use it in GitHub Desktop.
Save redknightlois/de4e99ff64f4c979ae9597e11fdc6d06 to your computer and use it in GitHub Desktop.
private static ReadOnlySpan<byte> LoadMaskTable => new byte[]
{
0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00
};
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static int CompareAvx2(ref byte p1, ref byte p2, int size)
{
// PERF: Given all the preparation that must happen before even accessing the pointers, even if we increase
// the size of the method by 10+ bytes, by the time we access the data it is already there in L1 cache.
Sse.Prefetch0(Unsafe.AsPointer(ref p1));
Sse.Prefetch0(Unsafe.AsPointer(ref p2));
// PERF: This allows us to do pointer arithmetic and use relative addressing using the
// hardware instructions without needed an extra register.
ref byte bpx = ref p1;
ref byte bpy = ref p2;
nuint length = (nuint)size;
ref byte bpxEnd = ref Unsafe.AddByteOffset(ref bpx, length);
uint matches;
// PERF: The alignment unit will be decided in terms of the total size, because we can use the exact same code
// for a length smaller than a vector or to force alignment to a certain memory boundary. This will cause some
// multi-modal behavior to appear (specially close to the vector size) because we will become dependent on
// the input. The biggest gains will be seen when the compares are several times bigger than the vector size,
// where the aligned memory access (no penalty) will dominate the runtime. So this formula will calculate how
// many bytes are required to get to an aligned pointer.
nuint alignmentUnit = length >= (nuint)Vector256<byte>.Count ? (nuint)(Vector256<byte>.Count - (long)Unsafe.AsPointer(ref bpx) % Vector256<byte>.Count) : length;
if ((alignmentUnit & (nuint)(Vector256<byte>.Count - 1)) == 0 || length is >= 32 and <= 512)
goto ProcessAligned;
// Check if we are completely aligned, in that case just skip everything and go straight to the
// core of the routine. We have much bigger fishes to fry.
if ((alignmentUnit & 2) != 0)
{
if (Unsafe.ReadUnaligned<ushort>(ref bpx) != Unsafe.ReadUnaligned<ushort>(ref bpy))
{
if (bpx == bpy)
{
bpx = ref Unsafe.AddByteOffset(ref bpx, 1);
bpy = ref Unsafe.AddByteOffset(ref bpy, 1);
}
return bpx - bpy;
}
bpx = ref Unsafe.AddByteOffset(ref bpx, 2);
bpy = ref Unsafe.AddByteOffset(ref bpy, 2);
}
// We have a potential problem. As AVX2 doesn't provide us a masked load that could address bytes
// we will need to ensure we are int aligned. Therefore, we have to do this as fast as possibly.
if ((alignmentUnit & 1) != 0)
{
if (bpx != bpy)
return bpx - bpy;
bpx = Unsafe.AddByteOffset(ref bpx, 1);
bpy = Unsafe.AddByteOffset(ref bpy, 1);
}
if ((long)Unsafe.AsPointer(ref bpxEnd) == (long)Unsafe.AsPointer(ref bpx))
return 0;
// PERF: From now on, at least 1 of the two memory sites will be 4 bytes aligned. Improving the chances to
// hit a 16 bytes (128-bits alignment) and also give us access to performed a single masked load to ensure
// 128-bits alignment. The reason why we want that is because natural alignment can impact the L1 data cache
// latency.
// For example in AMD 17th gen: A misaligned load operation suffers, at minimum, a one cycle penalty in the
// load-store pipeline if it spans a 32-byte boundary. Throughput for misaligned loads and stores is half
// that of aligned loads and stores since a misaligned load or store requires two cycles to access the data
// cache (versus a single cycle for aligned loads and stores).
// Source: https://developer.amd.com/wordpress/media/2013/12/55723_SOG_Fam_17h_Processors_3.00.pdf
// Now we know we are 4 bytes aligned. So now we can actually use this knowledge to perform a masked load
// of the leftovers to achieve 32 bytes alignment. In the case that we are smaller, this will just find the
// difference and we will jump to difference. Masked loads and stores will not cause memory access violations
// because no memory access happens per presentation from Intel.
// https://llvm.org/devmtg/2015-04/slides/MaskedIntrinsics.pdf
Debug.Assert(alignmentUnit > 0, "Cannot be 0 because that means that we have completed already.");
Debug.Assert(alignmentUnit < (nuint)Vector256<int>.Count * sizeof(int), $"Cannot be {Vector256<int>.Count * sizeof(int)} or greater because that means it is a full vector.");
int* tablePtr = (int*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(LoadMaskTable));
var mask = Avx.LoadDquVector256(tablePtr + ((nuint)Vector256<int>.Count - alignmentUnit / sizeof(uint)));
matches = (uint)Avx2.MoveMask(
Avx2.CompareEqual(
Avx2.MaskLoad((int*)Unsafe.AsPointer(ref bpx), mask).AsByte(),
Avx2.MaskLoad((int*)Unsafe.AsPointer(ref bpy), mask).AsByte()
)
);
if (matches != uint.MaxValue)
goto Difference;
// PERF: The reason why we don't keep the original alignment is because we want to get rid of the initial leftovers,
// so that would require an AND instruction anyways. In this way we get the same effect using a shift.
bpx = Unsafe.AddByteOffset(ref bpx, alignmentUnit & unchecked((nuint)~3));
ProcessAligned:
ref byte loopEnd = ref Unsafe.SubtractByteOffset(ref bpxEnd, (nuint)Vector256<byte>.Count);
while (Unsafe.IsAddressGreaterThan(ref loopEnd, ref bpx))
{
matches = (uint)Avx2.MoveMask(
Avx2.CompareEqual(
Vector256.LoadUnsafe(ref bpx),
Vector256.LoadUnsafe(ref bpy)
)
);
// Note that MoveMask has converted the equal vector elements into a set of bit flags,
// So the bit position in 'matches' corresponds to the element offset.
// 32 elements in Vector256<byte> so we compare to uint.MaxValue to check if everything matched
if (matches == uint.MaxValue)
{
// All matched
bpx = ref Unsafe.AddByteOffset(ref bpx, (nuint)Vector256<byte>.Count);
bpy = ref Unsafe.AddByteOffset(ref bpy, (nuint)Vector256<byte>.Count);
continue;
}
goto Difference;
}
// If can happen that we are done so we can avoid the last unaligned access.
if (Unsafe.AreSame(ref bpx, ref bpxEnd))
return 0;
bpx = loopEnd;
matches = (uint)Avx2.MoveMask(
Avx2.CompareEqual(
Vector256.LoadUnsafe(ref bpx),
Vector256.LoadUnsafe(ref bpy)
)
);
if (matches == uint.MaxValue)
return 0;
Difference:
// We invert matches to find differences, which are found in the bit-flag. .
// We then add offset of first difference to the current offset in order to check that specific byte.
nuint bytesToAdvance = (nuint)BitOperations.TrailingZeroCount(~matches);
bpx = ref Unsafe.AddByteOffset(ref bpx, bytesToAdvance);
bpy = ref Unsafe.AddByteOffset(ref bpy, bytesToAdvance);
return bpx - bpy;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment