Skip to content

Instantly share code, notes, and snippets.

@nanokatze
Last active April 16, 2024 22:51
Show Gist options
  • Save nanokatze/fea4f881c7c5b7b1b6931a92e7771a21 to your computer and use it in GitHub Desktop.
Save nanokatze/fea4f881c7c5b7b1b6931a92e7771a21 to your computer and use it in GitHub Desktop.
LZ4 decompressor for Vulkan written in Slang
typedef uint64_t size_t;
typedef uint64_t uintptr_t;
struct Queue<T> {
// pop side would really benefit from a lock tbh
uint32_t head;
uint32_t tail;
T buf[1];
};
uint myAtomicLoad(uint *p) {
uint scope = 1;
uint semantics = 0x2002;
return spirv_asm { OpAtomicLoad $$uint result $p $scope $semantics };
}
uint myAtomicAdd(uint *p, uint value) {
uint scope = 1;
uint semantics = 0x6006;
return spirv_asm { OpAtomicIAdd $$uint result $p $scope $semantics $value };
}
uint myAtomicCompareExchange(uint *p, uint comparator, uint value) {
uint scope = 1;
uint semantics = 0x6006;
uint semanticsUnequal = 0x2002;
return spirv_asm { OpAtomicCompareExchange $$uint result $p $scope $semantics $semanticsUnequal $value $comparator };
}
#define MAX_COPY_JOB_SIZE 8192 // this should be specified by the host in launch params.
#define JOB_TYPE_COPY_MEMORY 1
struct Job {
int8_t type;
uint8_t *dst;
uint8_t *src;
size_t size;
};
/*
struct Comp {
Queue<Job> jobQueue;
};
*/
struct Results {
uint64_t consumed;
uint64_t produced;
uint64_t t0;
uint64_t t1;
uint32_t spinDown; // there won't be any more jobs. TODO: move into a Comp struct
};
struct Push {
Queue<Job> *jobQueue;
uint8_t *src;
size_t srcLen;
uint8_t *dst;
size_t dstLen;
Results *results;
};
[[vk::push_constant]] Push push;
uint32_t load32_unaligned(uint8_t *p) {
return (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16) | ((uint32_t)p[3] << 24);
}
[shader("compute")]
[numthreads(16, 1, 1)]
void decompress(uint3 groupID : SV_GroupID, uint32_t index : SV_DispatchThreadID) {
// TODO: we probably want only one thread doing this.
// TODO: get a ticket with compare exchange instead
if (index == 0) {
size_t i = 0;
size_t j = 0;
while (i < push.srcLen) {
uint32_t magic = load32_unaligned(push.src + i);
if (magic != 0x184d2204) {
// halt and report error
}
i += 4;
uint8_t flg = push.src[i];
i++;
uint8_t bd = push.src[i];
i++;
// TODO: consult block independence as
if (flg & (1<<3)) {
// content size
//
// TODO: require this. We don't necessarily *need* this, but if
// content size is not present, we'll need to decompress this frame
// before we can proceed.
i += 8;
}
if (flg & (1<<0)) {
// dictionary ID
i += 4;
}
// header checksum
i++;
uint32_t endMark = 0;
while (true) {
uint32_t word = load32_unaligned(push.src + i);
if (word == 0x00000000) {
endMark = word;
break;
}
i += 4;
bool uncompressed = (word & 0x80000000) != 0;
uint32_t blockSize = word & 0x7fffffff;
// TODO: we need to do more work here to figure out block
// uncompressed size and such.
if (uncompressed) {
// TODO: we might be able to use the entire subgroup to speed
// this part up
for (size_t k = 0; k < blockSize; k += MAX_COPY_JOB_SIZE) {
Job job = {};
job.type = JOB_TYPE_COPY_MEMORY;
job.dst = push.dst + j + k;
job.src = push.src + i + k;
job.size = min(blockSize, MAX_COPY_JOB_SIZE);
uint32_t index = myAtomicAdd(&push.jobQueue.head, 1);
push.jobQueue.buf[index] = job;
}
i += blockSize;
j += blockSize;
} else {
// we need to scan the block ourselves and fan out jobs
i += blockSize;
}
if (flg & (1<<4)) {
// block checksum
i += 4;
}
}
// if (endMark != 0) {
// invalid end mark
// }
i += 4;
if (flg & (1 << 2)) {
// content checksum
i += 4;
}
// handle only one frame for now
break;
}
push.results.consumed = i;
push.results.produced = j;
// TODO: should be myAtomicStore
myAtomicAdd(&push.results.spinDown, 1);
}
while (true) {
bool quit = false;
uint32_t index;
if (WaveIsFirstLane()) {
uint32_t tail = myAtomicLoad(&push.jobQueue.tail);
while (true) {
uint32_t head = myAtomicLoad(&push.jobQueue.head);
if (tail < head) {
uint32_t was = myAtomicCompareExchange(&push.jobQueue.tail, tail, tail+1);
if (was == tail) {
index = tail;
break;
} else {
tail = was;
}
} else if (myAtomicLoad(&push.results.spinDown) != 0) {
quit = true;
break;
}
}
}
index = WaveReadLaneFirst(index);
if (WaveReadLaneFirst(quit))
return;
Job job = push.jobQueue.buf[index];
for (size_t i = WavePrefixSum(1); i < job.size; i += WaveActiveSum(1)) {
job.dst[i] = job.src[i];
}
}
// myAtomicLoad(&push.results.spinDown);
}
typedef uint64_t size_t;
// TODO: pass array of blocks to decompress, instead of params for a single block
struct Push {
uint8_t *src;
size_t srcLen;
uint8_t *dst;
size_t dstLen;
};
[[vk::push_constant]] Push push;
uint subgroupBallotFindLSB(uint4 value) {
return spirv_asm {
OpCapability GroupNonUniformBallot;
OpGroupNonUniformBallotFindLSB $$uint result Subgroup $value
};
}
void subgroupMemcpy(uint8_t *dst, uint8_t *src, size_t n) {
// TODO: use a bigger unit when possible
for (size_t i = WavePrefixSum(1); i < n; i += WaveActiveSum(1)) {
dst[i] = src[i];
}
/*
if (!WaveIsFirstLane())
return;
for (size_t i = 0; i < n; i++) {
dst[i] = src[i];
}
*/
}
struct ConsumedVarint {
uint64_t value;
size_t len;
};
ConsumedVarint consumeVarint(uint8_t *p, size_t n) {
// TODO: put this inside a loop, our subgroup might be narrower than
// the length of the run.
// TODO: do a scan with 1 instead of WaveGetLaneIndex
uint64_t value = 0;
size_t i = 0;
while (true) {
// Speculatively load varint bytes.
//
// TODO: we should have a separate tunable knob for how far we
// should speculate
uint8_t byte = 0;
if (i + WaveGetLaneIndex() < n) {
byte = p[i + WaveGetLaneIndex()];
}
uint run = subgroupBallotFindLSB(WaveActiveBallot(byte != 255)) + 1;
if (WaveGetLaneIndex() >= run) {
byte = 0;
}
value += WaveActiveSum((size_t)byte);
i += run;
if (run < WaveGetLaneCount())
break;
}
/*
for (; i < n;) {
uint8_t add = p[i++];
value += add;
if (add != 255)
break;
}
*/
ConsumedVarint result = {value, i};
return result;
}
[shader("compute")]
[numthreads(16, 1, 1)] // BUG: we just want single-subgroup workgroups
void decompressBlock( uint3 what : SV_GroupID /* uint32_t index : SV_DispatchThreadID */) {
// Currently we use a subgroup per block, we might want to scale up to
// workgroup per block if blocks are big enough in practice
if (all(what != uint3(0)))
return;
size_t i = 0;
size_t j = 0;
while (i < push.srcLen) {
uint8_t token = push.src[i++];
size_t literallength = (size_t)(token >> 4);
if ((token >> 4) == 15) {
// TODO: we can load this varint speculatively
ConsumedVarint wat = consumeVarint(push.src + i, push.srcLen - i);
literallength += wat.value;
i += wat.len;
}
subgroupMemcpy(push.dst + j, push.src + i, literallength);
i += literallength;
j += literallength;
if (i == push.srcLen)
break;
size_t offset = (size_t)((uint16_t)push.src[i+0] | ((uint16_t)push.src[i+1] << 8));
i += 2;
size_t matchlength = 4 + (size_t)(token & 0xf);
if ((token & 0xf) == 15) {
// TODO: we can load this varint speculatively
ConsumedVarint wat = consumeVarint(push.src + i, push.srcLen - i);
matchlength += wat.value;
i += wat.len;
}
// TODO: do this with the entire subgroup
if (WaveIsFirstLane()) {
for (size_t k = 0; k < matchlength; k++) {
push.dst[j+k] = push.dst[j-offset+k];
}
}
j += matchlength;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment