Skip to content

Instantly share code, notes, and snippets.

@xoofx
Last active July 12, 2023 17:08
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xoofx/b6c569d83678f85ac87a436f8e241917 to your computer and use it in GitHub Desktop.
Save xoofx/b6c569d83678f85ac87a436f8e241917 to your computer and use it in GitHub Desktop.
Optimized AVX2 version of finding the index of an integer from an array
// Discussion about https://mastodon.social/@denisio@dotnet.social/110644302160625267
// Optimized version using AVX. From 4x to 10x faster than a simple version.
// - nint for indexing
// - Unsafe.Add
// - Unrolling of 4 Vector256 + Or of the results to have only 1 branch per loop
// - Finding the local index within the Vector256 without a loop using AVX movemask
// Similar code can be done for Vector128
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
internal static class BatchFinder
{
/// <summary>
/// Find a value from a span of integers.
/// </summary>
/// <returns>Index of the value, or -1 if not found.</returns>
public static int Find_AVX2_256_Optimized(ReadOnlySpan<int> data, int value)
{
ref var pv = ref MemoryMarshal.GetReference(data);
var compareValue = Vector256.Create(value);
nint length = data.Length;
// Process 32 int (8 * 4) per loop (128 bytes)
nint bound1024 = length & ~(Vector256<int>.Count * 4 - 1);
nint i = 0;
for (; i < bound1024; i += Vector256<int>.Count * 4)
{
var r1 = Vector256.Equals(Unsafe.As<int, Vector256<int>>(ref Unsafe.Add(ref pv, i)), compareValue);
var r2 = Vector256.Equals(Unsafe.As<int, Vector256<int>>(ref Unsafe.Add(ref pv, i + Vector256<int>.Count)), compareValue);
var r3 = Vector256.Equals(Unsafe.As<int, Vector256<int>>(ref Unsafe.Add(ref pv, i + Vector256<int>.Count * 2)), compareValue);
var r4 = Vector256.Equals(Unsafe.As<int, Vector256<int>>(ref Unsafe.Add(ref pv, i + Vector256<int>.Count * 3)), compareValue);
var r5 = r1 | r2 | r3 | r4;
if (r5 != Vector256<int>.Zero)
{
// r12 = pack 32 to 16 of r1/r2
// r34 = pack 32 to 16 of r3/r4
// but it's working on 128 bit lanes, so we need to reorder them
Vector256<short> r12 = Avx2.PackSignedSaturate(r1, r2).AsInt16();
Vector256<short> r34 = Avx2.PackSignedSaturate(r3, r4).AsInt16();
// Reorder r12 & r34 correctly
r12 = Avx2.Permute4x64(r12.AsInt64(), 0b_11_01_10_00).AsInt16();
r34 = Avx2.Permute4x64(r34.AsInt64(), 0b_11_01_10_00).AsInt16();
// pack 16 to 8 of r12/r34
Vector256<sbyte> r = Avx2.PackSignedSaturate(r12, r34);
// Reorder r correctly
r = Avx2.Permute4x64(r.AsInt64(), 0b_11_01_10_00).AsSByte();
// Get the mask from <8 x byte>
var idx = Avx2.MoveMask(r);
return (int)(i + BitOperations.TrailingZeroCount(idx));
}
}
// Process 8 int per loop (32 bytes)
nint bound256 = length & ~(Vector256<int>.Count - 1);
for (; i < bound256; i += Vector256<int>.Count)
{
var r1 = Vector256.Equals(Unsafe.As<int, Vector256<int>>(ref Unsafe.Add(ref pv, i)), compareValue);
if (r1 != Vector256<int>.Zero)
{
// Get the mask from <8 x int> to byte
var rByte = Avx.MoveMask(r1.AsSingle());
// And get the local index
var idx = BitOperations.TrailingZeroCount((uint)rByte);
return (int)(i + idx);
}
}
// Process remaining
for (; i < length; i++)
{
if (Unsafe.Add(ref pv, i) == value)
return (int)i;
}
return -1;
}
}
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
using BenchmarkDotNet.Attributes;
using BenchmarkDotNet.Jobs;
namespace BenchFind;
[SimpleJob(RuntimeMoniker.Net70)]
public class BenchmarkFinder
{
[Params(32, 64, 128, 256, 512, 1024 * 1, 1024 * 2, 1024 * 3, 1024 * 4, 1024 * 5, 1024 * 6, 1024 * 7, 1024 * 8)]
public int N { get; set; }
int[] ints;
int findValue;
public BenchmarkFinder()
{
ints = Array.Empty<int>();
}
[GlobalSetup]
public void Setup()
{
ints = Enumerable.Range(0, N).ToArray();
findValue = ints[^3];
}
[Benchmark(Baseline = true)]
public int Find_Simple()
{
return FindSimple_(ints, findValue);
}
[Benchmark]
public int Find_Generic_128()
{
return Find_Generic_128_(ints, findValue);
}
[Benchmark]
public int Find_Generic_256()
{
return Find_Generic_256_(ints, findValue);
}
[Benchmark]
public int Find_AVX_256_Optimized()
{
return BatchFinder.Find_AVX_256_Optimized(ints, findValue);
}
static int FindSimple_(ReadOnlySpan<int> data, int value)
{
for (var i = 0; i < data.Length; i++)
if (data[i] == value)
return i;
return -1;
}
static int Find_Generic_128_(ReadOnlySpan<int> data, int value)
{
// In theory we should check for Vector128.IsHardwareAccelerated and dispatch
// accordingly, in practice here we don't to keep the code simple.
var vInts = MemoryMarshal.Cast<int, Vector128<int>>(data);
var compareValue = Vector128.Create(value);
var vectorLength = Vector128<int>.Count;
// Batch <4 x int> per loop
for (var i = 0; i < vInts.Length; i++)
{
var result = Vector128.Equals(vInts[i], compareValue);
if (result == Vector128<int>.Zero) continue;
for (var k = 0; k < vectorLength; k++)
if (result.GetElement(k) != 0)
return i * vectorLength + k;
}
// Scalar process of the remaining
for (var i = vInts.Length * vectorLength; i < data.Length; i++)
if (data[i] == value)
return i;
return -1;
}
static int Find_Generic_256_(ReadOnlySpan<int> data, int value)
{
// In theory we should check for Vector256.IsHardwareAccelerated and dispatch
// accordingly, in practice here we don't to keep the code simple.
var vInts = MemoryMarshal.Cast<int, Vector256<int>>(data);
var compareValue = Vector256.Create(value);
var vectorLength = Vector256<int>.Count;
// Batch <8 x int> per loop
for (var i = 0; i < vInts.Length; i++)
{
var result = Vector256.Equals(vInts[i], compareValue);
if (result == Vector256<int>.Zero) continue;
for (var k = 0; k < vectorLength; k++)
if (result.GetElement(k) != 0)
return i * vectorLength + k;
}
// Scalar process of the remaining
for (var i = vInts.Length * vectorLength; i < data.Length; i++)
if (data[i] == value)
return i;
return -1;
}
}
; Assembly listing for method BenchFind.BatchFinder:Find_AVX_256_Optimized(System.ReadOnlySpan`1[int],int):int
; Emitting BLENDED_CODE for X64 CPU with AVX - Windows
; optimized code
; rsp based frame
; fully interruptible
; No PGO data
; 2 inlinees with PGO data; 3 single block inlinees; 0 inlinees without PGO data
G_M000_IG01: ;; offset=0000H
C5F877 vzeroupper
G_M000_IG02: ;; offset=0003H
488B01 mov rax, bword ptr [rcx]
C4E1796EC2 vmovd xmm0, edx
C4E27D58C0 vpbroadcastd ymm0, ymm0
8B4908 mov ecx, dword ptr [rcx+08H]
4863C9 movsxd rcx, ecx
4C8BC1 mov r8, rcx
4983E0E0 and r8, -32
4533C9 xor r9d, r9d
4D85C0 test r8, r8
7E45 jle SHORT G_M000_IG04
align [0 bytes for IG03]
G_M000_IG03: ;; offset=0025H
4D8BD1 mov r10, r9
49C1E202 shl r10, 2
C4A17D760C10 vpcmpeqd ymm1, ymm0, ymmword ptr[rax+r10]
C4A17D76541020 vpcmpeqd ymm2, ymm0, ymmword ptr[rax+r10+20H]
C4A17D765C1040 vpcmpeqd ymm3, ymm0, ymmword ptr[rax+r10+40H]
C4A17D76641060 vpcmpeqd ymm4, ymm0, ymmword ptr[rax+r10+60H]
C4E175EBEA vpor ymm5, ymm1, ymm2
C4E155EBEB vpor ymm5, ymm5, ymm3
C4E155EBEC vpor ymm5, ymm5, ymm4
C4E27D17ED vptest ymm5, ymm5
7571 jne SHORT G_M000_IG14
4983C120 add r9, 32
4D3BC8 cmp r9, r8
7CBB jl SHORT G_M000_IG03
G_M000_IG04: ;; offset=006AH
4C8BC1 mov r8, rcx
4983E0F8 and r8, -8
4D3BC8 cmp r9, r8
7D20 jge SHORT G_M000_IG06
66660F1F840000000000 align [10 bytes for IG05]
G_M000_IG05: ;; offset=0080H
C4A17D761C88 vpcmpeqd ymm3, ymm0, ymmword ptr[rax+4*r9]
C4E27D17DB vptest ymm3, ymm3
7531 jne SHORT G_M000_IG12
4983C108 add r9, 8
4D3BC8 cmp r9, r8
7CEA jl SHORT G_M000_IG05
G_M000_IG06: ;; offset=0096H
4C3BC9 cmp r9, rcx
7D13 jge SHORT G_M000_IG08
0F1F440000 align [5 bytes for IG07]
G_M000_IG07: ;; offset=00A0H
42391488 cmp dword ptr [rax+4*r9], edx
7411 je SHORT G_M000_IG10
49FFC1 inc r9
4C3BC9 cmp r9, rcx
7CF2 jl SHORT G_M000_IG07
G_M000_IG08: ;; offset=00AEH
B8FFFFFFFF mov eax, -1
G_M000_IG09: ;; offset=00B3H
C5F877 vzeroupper
C3 ret
G_M000_IG10: ;; offset=00B7H
418BC1 mov eax, r9d
G_M000_IG11: ;; offset=00BAH
C5F877 vzeroupper
C3 ret
G_M000_IG12: ;; offset=00BEH
C5FC50C3 vmovmskps yrax, ymm3
F30FBCC0 tzcnt eax, eax
4103C1 add eax, r9d
G_M000_IG13: ;; offset=00C9H
C5F877 vzeroupper
C3 ret
G_M000_IG14: ;; offset=00CDH
C5E56BC4 vpackssdw ymm0, ymm3, ymm4
C4E3FD00C0D8 vpermq ymm0, ymm0, -40
C5F56BCA vpackssdw ymm1, ymm1, ymm2
C4E3FD00C9D8 vpermq ymm1, ymm1, -40
C5F563C0 vpacksswb ymm0, ymm1, ymm0
C4E3FD00C0D8 vpermq ymm0, ymm0, -40
C5FDD7C0 vpmovmskb eax, ymm0
F30FBCC0 tzcnt eax, eax
4103C1 add eax, r9d
G_M000_IG15: ;; offset=00F6H
C5F877 vzeroupper
C3 ret
; Total bytes of code 250
BenchmarkDotNet=v0.13.5, OS=Windows 11 (10.0.22621.1848/22H2/2022Update/SunValley2)
AMD Ryzen 9 5950X, 1 CPU, 32 logical and 16 physical cores
.NET SDK=8.0.100-preview.4.23260.5
  [Host]   : .NET 7.0.2 (7.0.222.60605), X64 RyuJIT AVX2
  .NET 7.0 : .NET 7.0.7 (7.0.723.27404), X64 RyuJIT AVX2

Job=.NET 7.0  Runtime=.NET 7.0  
Method N Mean Error StdDev Median Ratio RatioSD
Find_Simple 32 9.497 ns 0.2087 ns 0.2993 ns 9.555 ns 1.00 0.00
Find_Generic_128 32 4.572 ns 0.0025 ns 0.0020 ns 4.572 ns 0.48 0.01
Find_Generic_256 32 7.308 ns 0.5040 ns 1.4861 ns 7.455 ns 0.78 0.17
Find_AVX2_256_Optimized 32 2.398 ns 0.0085 ns 0.0075 ns 2.397 ns 0.25 0.01
Find_Simple 64 16.557 ns 0.3431 ns 0.4580 ns 16.622 ns 1.00 0.00
Find_Generic_128 64 8.531 ns 0.0269 ns 0.0238 ns 8.543 ns 0.52 0.01
Find_Generic_256 64 6.626 ns 0.0900 ns 0.0752 ns 6.589 ns 0.41 0.01
Find_AVX2_256_Optimized 64 2.936 ns 0.0161 ns 0.0143 ns 2.935 ns 0.18 0.00
Find_Simple 128 35.024 ns 0.3709 ns 0.3097 ns 35.064 ns 1.00 0.00
Find_Generic_128 128 15.533 ns 0.0437 ns 0.0341 ns 15.546 ns 0.44 0.00
Find_Generic_256 128 10.098 ns 0.0235 ns 0.0208 ns 10.096 ns 0.29 0.00
Find_AVX2_256_Optimized 128 5.223 ns 0.0132 ns 0.0117 ns 5.221 ns 0.15 0.00
Find_Simple 256 64.626 ns 1.1894 ns 1.1126 ns 64.496 ns 1.00 0.00
Find_Generic_128 256 35.388 ns 0.0965 ns 0.0855 ns 35.392 ns 0.55 0.01
Find_Generic_256 256 16.866 ns 0.0433 ns 0.0384 ns 16.881 ns 0.26 0.00
Find_AVX2_256_Optimized 256 10.103 ns 0.0524 ns 0.0491 ns 10.131 ns 0.16 0.00
Find_Simple 512 120.302 ns 1.6310 ns 1.5256 ns 119.891 ns 1.00 0.00
Find_Generic_128 512 63.086 ns 0.1117 ns 0.1044 ns 63.058 ns 0.52 0.01
Find_Generic_256 512 39.328 ns 0.8087 ns 2.3845 ns 38.056 ns 0.33 0.02
Find_AVX2_256_Optimized 512 15.840 ns 0.0257 ns 0.0215 ns 15.842 ns 0.13 0.00
Find_Simple 1024 232.160 ns 1.9791 ns 1.8512 ns 232.436 ns 1.00 0.00
Find_Generic_128 1024 119.290 ns 0.2275 ns 0.2017 ns 119.350 ns 0.51 0.00
Find_Generic_256 1024 65.283 ns 0.1176 ns 0.1100 ns 65.236 ns 0.28 0.00
Find_AVX2_256_Optimized 1024 28.667 ns 0.0405 ns 0.0359 ns 28.656 ns 0.12 0.00
Find_Simple 2048 454.078 ns 3.3386 ns 3.1229 ns 453.596 ns 1.00 0.00
Find_Generic_128 2048 230.883 ns 0.1299 ns 0.1085 ns 230.879 ns 0.51 0.00
Find_Generic_256 2048 121.295 ns 0.1702 ns 0.1421 ns 121.315 ns 0.27 0.00
Find_AVX2_256_Optimized 2048 51.331 ns 0.2110 ns 0.1973 ns 51.436 ns 0.11 0.00
Find_Simple 3072 680.201 ns 6.9948 ns 6.5430 ns 680.013 ns 1.00 0.00
Find_Generic_128 3072 342.390 ns 0.7282 ns 0.6812 ns 342.483 ns 0.50 0.01
Find_Generic_256 3072 176.961 ns 0.1075 ns 0.0839 ns 176.933 ns 0.26 0.00
Find_AVX2_256_Optimized 3072 71.984 ns 0.0830 ns 0.0776 ns 71.974 ns 0.11 0.00
Find_Simple 4096 894.287 ns 3.6022 ns 3.0080 ns 894.541 ns 1.00 0.00
Find_Generic_128 4096 454.083 ns 0.3423 ns 0.3035 ns 454.020 ns 0.51 0.00
Find_Generic_256 4096 234.391 ns 2.5310 ns 2.1135 ns 234.057 ns 0.26 0.00
Find_AVX2_256_Optimized 4096 113.839 ns 0.5868 ns 0.5489 ns 113.575 ns 0.13 0.00
Find_Simple 5120 1,122.998 ns 9.4198 ns 8.8112 ns 1,119.175 ns 1.00 0.00
Find_Generic_128 5120 564.274 ns 1.4472 ns 1.3537 ns 564.942 ns 0.50 0.00
Find_Generic_256 5120 288.396 ns 0.4312 ns 0.4034 ns 288.519 ns 0.26 0.00
Find_AVX2_256_Optimized 5120 138.652 ns 0.6755 ns 0.6319 ns 139.035 ns 0.12 0.00
Find_Simple 6144 1,342.714 ns 7.1180 ns 6.3099 ns 1,339.940 ns 1.00 0.00
Find_Generic_128 6144 676.572 ns 0.3413 ns 0.2665 ns 676.500 ns 0.50 0.00
Find_Generic_256 6144 343.961 ns 0.7668 ns 0.7173 ns 344.316 ns 0.26 0.00
Find_AVX2_256_Optimized 6144 135.806 ns 0.1486 ns 0.1390 ns 135.735 ns 0.10 0.00
Find_Simple 7168 1,558.207 ns 6.6785 ns 5.9203 ns 1,558.427 ns 1.00 0.00
Find_Generic_128 7168 789.545 ns 1.2981 ns 1.2143 ns 789.045 ns 0.51 0.00
Find_Generic_256 7168 403.446 ns 3.4413 ns 3.2190 ns 401.874 ns 0.26 0.00
Find_AVX2_256_Optimized 7168 189.534 ns 0.5474 ns 0.5120 ns 189.874 ns 0.12 0.00
Find_Simple 8192 1,796.290 ns 13.4828 ns 12.6118 ns 1,792.636 ns 1.00 0.00
Find_Generic_128 8192 901.999 ns 1.8796 ns 1.7582 ns 902.707 ns 0.50 0.00
Find_Generic_256 8192 465.352 ns 5.0166 ns 4.6925 ns 462.971 ns 0.26 0.00
Find_AVX2_256_Optimized 8192 183.790 ns 0.8620 ns 0.8063 ns 183.384 ns 0.10 0.00
@manofstick
Copy link

Tiny tweak!

Can move the PackSignedSaturate out of the if, as replacement for the ORs. Should be same cost, and then avoid the need to do it later.

    Vector256<short> r12 = Avx2.PackSignedSaturate(r1, r2).AsInt16();
    Vector256<short> r34 = Avx2.PackSignedSaturate(r3, r4).AsInt16();

    var r5 = r12 | r34;
    if (r5 != Vector256<short>.Zero)
    {

Seems to do what I expect on my old CPU...

BenchmarkDotNet v0.13.6, Windows 10 (10.0.19045.3086/22H2/2022Update)
Intel Core i7-6770HQ CPU 2.60GHz (Skylake), 1 CPU, 8 logical and 4 physical cores
.NET SDK 7.0.304
[Host] : .NET 7.0.7 (7.0.723.27404), X64 RyuJIT AVX2
.NET 7.0 : .NET 7.0.7 (7.0.723.27404), X64 RyuJIT AVX2

Job=.NET 7.0 Runtime=.NET 7.0

Method N Mean Error StdDev Ratio RatioSD
Find_AVX_256_Optimized 32 4.027 ns 0.0668 ns 0.0592 ns 1.00 0.00
Find_AVX_256_OptimizedX 32 3.856 ns 0.0420 ns 0.0393 ns 0.96 0.02
Find_AVX_256_Optimized 64 5.894 ns 0.0968 ns 0.0905 ns 1.00 0.00
Find_AVX_256_OptimizedX 64 5.674 ns 0.0735 ns 0.0688 ns 0.96 0.02
Find_AVX_256_Optimized 128 9.717 ns 0.0399 ns 0.0373 ns 1.00 0.00
Find_AVX_256_OptimizedX 128 9.351 ns 0.0429 ns 0.0381 ns 0.96 0.01
Find_AVX_256_Optimized 256 15.419 ns 0.0682 ns 0.0605 ns 1.00 0.00
Find_AVX_256_OptimizedX 256 15.313 ns 0.0509 ns 0.0476 ns 0.99 0.00
Find_AVX_256_Optimized 512 27.756 ns 0.1678 ns 0.1487 ns 1.00 0.00
Find_AVX_256_OptimizedX 512 27.735 ns 0.1148 ns 0.1017 ns 1.00 0.01
Find_AVX_256_Optimized 1024 51.82 ns 0.257 ns 0.228 ns 1.00
Find_AVX_256_OptimizedX 1024 51.81 ns 0.148 ns 0.139 ns 1.00
Find_AVX_256_Optimized 2048 108.846 ns 0.7748 ns 0.7248 ns 1.00 0.00
Find_AVX_256_OptimizedX 2048 103.212 ns 1.9822 ns 1.8541 ns 0.95 0.02
Find_AVX_256_Optimized 3072 165.297 ns 0.9947 ns 0.9304 ns 1.00 0.00
Find_AVX_256_OptimizedX 3072 163.884 ns 0.6934 ns 0.6486 ns 0.99 0.01
Find_AVX_256_Optimized 4096 214.716 ns 0.8964 ns 0.8385 ns 1.00 0.00
Find_AVX_256_OptimizedX 4096 213.264 ns 0.9748 ns 0.8641 ns 0.99 0.00
Find_AVX_256_Optimized 5120 264.547 ns 1.4912 ns 1.3949 ns 1.00 0.00
Find_AVX_256_OptimizedX 5120 263.055 ns 1.0962 ns 1.0254 ns 0.99 0.01
Find_AVX_256_Optimized 6144 315.956 ns 2.5487 ns 2.1283 ns 1.00 0.00
Find_AVX_256_OptimizedX 6144 311.699 ns 2.4050 ns 2.2496 ns 0.99 0.01
Find_AVX_256_Optimized 7168 363.677 ns 2.2572 ns 2.0009 ns 1.00 0.00
Find_AVX_256_OptimizedX 7168 362.104 ns 2.3084 ns 2.1593 ns 0.99 0.01
Find_AVX_256_Optimized 8192 425.753 ns 3.2725 ns 3.0611 ns 1.00 0.00
Find_AVX_256_OptimizedX 8192 420.958 ns 2.3809 ns 2.2271 ns 0.99 0.01

@xoofx
Copy link
Author

xoofx commented Jul 12, 2023

Can move the PackSignedSaturate out of the if, as replacement for the ORs. Should be same cost, and then avoid the need to do it later

Yep, indeed, I forgot to do that after changing the order of the pack and permute - before the ORs were less costly, but now it makes completely sense to change it 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment