Created
February 14, 2020 07:20
-
-
Save Turnerj/76b69b10acefa5d9dde99f362b7f13f8 to your computer and use it in GitHub Desktop.
Second attempt - probably just as bad as my first, code is neater though.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#if NETCOREAPP3_0 | |
using System; | |
using System.Buffers; | |
using System.Runtime.CompilerServices; | |
using System.Runtime.Intrinsics.X86; | |
using System.Runtime.Intrinsics; | |
namespace Quickenshtein | |
{ | |
/// <summary> | |
/// Quick Levenshtein Distance Calculator | |
/// </summary> | |
public static partial class Levenshtein | |
{ | |
[MethodImpl(MethodImplOptions.AggressiveInlining)] | |
private static unsafe int CalculateDistance_Intrinsics(ReadOnlySpan<char> source, ReadOnlySpan<char> target) | |
{ | |
var sourceLength = source.Length; | |
var targetLength = target.Length; | |
var arrayPool = ArrayPool<ushort>.Shared; | |
var pooledArray = arrayPool.Rent(targetLength); | |
Span<ushort> previousRow = pooledArray; | |
//ArrayPool values are sometimes bigger than allocated, let's trim our span to exactly what we use | |
previousRow = previousRow.Slice(0, targetLength); | |
FillRow(previousRow); | |
var negationVector = Vector128.Create((ushort)1); | |
var allOnesVector128 = negationVector; | |
fixed (char* targetPtr = target) | |
fixed (ushort* previousRowPtr = previousRow) | |
{ | |
var targetUShortPtr = (ushort*)targetPtr; | |
for (ushort rowIndex = 0; rowIndex < sourceLength; rowIndex++) | |
{ | |
var lastSubstitutionCost = rowIndex; | |
var lastInsertionCost = (ushort)(rowIndex + 1); | |
var sourcePrevChar = source[rowIndex]; | |
var columnIndex = 0; | |
ushort lastDeletionCost; | |
ushort localCost; | |
var rowColumnsRemaining = targetLength; | |
var sourcePrevCharVector = Vector128.Create(sourcePrevChar); | |
if (rowColumnsRemaining >= 8) | |
{ | |
var substitutionCostsVector = Vector128.Create( | |
lastSubstitutionCost, | |
previousRowPtr[0], | |
previousRowPtr[1], | |
previousRowPtr[2], | |
previousRowPtr[3], | |
previousRowPtr[4], | |
previousRowPtr[5], | |
previousRowPtr[6]); | |
while (rowColumnsRemaining >= 8) | |
{ | |
rowColumnsRemaining -= 8; | |
var targetCharVector = Sse3.LoadDquVector128(targetUShortPtr + columnIndex); | |
var notEqualVector = Sse2.AndNot(Sse2.CompareEqual(sourcePrevCharVector, targetCharVector), negationVector); | |
var deletionCostsVector = Sse3.LoadDquVector128(previousRowPtr + columnIndex); | |
var adjustedDeletionCostsVector = Sse2.Add(deletionCostsVector, allOnesVector128); | |
var adjustedSubstitutionCostsVector = Sse2.Add(notEqualVector, substitutionCostsVector); | |
lastInsertionCost = Sse41.MinHorizontal(Vector128.Create( | |
(ushort)(lastInsertionCost + 1), | |
adjustedDeletionCostsVector.GetElement(0), | |
adjustedSubstitutionCostsVector.GetElement(0), | |
ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue | |
)).GetElement(0); | |
previousRowPtr[columnIndex++] = lastInsertionCost; | |
lastInsertionCost = Sse41.MinHorizontal(Vector128.Create( | |
(ushort)(lastInsertionCost + 1), | |
adjustedDeletionCostsVector.GetElement(1), | |
adjustedSubstitutionCostsVector.GetElement(1), | |
ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue | |
)).GetElement(0); | |
previousRowPtr[columnIndex++] = lastInsertionCost; | |
lastInsertionCost = Sse41.MinHorizontal(Vector128.Create( | |
(ushort)(lastInsertionCost + 1), | |
adjustedDeletionCostsVector.GetElement(2), | |
adjustedSubstitutionCostsVector.GetElement(2), | |
ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue | |
)).GetElement(0); | |
previousRowPtr[columnIndex++] = lastInsertionCost; | |
lastInsertionCost = Sse41.MinHorizontal(Vector128.Create( | |
(ushort)(lastInsertionCost + 1), | |
adjustedDeletionCostsVector.GetElement(3), | |
adjustedSubstitutionCostsVector.GetElement(3), | |
ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue | |
)).GetElement(0); | |
previousRowPtr[columnIndex++] = lastInsertionCost; | |
lastInsertionCost = Sse41.MinHorizontal(Vector128.Create( | |
(ushort)(lastInsertionCost + 1), | |
adjustedDeletionCostsVector.GetElement(4), | |
adjustedSubstitutionCostsVector.GetElement(4), | |
ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue | |
)).GetElement(0); | |
previousRowPtr[columnIndex++] = lastInsertionCost; | |
lastInsertionCost = Sse41.MinHorizontal(Vector128.Create( | |
(ushort)(lastInsertionCost + 1), | |
adjustedDeletionCostsVector.GetElement(5), | |
adjustedSubstitutionCostsVector.GetElement(5), | |
ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue | |
)).GetElement(0); | |
previousRowPtr[columnIndex++] = lastInsertionCost; | |
lastInsertionCost = Sse41.MinHorizontal(Vector128.Create( | |
(ushort)(lastInsertionCost + 1), | |
adjustedDeletionCostsVector.GetElement(6), | |
adjustedSubstitutionCostsVector.GetElement(6), | |
ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue | |
)).GetElement(0); | |
previousRowPtr[columnIndex++] = lastInsertionCost; | |
lastInsertionCost = Sse41.MinHorizontal(Vector128.Create( | |
(ushort)(lastInsertionCost + 1), | |
adjustedDeletionCostsVector.GetElement(7), | |
adjustedSubstitutionCostsVector.GetElement(7), | |
ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue | |
)).GetElement(0); | |
previousRowPtr[columnIndex++] = lastInsertionCost; | |
lastSubstitutionCost = deletionCostsVector.GetElement(7); | |
if (rowColumnsRemaining >= 8) | |
{ | |
substitutionCostsVector = Sse3.LoadDquVector128(previousRowPtr + columnIndex - 1) | |
.WithElement(0, lastSubstitutionCost); | |
} | |
} | |
} | |
if (rowColumnsRemaining > 4) | |
{ | |
rowColumnsRemaining -= 4; | |
localCost = lastSubstitutionCost; | |
lastDeletionCost = previousRowPtr[columnIndex]; | |
if (sourcePrevChar != targetPtr[columnIndex]) | |
{ | |
localCost = Math.Min(lastInsertionCost, localCost); | |
localCost = Math.Min(lastDeletionCost, localCost); | |
localCost++; | |
} | |
lastInsertionCost = localCost; | |
previousRowPtr[columnIndex++] = localCost; | |
lastSubstitutionCost = lastDeletionCost; | |
localCost = lastSubstitutionCost; | |
lastDeletionCost = previousRowPtr[columnIndex]; | |
if (sourcePrevChar != targetPtr[columnIndex]) | |
{ | |
localCost = Math.Min(lastInsertionCost, localCost); | |
localCost = Math.Min(lastDeletionCost, localCost); | |
localCost++; | |
} | |
lastInsertionCost = localCost; | |
previousRowPtr[columnIndex++] = localCost; | |
lastSubstitutionCost = lastDeletionCost; | |
localCost = lastSubstitutionCost; | |
lastDeletionCost = previousRowPtr[columnIndex]; | |
if (sourcePrevChar != targetPtr[columnIndex]) | |
{ | |
localCost = Math.Min(lastInsertionCost, localCost); | |
localCost = Math.Min(lastDeletionCost, localCost); | |
localCost++; | |
} | |
lastInsertionCost = localCost; | |
previousRowPtr[columnIndex++] = localCost; | |
lastSubstitutionCost = lastDeletionCost; | |
localCost = lastSubstitutionCost; | |
lastDeletionCost = previousRowPtr[columnIndex]; | |
if (sourcePrevChar != targetPtr[columnIndex]) | |
{ | |
localCost = Math.Min(lastInsertionCost, localCost); | |
localCost = Math.Min(lastDeletionCost, localCost); | |
localCost++; | |
} | |
lastInsertionCost = localCost; | |
previousRowPtr[columnIndex++] = localCost; | |
lastSubstitutionCost = lastDeletionCost; | |
} | |
while (rowColumnsRemaining > 0) | |
{ | |
rowColumnsRemaining--; | |
localCost = lastSubstitutionCost; | |
lastDeletionCost = previousRowPtr[columnIndex]; | |
if (sourcePrevChar != targetPtr[columnIndex]) | |
{ | |
localCost = Math.Min(lastInsertionCost, localCost); | |
localCost = Math.Min(lastDeletionCost, localCost); | |
localCost++; | |
} | |
lastInsertionCost = localCost; | |
previousRowPtr[columnIndex++] = localCost; | |
lastSubstitutionCost = lastDeletionCost; | |
} | |
} | |
} | |
var result = previousRow[targetLength - 1]; | |
arrayPool.Return(pooledArray); | |
return result; | |
} | |
} | |
} | |
#endif |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment