Created
March 6, 2020 08:36
-
-
Save Turnerj/c3f1962825f28f2e7798f85d25c3a137 to your computer and use it in GitHub Desktop.
Third attempt - still not as good as still having branches
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/// <summary> | |
/// Using SSE4.1, calculates the costs for the virtual matrix. | |
/// This performs a 4x outer loop unrolling allowing fewer lookups of target character and deletion cost data across the rows. | |
/// </summary> | |
/// <param name="previousRowPtr"></param> | |
/// <param name="source"></param> | |
/// <param name="rowIndex"></param> | |
/// <param name="targetPtr"></param> | |
/// <param name="targetLength"></param> | |
private static unsafe void CalculateRows_4Rows_Sse41(int* previousRowPtr, ReadOnlySpan<char> source, ref int rowIndex, char* targetPtr, int targetLength) | |
{ | |
var acceptableRowCount = source.Length - 3; | |
Vector128<int> row1Costs, row2Costs, row3Costs, row4Costs, row5Costs; | |
Vector128<int> sourceChars; | |
var allOnesVector = Vector128.Create(1); | |
fixed (char* sourcePtr = source) | |
{ | |
var sourceUShortPtr = (ushort*)sourcePtr; | |
var targetUShortPtr = (ushort*)targetPtr; | |
for (; rowIndex < acceptableRowCount; rowIndex += 4) | |
{ | |
sourceChars = Sse41.ConvertToVector128Int32(Sse3.LoadDquVector128(sourceUShortPtr)); | |
row1Costs = Vector128.Create(rowIndex); //Sub | |
row2Costs = Sse2.Add(row1Costs, allOnesVector); //Insert, Sub | |
row3Costs = Sse2.Add(row2Costs, allOnesVector); //Insert, Sub | |
row4Costs = Sse2.Add(row3Costs, allOnesVector); //Insert, Sub | |
row5Costs = Sse2.Add(row4Costs, allOnesVector); //Insert | |
var columnIndex = 0; | |
var rowColumnsRemaining = targetLength; | |
Vector128<int> targetChar; | |
while (rowColumnsRemaining >= 8) | |
{ | |
rowColumnsRemaining -= 8; | |
targetChar = Sse41.ConvertToVector128Int32(Sse3.LoadDquVector128(targetUShortPtr + columnIndex)); | |
CalculateColumn_4Rows_Sse41(ref targetChar, previousRowPtr, ref row1Costs, ref row2Costs, ref row3Costs, ref row4Costs, ref row5Costs, ref allOnesVector, sourceChars, ref columnIndex); | |
CalculateColumn_4Rows_Sse41(ref targetChar, previousRowPtr, ref row1Costs, ref row2Costs, ref row3Costs, ref row4Costs, ref row5Costs, ref allOnesVector, sourceChars, ref columnIndex); | |
CalculateColumn_4Rows_Sse41(ref targetChar, previousRowPtr, ref row1Costs, ref row2Costs, ref row3Costs, ref row4Costs, ref row5Costs, ref allOnesVector, sourceChars, ref columnIndex); | |
CalculateColumn_4Rows_Sse41(ref targetChar, previousRowPtr, ref row1Costs, ref row2Costs, ref row3Costs, ref row4Costs, ref row5Costs, ref allOnesVector, sourceChars, ref columnIndex); | |
targetChar = Sse41.ConvertToVector128Int32(Sse3.LoadDquVector128(targetUShortPtr + columnIndex)); | |
CalculateColumn_4Rows_Sse41(ref targetChar, previousRowPtr, ref row1Costs, ref row2Costs, ref row3Costs, ref row4Costs, ref row5Costs, ref allOnesVector, sourceChars, ref columnIndex); | |
CalculateColumn_4Rows_Sse41(ref targetChar, previousRowPtr, ref row1Costs, ref row2Costs, ref row3Costs, ref row4Costs, ref row5Costs, ref allOnesVector, sourceChars, ref columnIndex); | |
CalculateColumn_4Rows_Sse41(ref targetChar, previousRowPtr, ref row1Costs, ref row2Costs, ref row3Costs, ref row4Costs, ref row5Costs, ref allOnesVector, sourceChars, ref columnIndex); | |
CalculateColumn_4Rows_Sse41(ref targetChar, previousRowPtr, ref row1Costs, ref row2Costs, ref row3Costs, ref row4Costs, ref row5Costs, ref allOnesVector, sourceChars, ref columnIndex); | |
} | |
if (rowColumnsRemaining >= 4) | |
{ | |
rowColumnsRemaining -= 4; | |
targetChar = Sse41.ConvertToVector128Int32(Sse3.LoadDquVector128(targetUShortPtr + columnIndex)); | |
CalculateColumn_4Rows_Sse41(ref targetChar, previousRowPtr, ref row1Costs, ref row2Costs, ref row3Costs, ref row4Costs, ref row5Costs, ref allOnesVector, sourceChars, ref columnIndex); | |
CalculateColumn_4Rows_Sse41(ref targetChar, previousRowPtr, ref row1Costs, ref row2Costs, ref row3Costs, ref row4Costs, ref row5Costs, ref allOnesVector, sourceChars, ref columnIndex); | |
CalculateColumn_4Rows_Sse41(ref targetChar, previousRowPtr, ref row1Costs, ref row2Costs, ref row3Costs, ref row4Costs, ref row5Costs, ref allOnesVector, sourceChars, ref columnIndex); | |
CalculateColumn_4Rows_Sse41(ref targetChar, previousRowPtr, ref row1Costs, ref row2Costs, ref row3Costs, ref row4Costs, ref row5Costs, ref allOnesVector, sourceChars, ref columnIndex); | |
} | |
while (rowColumnsRemaining > 0) | |
{ | |
rowColumnsRemaining--; | |
targetChar = Sse41.ConvertToVector128Int32(Vector128.Create(targetUShortPtr[columnIndex])); | |
CalculateColumn_4Rows_Sse41(ref targetChar, previousRowPtr, ref row1Costs, ref row2Costs, ref row3Costs, ref row4Costs, ref row5Costs, ref allOnesVector, sourceChars, ref columnIndex); | |
} | |
} | |
} | |
} | |
/// <summary> | |
/// Using SSE4.1, calculates the cost for 4 vertically adjacent cells in the virtual matrix. | |
/// Comparing 4 vertically adjacent cells prevents 3 target character lookups, 3 deletion cost lookups and 3 saves of the deletion cost. | |
/// SSE4.1 instructions allow a virtually branchless minimum value computation when the source and target characters don't match. | |
/// </summary> | |
/// <param name="targetChar"></param> | |
/// <param name="previousRowPtr"></param> | |
/// <param name="row1Costs"></param> | |
/// <param name="row2Costs"></param> | |
/// <param name="row3Costs"></param> | |
/// <param name="row4Costs"></param> | |
/// <param name="row5Costs"></param> | |
/// <param name="allOnesVector"></param> | |
/// <param name="columnIndex"></param> | |
[MethodImpl(MethodImplOptions.AggressiveInlining)] | |
private static unsafe void CalculateColumn_4Rows_Sse41( | |
ref Vector128<int> targetChar, | |
int* previousRowPtr, | |
ref Vector128<int> row1Costs, | |
ref Vector128<int> row2Costs, | |
ref Vector128<int> row3Costs, | |
ref Vector128<int> row4Costs, | |
ref Vector128<int> row5Costs, | |
ref Vector128<int> allOnesVector, | |
Vector128<int> sourceChars, | |
ref int columnIndex | |
) | |
{ | |
var lastDeletionCost = Vector128.Create(previousRowPtr[columnIndex]); | |
var comparisonMask = Sse2.CompareEqual(targetChar, sourceChars); | |
var localEqual = Sse2.And(comparisonMask, row1Costs); | |
var localNotEqual = Sse2.AndNot( | |
comparisonMask, | |
Sse2.Add( | |
Sse41.Min( | |
Sse41.Min( | |
row2Costs, | |
row1Costs | |
), | |
lastDeletionCost | |
), | |
allOnesVector | |
) | |
); | |
row1Costs = lastDeletionCost; | |
lastDeletionCost = Sse2.Or(localEqual, localNotEqual); | |
sourceChars = Sse2.ShiftRightLogical128BitLane(sourceChars, 4); | |
comparisonMask = Sse2.CompareEqual(targetChar, sourceChars); | |
localEqual = Sse2.And(comparisonMask, row2Costs); | |
localNotEqual = Sse2.AndNot( | |
comparisonMask, | |
Sse2.Add( | |
Sse41.Min( | |
Sse41.Min( | |
row3Costs, | |
row2Costs | |
), | |
lastDeletionCost | |
), | |
allOnesVector | |
) | |
); | |
row2Costs = lastDeletionCost; | |
lastDeletionCost = Sse2.Or(localEqual, localNotEqual); | |
sourceChars = Sse2.ShiftRightLogical128BitLane(sourceChars, 4); | |
comparisonMask = Sse2.CompareEqual(targetChar, sourceChars); | |
localEqual = Sse2.And(comparisonMask, row3Costs); | |
localNotEqual = Sse2.AndNot( | |
comparisonMask, | |
Sse2.Add( | |
Sse41.Min( | |
Sse41.Min( | |
row4Costs, | |
row3Costs | |
), | |
lastDeletionCost | |
), | |
allOnesVector | |
) | |
); | |
row3Costs = lastDeletionCost; | |
lastDeletionCost = Sse2.Or(localEqual, localNotEqual); | |
sourceChars = Sse2.ShiftRightLogical128BitLane(sourceChars, 4); | |
comparisonMask = Sse2.CompareEqual(targetChar, sourceChars); | |
localEqual = Sse2.And(comparisonMask, row4Costs); | |
localNotEqual = Sse2.AndNot( | |
comparisonMask, | |
Sse2.Add( | |
Sse41.Min( | |
Sse41.Min( | |
row5Costs, | |
row4Costs | |
), | |
lastDeletionCost | |
), | |
allOnesVector | |
) | |
); | |
row4Costs = lastDeletionCost; | |
row5Costs = Sse2.Or(localEqual, localNotEqual); | |
previousRowPtr[columnIndex++] = row5Costs.GetElement(0); | |
targetChar = Sse2.ShiftRightLogical128BitLane(targetChar, 4); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment