Skip to content

Instantly share code, notes, and snippets.

@tcantenot
Forked from dondragmer/PrefixSort.compute
Created January 25, 2021 10:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tcantenot/f5e70156f910b36e6a426d31645ce39f to your computer and use it in GitHub Desktop.
Save tcantenot/f5e70156f910b36e6a426d31645ce39f to your computer and use it in GitHub Desktop.
An optimized GPU counting sort
#pragma use_dxc //enable SM 6.0 features, in Unity this is only supported on version 2020.2.0a8 or later with D3D12 enabled
#pragma kernel CountTotalsInBlock
#pragma kernel BlockCountPostfixSum
#pragma kernel CalculateOffsetsForEachKey
#pragma kernel FinalSort
uint _FirstBitToSort;
int _NumElements;
int _NumBlocks;
bool _ShouldSortPayload;
Buffer<uint> KeyInputBuffer;
RWBuffer<uint> KeyOutputBuffer;
Buffer<uint> PayloadInputBuffer;
RWBuffer<uint> PayloadOutputBuffer;
RWTexture2D<int4> PerBlockKeyCountsTexture;
RWTexture2D<int4> BlockToGlobalKeyOffsetsTexture;
//this program assumes 32 lane wide waves (i.e. Nvidia cards), 64 lane waves would require more changes than just adjusting these values
static const uint WAVE_SIZE = 32;
static const uint HIGHEST_LANE = WAVE_SIZE - 1;
static const uint WAVE_SIZE_PLUS_PAD = WAVE_SIZE + 1;
static const uint HALF_WAVE_SIZE = WAVE_SIZE / 2;
static const uint LOG2_WAVE_SIZE = firstbitlow(WAVE_SIZE);
/*
* -----------------------------------------------------------------------------------------------------------
*/
groupshared uint totalCountsInBlock[128];
[numthreads(1024, 1, 1)]
void CountTotalsInBlock(uint3 threadID : SV_GroupThreadID, uint3 groupID : SV_GroupID)
{
uint laneIndex = WaveGetLaneIndex(); //the index of this thread in its wavefront
uint waveIndex = threadID.x / WAVE_SIZE; //the index of this wavefront in the group
uint globalIndex = threadID.x + (groupID.x * 1024); //the index this thread is loading from the global array
uint rawSortKey = KeyInputBuffer[min(globalIndex, _NumElements - 1)];
uint sortKey = (globalIndex < _NumElements) ? ((rawSortKey >> _FirstBitToSort) & 0xFF) : 0xFF;
//initialize counts to 0 because we are going to atomically add to them
if (threadID.x < 128)
{
totalCountsInBlock[threadID.x] = 0;
}
GroupMemoryBarrierWithGroupSync();
//this will contain the total occurrences in this block for 8 keys
//4 sequential keys from the first 128 keys in the lower 16 bits of each element
//4 sequential keys from the latter 128 keys in the upper 16 bits of each element
uint4 countsForWave;
{
uint equalsLaneMask = ~0; //bitmask of lanes in the wave with keys = to this lane's index
for (int bit = 0; bit < LOG2_WAVE_SIZE; bit++)
{
//start at the third bit so each lane can store have 4 consecutive keys
bool isBitSet = sortKey & (1 << (bit + 2));
uint bitSetMask = WaveActiveBallot(isBitSet).x;
equalsLaneMask &= (laneIndex.x & (1 << bit)) ? bitSetMask : ~bitSetMask;
}
//this wave will get the counts in this wave for all 8 permutations of the following bits:
bool isBitSet = sortKey & (1 << 0);
uint firstBitSetMask = WaveActiveBallot(isBitSet).x;
isBitSet = sortKey & (1 << 1);
uint secondBitSetMask = WaveActiveBallot(isBitSet).x;
isBitSet = sortKey & (1 << 7);
uint eighthBitSetMask = WaveActiveBallot(isBitSet).x;
for (int secondBit = 0; secondBit < 2; secondBit++)
{
secondBitSetMask = ~secondBitSetMask;
for (int firstBit = 0; firstBit < 2; firstBit++)
{
firstBitSetMask = ~firstBitSetMask;
//pack two counts with different 8th bits into the same value
uint countA = countbits(equalsLaneMask & firstBitSetMask & secondBitSetMask & ~eighthBitSetMask);
countsForWave[firstBit + secondBit * 2] = (countA & 0xFFFF);
uint countB = countbits(equalsLaneMask & firstBitSetMask & secondBitSetMask & eighthBitSetMask);
countsForWave[firstBit + secondBit * 2] |= (countB << 16);
}
}
}
//atomically add the counts from every wave together
for (uint subIndex = 0; subIndex < 4; subIndex++)
{
uint writeIndex = laneIndex + (subIndex * WAVE_SIZE);
InterlockedAdd(totalCountsInBlock[writeIndex], countsForWave[subIndex]);
}
GroupMemoryBarrierWithGroupSync();
//have the first two waves output the results
[branch]
if (waveIndex <= 1)
{
uint4 countsForBlock;
for (int subIndex = 0; subIndex < 4; subIndex++)
{
uint readIndex = laneIndex + (subIndex * WAVE_SIZE);
countsForBlock[subIndex] = totalCountsInBlock[readIndex];
}
//output the total count of each 1-byte key in this group
uint4 unpackedCounts = (waveIndex == 0) ? (countsForBlock & 0xFFFF) : (countsForBlock >> 16);
PerBlockKeyCountsTexture[uint2(groupID.x, threadID.x)] = (int4) unpackedCounts;
//calculate a prefix sum of every key's count which would equal the index that key starts at in the sorted block
uint4 countPrefix;
countPrefix.x = 0;
countPrefix.y = countsForBlock.x;
countPrefix.z = countPrefix.y + countsForBlock.y;
countPrefix.w = countPrefix.z + countsForBlock.z;
//add in the prefix from the other lanes
countPrefix += WavePrefixSum(countPrefix.w + countsForBlock.w);
//add the total count of the first half of the keys to the second half
uint firstHalfTotal = countPrefix.w + countsForBlock.w;
firstHalfTotal = WaveReadLaneAt(firstHalfTotal, HIGHEST_LANE);
countPrefix += firstHalfTotal << 16;
//output the final starting index
uint4 unpackedPrefix = (waveIndex == 0) ? (countPrefix & 0xFFFF) : (countPrefix >> 16);
BlockToGlobalKeyOffsetsTexture[uint2(groupID.x, threadID.x)] = (int4) unpackedPrefix;
}
}
/*
* -----------------------------------------------------------------------------------------------------------
*/
groupshared int4 eachWaveTotals[32];
[numthreads(1024, 1, 1)]
void BlockCountPostfixSum(uint3 threadID : SV_GroupThreadID, uint3 groupID : SV_GroupID)
{
uint laneIndex = WaveGetLaneIndex(); //the index of this thread in its wavefront
uint waveIndex = threadID.x / WAVE_SIZE; //the index of this wavefront in the group
int4 blockCounts = (threadID.x < _NumBlocks) ? PerBlockKeyCountsTexture[uint2(threadID.x, groupID.y)] : 0;
//calculate the postfix 1024 blocks at a time
int4 runningTotals = 0;
int startingBlock;
for (startingBlock = 0; startingBlock < (_NumBlocks - 1024); startingBlock += 1024)
{
int4 blockCountPostfix = WavePrefixSum(blockCounts) + blockCounts;
//load the next set of key counts now to maximize the time until we need to use it
int blockLoadIndex = threadID.x + startingBlock;
blockCounts = (blockLoadIndex + 1024 < _NumBlocks) ? PerBlockKeyCountsTexture[uint2(blockLoadIndex + 1024, groupID.y)] : 0;
//have last lane in each wave output the total counts for this wave
if (laneIndex == HIGHEST_LANE)
{
eachWaveTotals[waveIndex] = blockCountPostfix;
}
GroupMemoryBarrierWithGroupSync();
//get the totals of all waves before this one
int4 allWaveTotals = eachWaveTotals[laneIndex];
int4 previousWaveTotal = (laneIndex < waveIndex) ? allWaveTotals : 0; //only keep totals for waves less than this one
previousWaveTotal = WaveActiveSum(previousWaveTotal);
blockCountPostfix += previousWaveTotal;
if (blockLoadIndex < _NumBlocks)
{
PerBlockKeyCountsTexture[uint2(blockLoadIndex, groupID.y)] = blockCountPostfix + runningTotals;
}
//get totals from all 1024 blocks to add to the next set
runningTotals += WaveActiveSum(allWaveTotals);
GroupMemoryBarrierWithGroupSync();
}
//calculate postfix for final set of blocks
int4 blockCountPostfix = WavePrefixSum(blockCounts) + blockCounts;
//have last lane in each wave output the total count for this wave
if (laneIndex == HIGHEST_LANE)
{
eachWaveTotals[waveIndex] = blockCountPostfix;
}
GroupMemoryBarrierWithGroupSync();
//get the totals of all waves before this one
int4 previousWaveTotal = eachWaveTotals[laneIndex];
previousWaveTotal = (laneIndex < waveIndex) ? previousWaveTotal : 0; //only keep totals for waves less than this one
previousWaveTotal = WaveActiveSum(previousWaveTotal);
blockCountPostfix += previousWaveTotal;
int blockLoadIndex = threadID.x + startingBlock;
if (blockLoadIndex < _NumBlocks)
{
PerBlockKeyCountsTexture[uint2(blockLoadIndex, groupID.y)] = blockCountPostfix + runningTotals;
}
}
/*
* -----------------------------------------------------------------------------------------------------------
*/
[numthreads(1, 32, 1)]
void CalculateOffsetsForEachKey(uint3 threadID : SV_GroupThreadID, uint3 groupID : SV_GroupID)
{
//get the total counts of each key in the entire global array
int4 globalKeyCountsA = PerBlockKeyCountsTexture[uint2(_NumBlocks - 1, threadID.y)];
int4 globalKeyCountsB = PerBlockKeyCountsTexture[uint2(_NumBlocks - 1, threadID.y + WAVE_SIZE)];
//get the totals counts of each key in all previous blocks
int4 previousBlockKeyCountsA = 0;
int4 previousBlockKeyCountsB = 0;
if (groupID.x > 0)
{
previousBlockKeyCountsA = PerBlockKeyCountsTexture[uint2(groupID.x - 1, threadID.y)];
previousBlockKeyCountsB = PerBlockKeyCountsTexture[uint2(groupID.x - 1, threadID.y + WAVE_SIZE)];
}
//get the start index of each key inside the sorted block
int4 blockKeyStartIndexA = BlockToGlobalKeyOffsetsTexture[uint2(groupID.x, threadID.y)];
int4 blockKeyStartIndexB = BlockToGlobalKeyOffsetsTexture[uint2(groupID.x, threadID.y + WAVE_SIZE)];
//generate prefix sum of total counts of each key to get each key's the global start index
globalKeyCountsA.y += globalKeyCountsA.x;
globalKeyCountsA.z += globalKeyCountsA.y;
globalKeyCountsA.w += globalKeyCountsA.z;
globalKeyCountsB.y += globalKeyCountsB.x;
globalKeyCountsB.z += globalKeyCountsB.y;
globalKeyCountsB.w += globalKeyCountsB.z;
//prefix sum for the first half of the keys
int crossLanePrefixSumLower = WavePrefixSum(globalKeyCountsA.w);
int4 globalKeyStartIndexA = int4(0, globalKeyCountsA.xyz) + crossLanePrefixSumLower + previousBlockKeyCountsA;
BlockToGlobalKeyOffsetsTexture[uint2(groupID.x, threadID.y)] = globalKeyStartIndexA - blockKeyStartIndexA;
//prefix sum for the second half of the keys
int crossLanePrefixSumUpper = WavePrefixSum(globalKeyCountsB.w);
crossLanePrefixSumUpper += WaveReadLaneAt(crossLanePrefixSumLower + globalKeyCountsA.w, HIGHEST_LANE); //add the first half total
int4 globalKeyStartIndexB = int4(0, globalKeyCountsB.xyz) + crossLanePrefixSumUpper + previousBlockKeyCountsB;
BlockToGlobalKeyOffsetsTexture[uint2(groupID.x, threadID.y + WAVE_SIZE)] = globalKeyStartIndexB - blockKeyStartIndexB;
}
/*
* -----------------------------------------------------------------------------------------------------------
*/
groupshared uint countsOrSortedData[1024 * 2];
[numthreads(1024, 1, 1)]
void FinalSort(uint3 threadID : SV_GroupThreadID, uint3 groupID : SV_GroupID)
{
uint laneIndex = WaveGetLaneIndex(); //the index of this thread in its wavefront
uint waveIndex = threadID.x / WAVE_SIZE; //the index of this wavefront in the group
uint globalIndex = threadID.x + (groupID.x * 1024); //the index this thread is loading from the global array
uint rawSortKey = KeyInputBuffer[min(globalIndex, _NumElements - 1)];
uint sortPayload = 0;
if (_ShouldSortPayload) //only load a payload if we are actually sorting it
{
sortPayload = PayloadInputBuffer[min(globalIndex, _NumElements - 1)];
}
//do a local sort for this block in two stages, the lower 4 bits of the key and then the upper 4 bits
[unroll]
for (int subsortFirstBit = 0; subsortFirstBit < 8; subsortFirstBit += 4)
{
if (subsortFirstBit != 0)
{
GroupMemoryBarrierWithGroupSync();
}
uint sortKey = (globalIndex < _NumElements) ? ((rawSortKey >> (_FirstBitToSort + subsortFirstBit)) & 0x0F) : 0x0F;
uint internalSortIndex = 0;
//count keys in the wave
{
uint equalsKeyMask = ~0;
uint equalsLaneMask = ~0;
uint lessThanLaneMask = 0;
uint laneSortValue = laneIndex & 0x0F;
for (int bit = 0; bit < 4; bit++)
{
bool isBitSet = sortKey & (1 << bit);
uint bitSetMask = WaveActiveBallot(isBitSet).x;
equalsKeyMask &= isBitSet ? bitSetMask : ~bitSetMask;
if (laneSortValue & (1 << bit))
{
lessThanLaneMask |= ~bitSetMask;
equalsLaneMask &= bitSetMask;
}
else
{
lessThanLaneMask &= ~bitSetMask;
equalsLaneMask &= ~bitSetMask;
}
}
//count the number of lanes before this one with the same key
internalSortIndex = countbits((equalsKeyMask << (HIGHEST_LANE - laneIndex)) << 1);
//first half of the wave outputs count of keys equal to its value, second half outputs count of keys less than its value
uint countToOutput = (laneIndex < HALF_WAVE_SIZE) ? countbits(equalsLaneMask) : countbits(lessThanLaneMask);
uint writeIndex = laneIndex + (waveIndex * WAVE_SIZE_PLUS_PAD);
countsOrSortedData[writeIndex] = countToOutput;
}
GroupMemoryBarrierWithGroupSync();
//calculate the prefix sums and count of smaller keys for each of the 16 keys
{
uint readIndex = (laneIndex * WAVE_SIZE_PLUS_PAD) + waveIndex;
uint waveCounts = countsOrSortedData[readIndex];
[branch]
if (waveIndex < 16) //first 16 waves are calculating prefix sums for each key
{
uint prefixSumForKey = WavePrefixSum(waveCounts);
countsOrSortedData[readIndex] = prefixSumForKey;
}
else //last 16 waves are calculating starting index for each key
{
uint offsetForKey = WaveActiveSum(waveCounts);
if (WaveIsFirstLane())
{
countsOrSortedData[(waveIndex - 16) + (WAVE_SIZE * WAVE_SIZE_PLUS_PAD)] = offsetForKey;
}
}
}
GroupMemoryBarrierWithGroupSync();
//get the sorted index and then scatter into LDS
{
uint readIndex = sortKey + (waveIndex * WAVE_SIZE_PLUS_PAD);
internalSortIndex += countsOrSortedData[readIndex];
internalSortIndex += countsOrSortedData[sortKey + (WAVE_SIZE * WAVE_SIZE_PLUS_PAD)];
GroupMemoryBarrierWithGroupSync();
countsOrSortedData[internalSortIndex] = rawSortKey;
if (_ShouldSortPayload)
{
countsOrSortedData[internalSortIndex + 1024] = sortPayload;
}
}
GroupMemoryBarrierWithGroupSync();
//read the sorted data out of LDS
{
rawSortKey = countsOrSortedData[threadID.x];
if (_ShouldSortPayload)
{
sortPayload = countsOrSortedData[threadID.x + 1024];
}
}
}
if (globalIndex < _NumElements)
{
//load sorted data
uint sortKey = (rawSortKey >> _FirstBitToSort) & 0xFF;
int finalSortIndex = (int) threadID.x + BlockToGlobalKeyOffsetsTexture[uint2(groupID.x, sortKey / 4)][sortKey % 4];
KeyOutputBuffer[finalSortIndex] = rawSortKey;
if (_ShouldSortPayload)
{
PayloadOutputBuffer[finalSortIndex] = sortPayload;
}
}
}
using UnityEngine;
using System;
//this setup class is made for Unity but the shader will work in any engine that supports D3D12 and HLSL SM 6.0
public class PrefixSorterSetup : MonoBehaviour
{
static readonly int maxElements = 1024 * 1024 * 8;
public ComputeShader m_sortShader;
public int[] m_tests = { 1048576 };
public bool m_shouldSortPayload = true;
public string m_debugOutputElements = "";
int m_testSizeIndex = 0;
int m_numElements = -1;
int m_countTotalsKernel;
int m_blockPostfixKernel;
int m_calculateOffsetsKernel;
int m_finalSortKernel;
uint[] m_sortingKeys;
uint[] m_sortingPayload;
ComputeBuffer m_keysBufferA;
ComputeBuffer m_keysBufferB;
ComputeBuffer m_payloadBufferA;
ComputeBuffer m_payloadBufferB;
RenderTexture m_perBlockKeyCountsTexture;
RenderTexture m_blockToGlobalKeyOffsetsTexture;
private TextMesh m_debugDisplayText;
// Start is called before the first frame update
void Start()
{
m_debugDisplayText = GetComponent<TextMesh>();
SetupComputeShader();
ProcessControlsAndEditorSettings(); //sets up resources
}
// Update is called once per frame
void Update()
{
ProcessControlsAndEditorSettings();
if (DoSort())
{
if (Input.GetKeyDown(KeyCode.Space))
{
ValidateKeys();
if (m_shouldSortPayload)
{
ValidatePayload();
}
}
}
else
{
m_countTotalsKernel = -1;
m_blockPostfixKernel = -1;
m_calculateOffsetsKernel = -1;
m_finalSortKernel = -1;
SetupComputeShader();
BuildResources();
}
if (m_debugDisplayText != null)
{
if (m_sortShader == null)
{
m_debugDisplayText.text = "SHADER NOT SET!";
}
else
{
m_debugDisplayText.text = "Elements: " + m_numElements.ToString() + "\nSortring ";
m_debugDisplayText.text += m_shouldSortPayload ? "Keys and Payload" : "Keys Only";
}
}
}
void ProcessControlsAndEditorSettings()
{
//keyboard controls
if (Input.GetKeyDown(KeyCode.Q))
{
m_testSizeIndex--;
}
if (Input.GetKeyDown(KeyCode.E))
{
m_testSizeIndex++;
}
if (Input.GetKeyDown(KeyCode.R))
{
m_shouldSortPayload = !m_shouldSortPayload;
}
//make sure there is at least 1 valid test size
if (m_tests.Length == 0)
{
m_tests = new int[1];
m_tests[0] = 1048576;
}
//wrap test index to vaild range
if(m_testSizeIndex < 0)
{
m_testSizeIndex = m_tests.Length - 1;
}
if (m_testSizeIndex >= m_tests.Length)
{
m_testSizeIndex = 0;
}
//pick the number of elements and clamp it
int newNumElements = m_tests[m_testSizeIndex];
if (newNumElements > maxElements)
{
newNumElements = maxElements;
}
else if (newNumElements < 1)
{
newNumElements = 1;
}
//rebuild resources if number of elements changed
if(m_numElements != newNumElements)
{
m_numElements = newNumElements;
BuildResources();
}
}
void SetupComputeShader()
{
//check if the shader exists
if (m_sortShader == null)
{
return;
}
m_countTotalsKernel = m_sortShader.FindKernel("CountTotalsInBlock");
m_blockPostfixKernel = m_sortShader.FindKernel("BlockCountPostfixSum");
m_calculateOffsetsKernel = m_sortShader.FindKernel("CalculateOffsetsForEachKey");
m_finalSortKernel = m_sortShader.FindKernel("FinalSort");
}
void BuildResources()
{
if(m_sortShader == null)
{
return;
}
//create an unsorted array of values
m_sortingKeys = new uint[m_numElements];
m_sortingPayload = new uint[m_numElements];
for (uint i = 0; i < m_numElements; i++)
{
m_sortingKeys[i] = (uint)UnityEngine.Random.Range(int.MinValue, int.MaxValue);
m_sortingPayload[i] = i;
}
//create the buffers
if (m_keysBufferA != null) { m_keysBufferA.Release(); }
m_keysBufferA = new ComputeBuffer(m_numElements, sizeof(int));
if (m_keysBufferB != null) { m_keysBufferB.Release(); }
m_keysBufferB = new ComputeBuffer(m_numElements, sizeof(int));
if (m_payloadBufferA != null) { m_payloadBufferA.Release(); }
m_payloadBufferA = new ComputeBuffer(m_numElements, sizeof(int));
if (m_payloadBufferB != null) { m_payloadBufferB.Release(); }
m_payloadBufferB = new ComputeBuffer(m_numElements, sizeof(int));
//create the textures
int numBlocks = Mathf.CeilToInt(m_numElements / 1024.0f);
RenderTextureDescriptor groupTotalsTexDesc = new RenderTextureDescriptor(numBlocks, 64, RenderTextureFormat.ARGBInt, 0);
groupTotalsTexDesc.enableRandomWrite = true;
if (m_perBlockKeyCountsTexture != null) { m_perBlockKeyCountsTexture.Release(); }
m_perBlockKeyCountsTexture = new RenderTexture(groupTotalsTexDesc);
m_perBlockKeyCountsTexture.Create();
m_sortShader.SetTexture(m_countTotalsKernel, "PerBlockKeyCountsTexture", m_perBlockKeyCountsTexture, 0);
m_sortShader.SetTexture(m_blockPostfixKernel, "PerBlockKeyCountsTexture", m_perBlockKeyCountsTexture, 0);
m_sortShader.SetTexture(m_calculateOffsetsKernel, "PerBlockKeyCountsTexture", m_perBlockKeyCountsTexture, 0);
if (m_blockToGlobalKeyOffsetsTexture != null) { m_blockToGlobalKeyOffsetsTexture.Release(); }
m_blockToGlobalKeyOffsetsTexture = new RenderTexture(groupTotalsTexDesc);
m_blockToGlobalKeyOffsetsTexture.Create();
m_sortShader.SetTexture(m_countTotalsKernel, "BlockToGlobalKeyOffsetsTexture", m_blockToGlobalKeyOffsetsTexture, 0);
m_sortShader.SetTexture(m_calculateOffsetsKernel, "BlockToGlobalKeyOffsetsTexture", m_blockToGlobalKeyOffsetsTexture, 0);
m_sortShader.SetTexture(m_finalSortKernel, "BlockToGlobalKeyOffsetsTexture", m_blockToGlobalKeyOffsetsTexture, 0);
}
bool DoSort()
{
if (m_sortShader == null || m_numElements < 1 || m_numElements > maxElements
|| m_countTotalsKernel < 0 || m_blockPostfixKernel < 0
|| m_calculateOffsetsKernel < 0 || m_finalSortKernel < 0)
{
return false;
}
m_keysBufferA.SetData(m_sortingKeys, 0, 0, m_numElements);
m_payloadBufferA.SetData(m_sortingPayload, 0, 0, m_numElements);
int numBlocks = Mathf.CeilToInt(m_numElements / 1024.0f);
m_sortShader.SetInt("_NumElements", m_numElements);
m_sortShader.SetInt("_NumBlocks", numBlocks);
m_sortShader.SetBool("_ShouldSortPayload", m_shouldSortPayload);
//sorting is done four iterations each sorting 1 byte (from lowest to highest)
for (int i = 0; i < 4; i++)
{
m_sortShader.SetInt("_FirstBitToSort", i * 8);
//flip the buffers every other sort
if((i % 2) == 0)
{
m_sortShader.SetBuffer(m_countTotalsKernel, "KeyInputBuffer", m_keysBufferA);
m_sortShader.SetBuffer(m_finalSortKernel, "KeyInputBuffer", m_keysBufferA);
m_sortShader.SetBuffer(m_finalSortKernel, "KeyOutputBuffer", m_keysBufferB);
m_sortShader.SetBuffer(m_finalSortKernel, "PayloadInputBuffer", m_payloadBufferA);
m_sortShader.SetBuffer(m_finalSortKernel, "PayloadOutputBuffer", m_payloadBufferB);
}
else
{
m_sortShader.SetBuffer(m_countTotalsKernel, "KeyInputBuffer", m_keysBufferB);
m_sortShader.SetBuffer(m_finalSortKernel, "KeyInputBuffer", m_keysBufferB);
m_sortShader.SetBuffer(m_finalSortKernel, "KeyOutputBuffer", m_keysBufferA);
m_sortShader.SetBuffer(m_finalSortKernel, "PayloadInputBuffer", m_payloadBufferB);
m_sortShader.SetBuffer(m_finalSortKernel, "PayloadOutputBuffer", m_payloadBufferA);
}
m_sortShader.Dispatch(m_countTotalsKernel, numBlocks, 1, 1);
m_sortShader.Dispatch(m_blockPostfixKernel, 1, 64, 1);
m_sortShader.Dispatch(m_calculateOffsetsKernel, numBlocks, 1, 1);
m_sortShader.Dispatch(m_finalSortKernel, numBlocks, 1, 1);
}
return true;
}
void ValidateKeys()
{
int numToPrint = 10;
m_debugOutputElements = "";
uint[] values = new uint[m_numElements];
m_keysBufferA.GetData(values, 0, 0, m_numElements);
bool isSorted = true;
float smallest = values[0];
int failedOn = -1;
for (int i = 0; i < m_numElements; i++)
{
float f = values[i];
//print out the first 1024 elements in a debug string
if(i < 1024)
{
if (i % 32 == 0)
{
m_debugOutputElements += "\n";
}
if (i % 256 == 0)
{
m_debugOutputElements += "\n";
}
m_debugOutputElements += f + ", ";
}
if (f < smallest && isSorted)
{
isSorted = false;
failedOn = i;
}
smallest = f;
}
//print the values surrounding where the sort failed (or just the first values)
string output = "Keys: Size = " + m_numElements + " | Sorted = " + isSorted + " | Failed On = " + failedOn + " | Values: ";
int startIndex = Math.Max(0, failedOn - 5);
for (int i = startIndex; i < startIndex + numToPrint && i < m_numElements; i++)
{
output += values[i] + ", ";
}
Debug.Log(output);
}
void ValidatePayload()
{
int numToPrint = 10;
m_debugOutputElements = "";
//payload contains mapping back to original unsorted indices
uint[] originalIndices = new uint[m_numElements];
m_payloadBufferA.GetData(originalIndices, 0, 0, m_numElements);
bool isSorted = true;
float smallest = m_sortingKeys[originalIndices[0]];
int failedOn = -1;
for (int i = 0; i < m_numElements; i++)
{
float f = m_sortingKeys[originalIndices[i]];
//print out the first 1024 elements in a debug string
if (i < 1024)
{
if (i % 32 == 0)
{
m_debugOutputElements += "\n";
}
if (i % 256 == 0)
{
m_debugOutputElements += "\n";
}
m_debugOutputElements += f + ", ";
}
if (f < smallest && isSorted)
{
isSorted = false;
failedOn = i;
}
smallest = f;
}
//print the values surrounding where the sort failed (or just the first values)
string output = "Payload: Size = " + m_numElements + " | Sorted = " + isSorted + " | Failed On = " + failedOn + " | Values: ";
int startIndex = Math.Max(0, failedOn - 5);
for (int i = startIndex; i < startIndex + numToPrint && i < m_numElements; i++)
{
output += m_sortingKeys[originalIndices[i]] + ", ";
}
Debug.Log(output);
}
private void OnDestroy()
{
if (m_keysBufferA != null) { m_keysBufferA.Release(); }
if (m_keysBufferB != null) { m_keysBufferB.Release(); }
if (m_payloadBufferA != null) { m_payloadBufferA.Release(); }
if (m_payloadBufferB != null) { m_payloadBufferB.Release(); }
if (m_perBlockKeyCountsTexture != null) { m_perBlockKeyCountsTexture.Release(); }
if (m_blockToGlobalKeyOffsetsTexture != null) { m_blockToGlobalKeyOffsetsTexture.Release(); }
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment