Skip to content

Instantly share code, notes, and snippets.

@Turnerj
Created February 14, 2020 07:20
Show Gist options
  • Save Turnerj/76b69b10acefa5d9dde99f362b7f13f8 to your computer and use it in GitHub Desktop.
Save Turnerj/76b69b10acefa5d9dde99f362b7f13f8 to your computer and use it in GitHub Desktop.
Second attempt - probably just as bad as my first, code is neater though.
#if NETCOREAPP3_0
using System;
using System.Buffers;
using System.Runtime.CompilerServices;
using System.Runtime.Intrinsics.X86;
using System.Runtime.Intrinsics;
namespace Quickenshtein
{
/// <summary>
/// Quick Levenshtein Distance Calculator
/// </summary>
public static partial class Levenshtein
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe int CalculateDistance_Intrinsics(ReadOnlySpan<char> source, ReadOnlySpan<char> target)
{
var sourceLength = source.Length;
var targetLength = target.Length;
var arrayPool = ArrayPool<ushort>.Shared;
var pooledArray = arrayPool.Rent(targetLength);
Span<ushort> previousRow = pooledArray;
//ArrayPool values are sometimes bigger than allocated, let's trim our span to exactly what we use
previousRow = previousRow.Slice(0, targetLength);
FillRow(previousRow);
var negationVector = Vector128.Create((ushort)1);
var allOnesVector128 = negationVector;
fixed (char* targetPtr = target)
fixed (ushort* previousRowPtr = previousRow)
{
var targetUShortPtr = (ushort*)targetPtr;
for (ushort rowIndex = 0; rowIndex < sourceLength; rowIndex++)
{
var lastSubstitutionCost = rowIndex;
var lastInsertionCost = (ushort)(rowIndex + 1);
var sourcePrevChar = source[rowIndex];
var columnIndex = 0;
ushort lastDeletionCost;
ushort localCost;
var rowColumnsRemaining = targetLength;
var sourcePrevCharVector = Vector128.Create(sourcePrevChar);
if (rowColumnsRemaining >= 8)
{
var substitutionCostsVector = Vector128.Create(
lastSubstitutionCost,
previousRowPtr[0],
previousRowPtr[1],
previousRowPtr[2],
previousRowPtr[3],
previousRowPtr[4],
previousRowPtr[5],
previousRowPtr[6]);
while (rowColumnsRemaining >= 8)
{
rowColumnsRemaining -= 8;
var targetCharVector = Sse3.LoadDquVector128(targetUShortPtr + columnIndex);
var notEqualVector = Sse2.AndNot(Sse2.CompareEqual(sourcePrevCharVector, targetCharVector), negationVector);
var deletionCostsVector = Sse3.LoadDquVector128(previousRowPtr + columnIndex);
var adjustedDeletionCostsVector = Sse2.Add(deletionCostsVector, allOnesVector128);
var adjustedSubstitutionCostsVector = Sse2.Add(notEqualVector, substitutionCostsVector);
lastInsertionCost = Sse41.MinHorizontal(Vector128.Create(
(ushort)(lastInsertionCost + 1),
adjustedDeletionCostsVector.GetElement(0),
adjustedSubstitutionCostsVector.GetElement(0),
ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue
)).GetElement(0);
previousRowPtr[columnIndex++] = lastInsertionCost;
lastInsertionCost = Sse41.MinHorizontal(Vector128.Create(
(ushort)(lastInsertionCost + 1),
adjustedDeletionCostsVector.GetElement(1),
adjustedSubstitutionCostsVector.GetElement(1),
ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue
)).GetElement(0);
previousRowPtr[columnIndex++] = lastInsertionCost;
lastInsertionCost = Sse41.MinHorizontal(Vector128.Create(
(ushort)(lastInsertionCost + 1),
adjustedDeletionCostsVector.GetElement(2),
adjustedSubstitutionCostsVector.GetElement(2),
ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue
)).GetElement(0);
previousRowPtr[columnIndex++] = lastInsertionCost;
lastInsertionCost = Sse41.MinHorizontal(Vector128.Create(
(ushort)(lastInsertionCost + 1),
adjustedDeletionCostsVector.GetElement(3),
adjustedSubstitutionCostsVector.GetElement(3),
ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue
)).GetElement(0);
previousRowPtr[columnIndex++] = lastInsertionCost;
lastInsertionCost = Sse41.MinHorizontal(Vector128.Create(
(ushort)(lastInsertionCost + 1),
adjustedDeletionCostsVector.GetElement(4),
adjustedSubstitutionCostsVector.GetElement(4),
ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue
)).GetElement(0);
previousRowPtr[columnIndex++] = lastInsertionCost;
lastInsertionCost = Sse41.MinHorizontal(Vector128.Create(
(ushort)(lastInsertionCost + 1),
adjustedDeletionCostsVector.GetElement(5),
adjustedSubstitutionCostsVector.GetElement(5),
ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue
)).GetElement(0);
previousRowPtr[columnIndex++] = lastInsertionCost;
lastInsertionCost = Sse41.MinHorizontal(Vector128.Create(
(ushort)(lastInsertionCost + 1),
adjustedDeletionCostsVector.GetElement(6),
adjustedSubstitutionCostsVector.GetElement(6),
ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue
)).GetElement(0);
previousRowPtr[columnIndex++] = lastInsertionCost;
lastInsertionCost = Sse41.MinHorizontal(Vector128.Create(
(ushort)(lastInsertionCost + 1),
adjustedDeletionCostsVector.GetElement(7),
adjustedSubstitutionCostsVector.GetElement(7),
ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue, ushort.MaxValue
)).GetElement(0);
previousRowPtr[columnIndex++] = lastInsertionCost;
lastSubstitutionCost = deletionCostsVector.GetElement(7);
if (rowColumnsRemaining >= 8)
{
substitutionCostsVector = Sse3.LoadDquVector128(previousRowPtr + columnIndex - 1)
.WithElement(0, lastSubstitutionCost);
}
}
}
if (rowColumnsRemaining > 4)
{
rowColumnsRemaining -= 4;
localCost = lastSubstitutionCost;
lastDeletionCost = previousRowPtr[columnIndex];
if (sourcePrevChar != targetPtr[columnIndex])
{
localCost = Math.Min(lastInsertionCost, localCost);
localCost = Math.Min(lastDeletionCost, localCost);
localCost++;
}
lastInsertionCost = localCost;
previousRowPtr[columnIndex++] = localCost;
lastSubstitutionCost = lastDeletionCost;
localCost = lastSubstitutionCost;
lastDeletionCost = previousRowPtr[columnIndex];
if (sourcePrevChar != targetPtr[columnIndex])
{
localCost = Math.Min(lastInsertionCost, localCost);
localCost = Math.Min(lastDeletionCost, localCost);
localCost++;
}
lastInsertionCost = localCost;
previousRowPtr[columnIndex++] = localCost;
lastSubstitutionCost = lastDeletionCost;
localCost = lastSubstitutionCost;
lastDeletionCost = previousRowPtr[columnIndex];
if (sourcePrevChar != targetPtr[columnIndex])
{
localCost = Math.Min(lastInsertionCost, localCost);
localCost = Math.Min(lastDeletionCost, localCost);
localCost++;
}
lastInsertionCost = localCost;
previousRowPtr[columnIndex++] = localCost;
lastSubstitutionCost = lastDeletionCost;
localCost = lastSubstitutionCost;
lastDeletionCost = previousRowPtr[columnIndex];
if (sourcePrevChar != targetPtr[columnIndex])
{
localCost = Math.Min(lastInsertionCost, localCost);
localCost = Math.Min(lastDeletionCost, localCost);
localCost++;
}
lastInsertionCost = localCost;
previousRowPtr[columnIndex++] = localCost;
lastSubstitutionCost = lastDeletionCost;
}
while (rowColumnsRemaining > 0)
{
rowColumnsRemaining--;
localCost = lastSubstitutionCost;
lastDeletionCost = previousRowPtr[columnIndex];
if (sourcePrevChar != targetPtr[columnIndex])
{
localCost = Math.Min(lastInsertionCost, localCost);
localCost = Math.Min(lastDeletionCost, localCost);
localCost++;
}
lastInsertionCost = localCost;
previousRowPtr[columnIndex++] = localCost;
lastSubstitutionCost = lastDeletionCost;
}
}
}
var result = previousRow[targetLength - 1];
arrayPool.Return(pooledArray);
return result;
}
}
}
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment