Skip to content

Instantly share code, notes, and snippets.

@DevJohnC
Created March 2, 2024 03:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save DevJohnC/ec8a07eb2f1370b7a9a601b6d5a37254 to your computer and use it in GitHub Desktop.
Save DevJohnC/ec8a07eb2f1370b7a9a601b6d5a37254 to your computer and use it in GitHub Desktop.
DotProduct operation using SIMD
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
namespace VectorSearch.Memory;
public readonly ref struct MemoryRef<T>
{
private readonly ref T _value;
public MemoryRef(ref T value)
{
_value = ref value;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public ref T GetElement(nint index) => ref Unsafe.Add(ref _value, index);
public static MemoryRef<T> GetReference(ReadOnlySpan<T> span)
{
return new MemoryRef<T>(ref MemoryMarshal.GetReference(span));
}
}
using VectorSearch.Memory;
namespace VectorSearch.Math;
public static class Scalar
{
public static double DotProduct(ReadOnlySpan<double> vector1, ReadOnlySpan<double> vector2)
{
VectorHelper.GuardMismatchedSizes(vector1, vector2);
var vector1Ref = MemoryRef<double>.GetReference(vector1);
var vector2Ref = MemoryRef<double>.GetReference(vector2);
var dotProduct = 0d;
nint dimensionCount = vector1.Length;
nint i = 0;
for (; i < dimensionCount; i++)
{
ref var vector1Value = ref vector1Ref.GetElement(i);
ref var vector2Value = ref vector2Ref.GetElement(i);
dotProduct += vector1Value * vector2Value;
}
return dotProduct;
}
}
using System.Numerics;
using System.Runtime.CompilerServices;
using VectorSearch.Memory;
namespace VectorSearch.Math;
public static class SIMD
{
public static double DotProduct(ReadOnlySpan<double> vector1, ReadOnlySpan<double> vector2)
{
const nint operationsPerIteration = 2;
VectorHelper.GuardMismatchedSizes(vector1, vector2);
if (!Vector.IsHardwareAccelerated || !Vector<double>.IsSupported || vector1.Length < Vector<double>.Count)
return Scalar.DotProduct(vector1, vector2);
var vector1Ref = MemoryRef<double>.GetReference(vector1);
var vector2Ref = MemoryRef<double>.GetReference(vector2);
var dotProduct = 0d;
nint dimensionCount = vector1.Length;
nint i = 0;
nint step = Vector<double>.Count;
nint bound = dimensionCount & ~(step * operationsPerIteration - 1);
for (; i < bound; i += step * operationsPerIteration)
{
var simdRegister1Vector1 = Unsafe.As<double, Vector<double>>(ref vector1Ref.GetElement(i));
var simdRegister1Vector2 = Unsafe.As<double, Vector<double>>(ref vector2Ref.GetElement(i));
var simdRegister2Vector1 = Unsafe.As<double, Vector<double>>(ref vector1Ref.GetElement(i + step));
var simdRegister2Vector2 = Unsafe.As<double, Vector<double>>(ref vector2Ref.GetElement(i + step));
dotProduct += Vector.Dot(simdRegister1Vector1, simdRegister1Vector2) +
Vector.Dot(simdRegister2Vector1, simdRegister2Vector2);
}
for (; i < dimensionCount; i++)
{
var simdRegister1Vector1 = Unsafe.As<double, Vector<double>>(ref vector1Ref.GetElement(i));
var simdRegister1Vector2 = Unsafe.As<double, Vector<double>>(ref vector2Ref.GetElement(i));
dotProduct += Vector.Dot(simdRegister1Vector1, simdRegister1Vector2);
}
return dotProduct;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment