Skip to content

Instantly share code, notes, and snippets.

@bzm3r
Created February 22, 2020 01:39
Show Gist options
  • Save bzm3r/fee14d40d8840be89a123715529a0ec9 to your computer and use it in GitHub Desktop.
Save bzm3r/fee14d40d8840be89a123715529a0ec9 to your computer and use it in GitHub Desktop.
#version 450
#extension GL_KHR_shader_subgroup_shuffle: enable
#define WORKGROUP_SIZE ~WG_SIZE~
// Unlike the threadgroup case, the Y-dimension of the workgroup size is not used.
// This is because the Y-dimension will be implicit in the number of subgroups in a workgroup.
layout(local_size_x = WORKGROUP_SIZE) in;
layout(set = 0, binding = 0) buffer BM {
uint[32] bms[];
};
struct Uniforms
{
uint num_bms;
uint num_executions;
};
layout(set=0, binding = 1) uniform UniformInput
{
Uniforms u_consts;
};
uint shuffle_round(uint r, uint m, uint s) {
uint b = subgroupShuffleXor(r, s);
uint c;
if ((gl_SubgroupInvocationID & s) == 0) {
c = b << s;
} else {
m = ~m;
c = b >> s;
}
return (r & m) | (c & ~m);
}
const uint shifts[5] = uint[5](16, 8, 4, 2, 1);
const uint masks[5] = uint[5](0xffff, 0xff00ff, 0xf0f0f0f, 0x33333333, 0x55555555);
void main() {
// First, note that x/(2.pow(y)) = x >> y. In our case, 2.pow(5) = 32 is particularly relevant.
// Suppose that S subgroups can fit within a workgroup. S is stored in gl_NumSubgroups.
// Within a subgroup, how many groups of 32 are there? This is given by (gl_SubgroupSize >> 5).
// gl_SubgroupID gives us the ID of the subgroup this invocation is in, within the workgroup.
// Finally, if gl_SubgroupSize > 32, then (gl_SubgroupInvocationID >> 5) gives us the matrix ID within the subgroup.
uint global_bm_ix = (gl_SubgroupSize >> 5)*gl_NumSubgroups*gl_WorkGroupID.x + gl_SubgroupID + (gl_SubgroupInvocationID >> 5);
//if (global_bm_ix < u_consts.num_bms) {
uint row_id = gl_SubgroupInvocationID & 31;
uint global_row = bms[global_bm_ix][row_id];
uint row;
for (uint iter = 0; iter < u_consts.num_executions; iter++) {
row = global_row;
row = shuffle_round(row, 0xffff, 16);
row = shuffle_round(row, 0xff00ff, 8);
row = shuffle_round(row, 0xf0f0f0f, 4);
row = shuffle_round(row, 0x33333333, 2);
row = shuffle_round(row, 0x55555555, 1);
// for (uint i = 0; i < 5; i++) {
// m = masks[i];
// s = shifts[i];
//
// row = shuffle_round(row, m, s);
// }
}
bms[global_bm_ix][row_id] = row;
//}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment