Skip to content

Instantly share code, notes, and snippets.

@gfoidl
Last active August 23, 2023 19:08
Show Gist options
  • Save gfoidl/b3da2a1bacbd617c04584b2398cedc6e to your computer and use it in GitHub Desktop.
Save gfoidl/b3da2a1bacbd617c04584b2398cedc6e to your computer and use it in GitHub Desktop.
ReduceMin/Max for Vector128<T> where T : IBinaryInteger
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
internal static class VectorExtensions
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static T ReduceMin<T>(this Vector128<T> vector) where T : struct, IBinaryInteger<T>
{
if (Sse41.IsSupported)
{
if (typeof(T) == typeof(byte))
{
// We treat the vector as short and shift right by 8 bits. Thus we get
// v0: <b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15>
// shifted: <b1, 0, b3, 0, b5, 0, b7, 0, b9, 0, b11, 0, b13, 0, b15, 0>
// Then when treated as vector of byte again, we can build 8 pairwise
// minimums.
// With SSE4.1's _mm_minpos_epu16 we get the final minimum.
// As all values are within the byte-range, and can't exceed it, the
// cast from ushort -> byte is safe.
Vector128<byte> v0 = vector.AsByte();
Vector128<byte> shifted = Sse2.ShiftRightLogical(v0.AsInt16(), 8).AsByte();
Vector128<byte> tmp = Sse2.Min(v0, shifted);
Vector128<ushort> minTmp = Sse41.MinHorizontal(tmp.AsUInt16());
return (T)(object)(byte)minTmp.ToScalar();
}
else if (typeof(T) == typeof(sbyte))
{
// We transform the values so that sbyte.MinValue -> 0, sbyte.MaxValue -> byte.MaxValue.
// Thus we reduced the problem to be the same as the minimum of byte-vector (above).
// The transformation needs to be undone before returning the result.
Vector128<byte> v0 = (vector.AsSByte() - Vector128.Create(sbyte.MinValue)).AsByte();
byte minTmp = ReduceMin(v0);
return (T)(object)(sbyte)(minTmp - sbyte.MinValue);
}
else if (typeof(T) == typeof(ushort))
{
Vector128<ushort> minTmp = Sse41.MinHorizontal(vector.AsUInt16());
return (T)(object)minTmp.ToScalar();
}
else if (typeof(T) == typeof(short))
{
// We transformm the values so that short.MinValue -> 0, short.MaxValue -> ushort.MaxValue.
// Thus we reduced the problem to be the same as the minimum of ushort-vector (above).
// The transformation needs to be undone before returning the result.
Vector128<ushort> v0 = (vector.AsInt16() - Vector128.Create(short.MinValue)).AsUInt16();
ushort minTmp = ReduceMin(v0);
return (T)(object)(short)(minTmp - short.MinValue);
}
else if (typeof(T) == typeof(uint))
{
Vector128<uint> v0 = vector.AsUInt32();
Vector128<uint> v1 = Sse2.Shuffle(v0, 0b_11_11_01_01);
// https://github.com/dotnet/runtime/issues/75892
Vector128<uint> val0 = Sse41.Min(v0, v1);
Vector128<uint> val1 = Sse2.Shuffle(val0, 0b_11_10_01_10);
return (T)(object)Sse41.Min(val0, val1).ToScalar();
}
else if (typeof(T) == typeof(int))
{
Vector128<int> v0 = vector.AsInt32();
Vector128<int> v1 = Sse2.Shuffle(v0, 0b_11_11_01_01);
// https://github.com/dotnet/runtime/issues/75892
Vector128<int> val0 = Sse41.Min(v0, v1);
Vector128<int> val1 = Sse2.Shuffle(val0, 0b_11_10_01_10);
return (T)(object)Sse41.Min(val0, val1).ToScalar();
}
// These are not worth it.
//else if (typeof(T) == typeof(ulong))
//{
// Vector128<ulong> v0 = vector.AsUInt64();
// Vector128<ulong> v1 = Sse2.ShiftRightLogical128BitLane(v0, sizeof(ulong));
// return (T)(object)Vector128.Min(v0, v1).ToScalar();
//}
//else if (typeof(T) == typeof(long))
//{
// Vector128<long> v0 = vector.AsInt64();
// Vector128<long> v1 = Sse2.ShiftRightLogical128BitLane(v0, sizeof(long));
// return (T)(object)Vector128.Min(v0, v1).ToScalar();
//}
}
T min = vector[0];
for (int i = 1; i < Vector128<T>.Count; i++) // loop gets unrolled if count <= 4
{
if (vector[i] < min)
{
min = vector[i];
}
}
return min;
}
//-------------------------------------------------------------------------
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static T ReduceMax<T>(this Vector128<T> vector) where T : struct, IBinaryInteger<T>
{
if (Sse41.IsSupported)
{
if (typeof(T) == typeof(byte))
{
// We transform the values so that byte.MaxValue -> 0, byte.MinValue -> byte.MaxValue.
// Thus we reduced the problem to be the same as the minimum of byte-vector (above).
// The transformation needs to be undone before returning the result.
Vector128<byte> v0 = ~vector.AsByte();
byte minTmp = ReduceMin(v0);
return (T)(object)(byte)~minTmp;
}
else if (typeof(T) == typeof(sbyte))
{
// We transform the values so that sbyte.MaxValue -> 0, sbyte.MinValue -> byte.MaxValue.
// Thus we reduced the problem to be the same as the minimum of byte-vector (above).
// The transformation needs to be undone before returning the result.
Vector128<byte> v0 = (Vector128.Create(sbyte.MaxValue) - vector.AsSByte()).AsByte();
byte minTmp = ReduceMin(v0);
return (T)(object)(sbyte)(sbyte.MaxValue - minTmp);
}
else if (typeof(T) == typeof(ushort))
{
// We transform the values so that ushort.MaxValue -> 0, ushort.MinValue -> ushort.MaxValue.
// Thus we reduced the problem to be the same as the minimum of ushort-vector (above).
// The transformation needs to be undone before returning the result.
ushort minTmp = ReduceMin(~vector.AsUInt16());
return (T)(object)(ushort)~minTmp;
}
else if (typeof(T) == typeof(short))
{
// We transform the values so that short.MaxValue -> 0, short.MinValue -> ushort.MaxValue.
// Thus we reduced the problem to be the same as the minimum of ushort-vector (above).
// The transformation needs to be undone before returning the result.
Vector128<ushort> v0 = (Vector128.Create(short.MaxValue) - vector.AsInt16()).AsUInt16();
ushort minTmp = ReduceMin(v0);
return (T)(object)(short)(short.MaxValue - minTmp);
}
else if (typeof(T) == typeof(uint))
{
Vector128<uint> v0 = vector.AsUInt32();
Vector128<uint> v1 = Sse2.Shuffle(v0, 0b_11_11_01_01);
// https://github.com/dotnet/runtime/issues/75892
Vector128<uint> val0 = Sse41.Max(v0, v1);
Vector128<uint> val1 = Sse2.Shuffle(val0, 0b_11_10_01_10);
return (T)(object)Sse41.Max(val0, val1).ToScalar();
}
else if (typeof(T) == typeof(int))
{
Vector128<int> v0 = vector.AsInt32();
Vector128<int> v1 = Sse2.Shuffle(v0, 0b_11_11_01_01);
// https://github.com/dotnet/runtime/issues/75892
Vector128<int> val0 = Sse41.Max(v0, v1);
Vector128<int> val1 = Sse2.Shuffle(val0, 0b_11_10_01_10);
return (T)(object)Sse41.Max(val0, val1).ToScalar();
}
// These are not worth it.
//else if (typeof(T) == typeof(ulong))
//{
// Vector128<ulong> v0 = vector.AsUInt64();
// Vector128<ulong> v1 = Sse2.ShiftRightLogical128BitLane(v0, sizeof(ulong));
// return (T)(object)Vector128.Max(v0, v1).ToScalar();
//}
//else if (typeof(T) == typeof(long))
//{
// Vector128<long> v0 = vector.AsInt64();
// Vector128<long> v1 = Sse2.ShiftRightLogical128BitLane(v0, sizeof(long));
// return (T)(object)Vector128.Max(v0, v1).ToScalar();
//}
}
T max = vector[0];
for (int i = 1; i < Vector128<T>.Count; i++) // loop gets unrolled if count <= 4
{
if (vector[i] > max)
{
max = vector[i];
}
}
return max;
}
//-------------------------------------------------------------------------
public static T ReduceMin<T>(this Vector256<T> vector) where T : struct, IBinaryInteger<T>
{
// TODO: implement.
// So far I didn't think much about Vector256 here.
// Simple way would be to operate on the two lanes, then combine lower and upper.
throw new NotImplementedException();
}
//-------------------------------------------------------------------------
public static T ReduceMax<T>(this Vector256<T> vector) where T : struct, IBinaryInteger<T>
{
// TODO: implement.
// So far I didn't think much about Vector256 here.
// Simple way would be to operate on the two lanes, then combine lower and upper.
throw new NotImplementedException();
}
}
@gfoidl
Copy link
Author

gfoidl commented Sep 25, 2022

The vectorized reduction is

  • branchless
  • has less code-size

E.g. for byte:

; Bench`1[[System.Byte, System.Private.CoreLib]].MaxDefault()
       sub       rsp,18
       vzeroupper
       vmovupd   xmm0,[rcx+18]
       vpextrb   eax,xmm0,0
       mov       edx,1
       nop       dword ptr [rax+rax]
M00_L00:
       vmovupd   [rsp+8],xmm0
       movzx     ecx,byte ptr [rsp+rdx+8]
       mov       r8d,ecx
       cmp       r8d,eax
       jle       short M00_L01
       mov       eax,ecx
M00_L01:
       inc       edx
       cmp       edx,10
       jl        short M00_L00
       add       rsp,18
       ret
; Total bytes of code 65

; VectorExtensions.ReduceMin[[System.Byte, System.Private.CoreLib]](System.Runtime.Intrinsics.Vector128`1<Byte>)
       vzeroupper
       vmovupd   xmm0,[rcx]
       vpsrlw    xmm1,xmm0,8
       vpminub   xmm0,xmm0,xmm1
       vphminposuw xmm0,xmm0
       vmovd     eax,xmm0
       movzx     eax,ax
       movzx     eax,al
       ret
; Total bytes of code 32

or int:

; Bench`1[[System.Int32, System.Private.CoreLib]].MaxDefault()
       vzeroupper
       vmovupd   xmm0,[rcx+18]
       vmovd     eax,xmm0
       vpextrd   edx,xmm0,1
       cmp       edx,eax
       jle       short M00_L00
       mov       eax,edx
M00_L00:
       vpextrd   edx,xmm0,2
       cmp       edx,eax
       jle       short M00_L01
       mov       eax,edx
M00_L01:
       vpextrd   edx,xmm0,3
       cmp       edx,eax
       jle       short M00_L02
       mov       eax,edx
M00_L02:
       ret
; Total bytes of code 49

; Bench`1[[System.Int32, System.Private.CoreLib]].MaxVectorMadness()
       vzeroupper
       vmovupd   xmm0,[rcx+18]
       vpshufd   xmm1,xmm0,0F5
       vpmaxsd   xmm0,xmm0,xmm1
       vpshufd   xmm1,xmm0,0E6
       vpmaxsd   xmm0,xmm0,xmm1
       vmovd     eax,xmm0
       ret
; Total bytes of code 33
Program.cs
#define BENCH
//#define DUMP_ASM

using System.Numerics;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
using BenchmarkDotNet.Attributes;
using BenchmarkDotNet.Configs;

Type[] types =
{
    typeof(byte),
    typeof(sbyte),
    typeof(short),
    typeof(ushort),
    typeof(int),
    typeof(uint),
    typeof(long),
    typeof(ulong)
};

foreach (Type intType in types)
{
    Console.WriteLine(intType);
    Console.WriteLine(new string('=', intType.FullName!.Length));

    Type benchType              = typeof(Bench<>).MakeGenericType(intType);
    MethodInfo setup            = benchType.GetMethod(nameof(Bench<byte>.Setup))!;
    MethodInfo minDefault       = benchType.GetMethod(nameof(Bench<byte>.MinDefault))!;
    MethodInfo maxDefault       = benchType.GetMethod(nameof(Bench<byte>.MaxDefault))!;
    MethodInfo minVectorMadness = benchType.GetMethod(nameof(Bench<byte>.MinVectorMadness))!;
    MethodInfo maxVectorMadness = benchType.GetMethod(nameof(Bench<byte>.MaxVectorMadness))!;

    object bench = Activator.CreateInstance(benchType)!;
    setup.Invoke(bench, null);

    object? min0 = minDefault.Invoke(bench, null);
    object? max0 = maxDefault.Invoke(bench, null);
    object? min1 = minVectorMadness.Invoke(bench, null);
    object? max1 = maxVectorMadness.Invoke(bench, null);

#if DEBUG
    Console.WriteLine("             min\tmax");
    Console.WriteLine($"default:     {min0,3}\t{max0,3}");
    Console.WriteLine($"vec-madness: {min1,3}\t{max1,3}");

    // Super awful hack...I know...
    if (min0?.ToString() != min1?.ToString() || max0?.ToString() != max1?.ToString())
    {
        Console.ForegroundColor = ConsoleColor.Red;
        Console.WriteLine("Error!");
        Console.ResetColor();
    }

    Console.WriteLine(new string('-', 60));

#else
#if DUMP_ASM
    for (int i = 0; i < 100; ++i)
    {
        if (i % 10 == 0) Thread.Sleep(50);

        _ = minDefault.Invoke(bench, null);
        _ = maxDefault.Invoke(bench, null);
        _ = minVectorMadness.Invoke(bench, null);
        _ = maxVectorMadness.Invoke(bench, null);
    }
#endif
#endif
}

#if !DEBUG && BENCH
BenchmarkDotNet.Running.BenchmarkSwitcher.FromAssembly(typeof(Program).Assembly).Run(args);
//BenchmarkDotNet.Running.BenchmarkSwitcher.FromAssembly(typeof(Program).Assembly).RunAllJoined();
#endif

[ShortRunJob]
[GenericTypeArguments(typeof(byte))]
[GenericTypeArguments(typeof(sbyte))]
[GenericTypeArguments(typeof(short))]
[GenericTypeArguments(typeof(ushort))]
[GenericTypeArguments(typeof(int))]
[GenericTypeArguments(typeof(uint))]
[GenericTypeArguments(typeof(long))]
[GenericTypeArguments(typeof(ulong))]
[CategoriesColumn]
[GroupBenchmarksBy(BenchmarkLogicalGroupRule.ByCategory)]
[DisassemblyDiagnoser(maxDepth: 3)]
public class Bench<T> where T : struct, IBinaryInteger<T>
{
    private Vector128<T> _vecMin;
    private Vector128<T> _vecMax;

    [GlobalSetup]
    public void Setup()
    {
        (_vecMin, _vecMax) = CreateVector128();
    }

    [Benchmark(Baseline = true, Description = "Default")]
    [BenchmarkCategory("Min")]
    public T MinDefault() => MinMax<MinCalc<T>>(_vecMin);

    [Benchmark(Baseline = true, Description = "Default")]
    [BenchmarkCategory("Max")]
    public T MaxDefault() => MinMax<MaxCalc<T>>(_vecMax);

    [Benchmark(Description = "VectorMadness")]
    [BenchmarkCategory("Min")]
    public T MinVectorMadness() => _vecMin.ReduceMin();

    [Benchmark(Description = "VectorMadness")]
    [BenchmarkCategory("Max")]
    public T MaxVectorMadness() => _vecMax.ReduceMax();

    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static T MinMax<TMinMax>(Vector128<T> vec) where TMinMax : IMinMaxCalc<T>
    {
        T value = vec[0];
        for (int i = 1; i < Vector128<T>.Count; i++)
        {
            if (TMinMax.Compare(vec[i], value))
            {
                value = vec[i];
            }
        }

        return value;
    }

    private static (Vector128<T>, Vector128<T>) CreateVector128()
    {
        Span<T> values = new T[Vector128<T>.Count];

        int idx = 0;
        for (int i = Vector128<T>.Count; i > 0; --i)
        {
            values[idx++] = CreateValue(i);
        }

        Vector128<T> vectorMin = Vector128.LoadUnsafe(ref values[0]);
        values.Reverse();
        Vector128<T> vectorMax = Vector128.LoadUnsafe(ref values[0]);

#if DEBUG
        Console.WriteLine($"min vec: {vectorMin}");
        Console.WriteLine($"max vec: {vectorMax}");
        Console.WriteLine();
#endif

        return (vectorMin, vectorMax);

        static T CreateValue(int value)
        {
            if (typeof(T) == typeof(byte))
            {
                return (T)(object)(byte)value;
            }
            else if (typeof(T) == typeof(sbyte))
            {
                return (T)(object)(sbyte)value;
            }
            else if (typeof(T) == typeof(short))
            {
                return (T)(object)(short)value;
            }
            else if (typeof(T) == typeof(ushort))
            {
                return (T)(object)(ushort)value;
            }
            else if (typeof(T) == typeof(int))
            {
                return (T)(object)value;
            }
            else if (typeof(T) == typeof(uint))
            {
                return (T)(object)(uint)value;
            }
            else if (typeof(T) == typeof(long))
            {
                return (T)(object)(long)value;
            }
            else if (typeof(T) == typeof(ulong))
            {
                return (T)(object)(ulong)value;
            }

            throw new NotSupportedException();
        }
    }
}

internal interface IMinMaxCalc<T> where T : struct, IBinaryInteger<T>
{
    public static abstract bool Compare(T left, T right);
    public static abstract Vector128<T> Compare(Vector128<T> left, Vector128<T> right);
    public static abstract Vector256<T> Compare(Vector256<T> left, Vector256<T> right);
}

internal struct MinCalc<T> : IMinMaxCalc<T> where T : struct, IBinaryInteger<T>
{
    public static bool Compare(T left, T right) => left < right;
    public static Vector128<T> Compare(Vector128<T> left, Vector128<T> right)
    {
        if (Sse41.IsSupported)
        {
            if (typeof(T) == typeof(uint))
            {
                return (Vector128<T>)(object)Sse41.Min((Vector128<uint>)(object)left, (Vector128<uint>)(object)right);
            }
            else if (typeof(T) == typeof(int))
            {
                return (Vector128<T>)(object)Sse41.Min((Vector128<int>)(object)left, (Vector128<int>)(object)right);
            }
        }

        return Vector128.Min(left, right);
    }

    public static Vector256<T> Compare(Vector256<T> left, Vector256<T> right) => Vector256.Min(left, right);
}

internal struct MaxCalc<T> : IMinMaxCalc<T> where T : struct, IBinaryInteger<T>
{
    public static bool Compare(T left, T right) => left > right;
    public static Vector128<T> Compare(Vector128<T> left, Vector128<T> right)
    {
        if (Sse41.IsSupported)
        {
            if (typeof(T) == typeof(uint))
            {
                return (Vector128<T>)(object)Sse41.Max((Vector128<uint>)(object)left, (Vector128<uint>)(object)right);
            }
            else if (typeof(T) == typeof(int))
            {
                return (Vector128<T>)(object)Sse41.Max((Vector128<int>)(object)left, (Vector128<int>)(object)right);
            }
        }

        return Vector128.Max(left, right);
    }

    public static Vector256<T> Compare(Vector256<T> left, Vector256<T> right) => Vector256.Max(left, right);
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment