Created
March 2, 2024 03:24
-
-
Save DevJohnC/ec8a07eb2f1370b7a9a601b6d5a37254 to your computer and use it in GitHub Desktop.
DotProduct operation using SIMD
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System.Runtime.CompilerServices; | |
using System.Runtime.InteropServices; | |
namespace VectorSearch.Memory; | |
public readonly ref struct MemoryRef<T> | |
{ | |
private readonly ref T _value; | |
public MemoryRef(ref T value) | |
{ | |
_value = ref value; | |
} | |
[MethodImpl(MethodImplOptions.AggressiveInlining)] | |
public ref T GetElement(nint index) => ref Unsafe.Add(ref _value, index); | |
public static MemoryRef<T> GetReference(ReadOnlySpan<T> span) | |
{ | |
return new MemoryRef<T>(ref MemoryMarshal.GetReference(span)); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using VectorSearch.Memory; | |
namespace VectorSearch.Math; | |
public static class Scalar | |
{ | |
public static double DotProduct(ReadOnlySpan<double> vector1, ReadOnlySpan<double> vector2) | |
{ | |
VectorHelper.GuardMismatchedSizes(vector1, vector2); | |
var vector1Ref = MemoryRef<double>.GetReference(vector1); | |
var vector2Ref = MemoryRef<double>.GetReference(vector2); | |
var dotProduct = 0d; | |
nint dimensionCount = vector1.Length; | |
nint i = 0; | |
for (; i < dimensionCount; i++) | |
{ | |
ref var vector1Value = ref vector1Ref.GetElement(i); | |
ref var vector2Value = ref vector2Ref.GetElement(i); | |
dotProduct += vector1Value * vector2Value; | |
} | |
return dotProduct; | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System.Numerics; | |
using System.Runtime.CompilerServices; | |
using VectorSearch.Memory; | |
namespace VectorSearch.Math; | |
public static class SIMD | |
{ | |
public static double DotProduct(ReadOnlySpan<double> vector1, ReadOnlySpan<double> vector2) | |
{ | |
const nint operationsPerIteration = 2; | |
VectorHelper.GuardMismatchedSizes(vector1, vector2); | |
if (!Vector.IsHardwareAccelerated || !Vector<double>.IsSupported || vector1.Length < Vector<double>.Count) | |
return Scalar.DotProduct(vector1, vector2); | |
var vector1Ref = MemoryRef<double>.GetReference(vector1); | |
var vector2Ref = MemoryRef<double>.GetReference(vector2); | |
var dotProduct = 0d; | |
nint dimensionCount = vector1.Length; | |
nint i = 0; | |
nint step = Vector<double>.Count; | |
nint bound = dimensionCount & ~(step * operationsPerIteration - 1); | |
for (; i < bound; i += step * operationsPerIteration) | |
{ | |
var simdRegister1Vector1 = Unsafe.As<double, Vector<double>>(ref vector1Ref.GetElement(i)); | |
var simdRegister1Vector2 = Unsafe.As<double, Vector<double>>(ref vector2Ref.GetElement(i)); | |
var simdRegister2Vector1 = Unsafe.As<double, Vector<double>>(ref vector1Ref.GetElement(i + step)); | |
var simdRegister2Vector2 = Unsafe.As<double, Vector<double>>(ref vector2Ref.GetElement(i + step)); | |
dotProduct += Vector.Dot(simdRegister1Vector1, simdRegister1Vector2) + | |
Vector.Dot(simdRegister2Vector1, simdRegister2Vector2); | |
} | |
for (; i < dimensionCount; i++) | |
{ | |
var simdRegister1Vector1 = Unsafe.As<double, Vector<double>>(ref vector1Ref.GetElement(i)); | |
var simdRegister1Vector2 = Unsafe.As<double, Vector<double>>(ref vector2Ref.GetElement(i)); | |
dotProduct += Vector.Dot(simdRegister1Vector1, simdRegister1Vector2); | |
} | |
return dotProduct; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment