Skip to content

Instantly share code, notes, and snippets.

@Turnerj
Created March 6, 2020 08:36
Show Gist options
  • Save Turnerj/c3f1962825f28f2e7798f85d25c3a137 to your computer and use it in GitHub Desktop.
Save Turnerj/c3f1962825f28f2e7798f85d25c3a137 to your computer and use it in GitHub Desktop.
Third attempt - still not as good as still having branches
/// <summary>
/// Using SSE4.1, calculates the costs for the virtual matrix.
/// This performs a 4x outer loop unrolling allowing fewer lookups of target character and deletion cost data across the rows.
/// </summary>
/// <param name="previousRowPtr"></param>
/// <param name="source"></param>
/// <param name="rowIndex"></param>
/// <param name="targetPtr"></param>
/// <param name="targetLength"></param>
private static unsafe void CalculateRows_4Rows_Sse41(int* previousRowPtr, ReadOnlySpan<char> source, ref int rowIndex, char* targetPtr, int targetLength)
{
var acceptableRowCount = source.Length - 3;
Vector128<int> row1Costs, row2Costs, row3Costs, row4Costs, row5Costs;
Vector128<int> sourceChars;
var allOnesVector = Vector128.Create(1);
fixed (char* sourcePtr = source)
{
var sourceUShortPtr = (ushort*)sourcePtr;
var targetUShortPtr = (ushort*)targetPtr;
for (; rowIndex < acceptableRowCount; rowIndex += 4)
{
sourceChars = Sse41.ConvertToVector128Int32(Sse3.LoadDquVector128(sourceUShortPtr));
row1Costs = Vector128.Create(rowIndex); //Sub
row2Costs = Sse2.Add(row1Costs, allOnesVector); //Insert, Sub
row3Costs = Sse2.Add(row2Costs, allOnesVector); //Insert, Sub
row4Costs = Sse2.Add(row3Costs, allOnesVector); //Insert, Sub
row5Costs = Sse2.Add(row4Costs, allOnesVector); //Insert
var columnIndex = 0;
var rowColumnsRemaining = targetLength;
Vector128<int> targetChar;
while (rowColumnsRemaining >= 8)
{
rowColumnsRemaining -= 8;
targetChar = Sse41.ConvertToVector128Int32(Sse3.LoadDquVector128(targetUShortPtr + columnIndex));
CalculateColumn_4Rows_Sse41(ref targetChar, previousRowPtr, ref row1Costs, ref row2Costs, ref row3Costs, ref row4Costs, ref row5Costs, ref allOnesVector, sourceChars, ref columnIndex);
CalculateColumn_4Rows_Sse41(ref targetChar, previousRowPtr, ref row1Costs, ref row2Costs, ref row3Costs, ref row4Costs, ref row5Costs, ref allOnesVector, sourceChars, ref columnIndex);
CalculateColumn_4Rows_Sse41(ref targetChar, previousRowPtr, ref row1Costs, ref row2Costs, ref row3Costs, ref row4Costs, ref row5Costs, ref allOnesVector, sourceChars, ref columnIndex);
CalculateColumn_4Rows_Sse41(ref targetChar, previousRowPtr, ref row1Costs, ref row2Costs, ref row3Costs, ref row4Costs, ref row5Costs, ref allOnesVector, sourceChars, ref columnIndex);
targetChar = Sse41.ConvertToVector128Int32(Sse3.LoadDquVector128(targetUShortPtr + columnIndex));
CalculateColumn_4Rows_Sse41(ref targetChar, previousRowPtr, ref row1Costs, ref row2Costs, ref row3Costs, ref row4Costs, ref row5Costs, ref allOnesVector, sourceChars, ref columnIndex);
CalculateColumn_4Rows_Sse41(ref targetChar, previousRowPtr, ref row1Costs, ref row2Costs, ref row3Costs, ref row4Costs, ref row5Costs, ref allOnesVector, sourceChars, ref columnIndex);
CalculateColumn_4Rows_Sse41(ref targetChar, previousRowPtr, ref row1Costs, ref row2Costs, ref row3Costs, ref row4Costs, ref row5Costs, ref allOnesVector, sourceChars, ref columnIndex);
CalculateColumn_4Rows_Sse41(ref targetChar, previousRowPtr, ref row1Costs, ref row2Costs, ref row3Costs, ref row4Costs, ref row5Costs, ref allOnesVector, sourceChars, ref columnIndex);
}
if (rowColumnsRemaining >= 4)
{
rowColumnsRemaining -= 4;
targetChar = Sse41.ConvertToVector128Int32(Sse3.LoadDquVector128(targetUShortPtr + columnIndex));
CalculateColumn_4Rows_Sse41(ref targetChar, previousRowPtr, ref row1Costs, ref row2Costs, ref row3Costs, ref row4Costs, ref row5Costs, ref allOnesVector, sourceChars, ref columnIndex);
CalculateColumn_4Rows_Sse41(ref targetChar, previousRowPtr, ref row1Costs, ref row2Costs, ref row3Costs, ref row4Costs, ref row5Costs, ref allOnesVector, sourceChars, ref columnIndex);
CalculateColumn_4Rows_Sse41(ref targetChar, previousRowPtr, ref row1Costs, ref row2Costs, ref row3Costs, ref row4Costs, ref row5Costs, ref allOnesVector, sourceChars, ref columnIndex);
CalculateColumn_4Rows_Sse41(ref targetChar, previousRowPtr, ref row1Costs, ref row2Costs, ref row3Costs, ref row4Costs, ref row5Costs, ref allOnesVector, sourceChars, ref columnIndex);
}
while (rowColumnsRemaining > 0)
{
rowColumnsRemaining--;
targetChar = Sse41.ConvertToVector128Int32(Vector128.Create(targetUShortPtr[columnIndex]));
CalculateColumn_4Rows_Sse41(ref targetChar, previousRowPtr, ref row1Costs, ref row2Costs, ref row3Costs, ref row4Costs, ref row5Costs, ref allOnesVector, sourceChars, ref columnIndex);
}
}
}
}
/// <summary>
/// Using SSE4.1, calculates the cost for 4 vertically adjacent cells in the virtual matrix.
/// Comparing 4 vertically adjacent cells prevents 3 target character lookups, 3 deletion cost lookups and 3 saves of the deletion cost.
/// SSE4.1 instructions allow a virtually branchless minimum value computation when the source and target characters don't match.
/// </summary>
/// <param name="targetChar"></param>
/// <param name="previousRowPtr"></param>
/// <param name="row1Costs"></param>
/// <param name="row2Costs"></param>
/// <param name="row3Costs"></param>
/// <param name="row4Costs"></param>
/// <param name="row5Costs"></param>
/// <param name="allOnesVector"></param>
/// <param name="columnIndex"></param>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe void CalculateColumn_4Rows_Sse41(
ref Vector128<int> targetChar,
int* previousRowPtr,
ref Vector128<int> row1Costs,
ref Vector128<int> row2Costs,
ref Vector128<int> row3Costs,
ref Vector128<int> row4Costs,
ref Vector128<int> row5Costs,
ref Vector128<int> allOnesVector,
Vector128<int> sourceChars,
ref int columnIndex
)
{
var lastDeletionCost = Vector128.Create(previousRowPtr[columnIndex]);
var comparisonMask = Sse2.CompareEqual(targetChar, sourceChars);
var localEqual = Sse2.And(comparisonMask, row1Costs);
var localNotEqual = Sse2.AndNot(
comparisonMask,
Sse2.Add(
Sse41.Min(
Sse41.Min(
row2Costs,
row1Costs
),
lastDeletionCost
),
allOnesVector
)
);
row1Costs = lastDeletionCost;
lastDeletionCost = Sse2.Or(localEqual, localNotEqual);
sourceChars = Sse2.ShiftRightLogical128BitLane(sourceChars, 4);
comparisonMask = Sse2.CompareEqual(targetChar, sourceChars);
localEqual = Sse2.And(comparisonMask, row2Costs);
localNotEqual = Sse2.AndNot(
comparisonMask,
Sse2.Add(
Sse41.Min(
Sse41.Min(
row3Costs,
row2Costs
),
lastDeletionCost
),
allOnesVector
)
);
row2Costs = lastDeletionCost;
lastDeletionCost = Sse2.Or(localEqual, localNotEqual);
sourceChars = Sse2.ShiftRightLogical128BitLane(sourceChars, 4);
comparisonMask = Sse2.CompareEqual(targetChar, sourceChars);
localEqual = Sse2.And(comparisonMask, row3Costs);
localNotEqual = Sse2.AndNot(
comparisonMask,
Sse2.Add(
Sse41.Min(
Sse41.Min(
row4Costs,
row3Costs
),
lastDeletionCost
),
allOnesVector
)
);
row3Costs = lastDeletionCost;
lastDeletionCost = Sse2.Or(localEqual, localNotEqual);
sourceChars = Sse2.ShiftRightLogical128BitLane(sourceChars, 4);
comparisonMask = Sse2.CompareEqual(targetChar, sourceChars);
localEqual = Sse2.And(comparisonMask, row4Costs);
localNotEqual = Sse2.AndNot(
comparisonMask,
Sse2.Add(
Sse41.Min(
Sse41.Min(
row5Costs,
row4Costs
),
lastDeletionCost
),
allOnesVector
)
);
row4Costs = lastDeletionCost;
row5Costs = Sse2.Or(localEqual, localNotEqual);
previousRowPtr[columnIndex++] = row5Costs.GetElement(0);
targetChar = Sse2.ShiftRightLogical128BitLane(targetChar, 4);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment