Skip to content

Instantly share code, notes, and snippets.

@dondragmer
Created December 5, 2020 00:11
Show Gist options
  • Star 29 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save dondragmer/0c0b3eed0f7c30f7391deb11121a5aa1 to your computer and use it in GitHub Desktop.
Save dondragmer/0c0b3eed0f7c30f7391deb11121a5aa1 to your computer and use it in GitHub Desktop.
A very fast GPU sort for sorting values within a wavefront
Buffer<uint> Input;
RWBuffer<uint> Output;
//returns the index that this value should be moved to to sort the array
uint CuteSort(uint value, uint laneIndex)
{
uint smallerValuesMask = 0;
uint equalValuesMask = ~0;
//don't need to test every bit if your value is constrained to a smaller range
for (int bit = 0; bit < 32; bit++)
{
bool isBitSet = value & (1 << bit);
uint bitSetMask = WaveActiveBallot(isBitSet);
if(isBitSet)
{
smallerValuesMask |= ~bitSetMask;
equalValuesMask &= bitSetMask;
}
else
{
smallerValuesMask &= ~bitSetMask;
equalValuesMask &= ~bitSetMask;
}
}
//count up all the lanes with values that should be in front of this one
uint numSmallerThanThis = countbits(smallerValuesMask);
uint numEqualBeforeThis = countbits((equalValuesMask << (31 - laneIndex)) << 1);
return numSmallerThanThis + numEqualBeforeThis;
}
uint ShuffleTo(uint value, uint dstIndex, uint laneIndex)
{
uint equalIndexMask = ~0;
for (int bit = 0; bit < 5; bit++)
{
uint bitSetMask = WaveActiveBallot(dstIndex & (1 << bit));
equalIndexMask &= (laneIndex & (1 << bit)) ? bitSetMask : ~bitSetMask;
}
uint laneWithOurValue = firstbitlow(equalIndexMask);
return WaveReadLaneAt(value, laneWithOurValue);
}
[numthreads(1024, 1, 1)]
void CuteSortTest(uint3 id : SV_DispatchThreadID)
{
uint value = Input[id.x];
uint outputIndex = CuteSort(value, id.x & 0x1F);
/*
//alt version which presorts before outputting
value = ShuffleTo(value, outputIndex, id.x & 0x1F);
Output[id.x] = value;
*/
outputIndex += id.x & ~0x1F;
Output[outputIndex] = value;
}
//returns the value at this lane's index in the sorted array, slower than cute sort in most circumstances
uint BitonicSort(uint value, uint laneIndex)
{
for (uint sortSize = 2; sortSize <= 32; sortSize = sortSize << 1)
{
bool reverseSequence = laneIndex & sortSize;
for (uint stride = sortSize >> 1; stride > 0; stride = stride >> 1)
{
bool ascending = (bool) (laneIndex & stride) == reverseSequence;
uint other = WaveReadLaneAt(value, laneIndex ^ stride);
if ((other < value) == ascending)
{
value = other;
}
}
}
return value;
}
[numthreads(1024, 1, 1)]
void BitonicSortTest(uint3 id : SV_DispatchThreadID)
{
uint value = Input[id.x];
value = BitonicSort(value, id.x & 0x1F);
Output[id.x] = value;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment