Skip to content

Instantly share code, notes, and snippets.

@bzm3r
Created February 22, 2020 19:07
Show Gist options
  • Save bzm3r/9078999cbc209af2cd059e3d5b0536e0 to your computer and use it in GitHub Desktop.
Save bzm3r/9078999cbc209af2cd059e3d5b0536e0 to your computer and use it in GitHub Desktop.
#version 450
#define WORKGROUP_SIZE ~WG_SIZE~
layout(local_size_x = WORKGROUP_SIZE, local_size_y = 1) in;
layout(set = 0, binding = 0) buffer BM {
uint[32] bms[];
};
struct Uniforms
{
uint num_bms;
uint num_executions;
};
layout(set=0, binding = 1) uniform UniformInput
{
Uniforms u_consts;
};
uint shuffle_round_tg(uint dst_tid, uint a, uint b, uint m, uint s) {
uint c;
if ((dst_tid & s) == 0) {
c = b << s;
} else {
m = ~m;
c = b >> s;
}
return (a & m) | (c & ~m);
}
uint shuffle_round(uint a, uint m, uint s) {
uint b = subgroupShuffleXor(r, s);
uint c;
if ((gl_SubgroupInvocationID & s) == 0) {
c = b << s;
} else {
m = ~m;
c = b >> s;
}
return (a & m) | (c & ~m);
}
// Assuming 32x32 bit matrices, we expect at most gl_WorkGroupSize.x >> 5 to be processed by the workgroup.
shared uint[32] tg_bms[gl_WorkGroupSize.x >> 5];
const uint num_mats_per_wg = glWorkGroupSize.x >> 5;
const uint shifts[5] = uint[5](16, 8, 4, 2, 1);
const uint masks[5] = uint[5](0xffff, 0xff00ff, 0xf0f0f0f, 0x33333333, 0x55555555);
void main() {
uint local_bm_ix = gl_LocalInvocationID.x >> 5;
uint global_bm_ix = num_mats_per_wg*gl_WorkGroupID.x + local_bm_ix;
if (global_bm_ix < u_consts.num_bms) {
uint dst_tix = gl_LocalInvocationID.x;
uint s;
uint src_tix;
uint src_dat;
uint dst_dat;
for (uint iter = 0; iter < u_consts.num_executions; iter++) {
dst_dat = bms[global_bm_ix][dst_tix];
memoryBarrierShared();
barrier();
for (int i = 0; i < 5; i++) {
s = shifts[i];
m = masks[i];
if (s >= gl_SubgroupSize) {
tg_bms[local_bm_ix][dst_tix] = dst_dat;
memoryBarrierShared();
barrier();
src_tix = dst_tix^s;
src_dat = tg_bms[local_bm_ix][src_tix];
memoryBarrierShared();
barrier();
dst_dat = shuffle_round_tg(dst_tix, dst_dat, src_dat, m, s);
} else {
dst_dat = shuffle_round(dst_dat, m, s);
}
}
}
bms[global_bm_ix][dst_tix] = dst_dat;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment