Created
February 22, 2020 01:39
-
-
Save bzm3r/fee14d40d8840be89a123715529a0ec9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#version 450 | |
#extension GL_KHR_shader_subgroup_shuffle: enable | |
#define WORKGROUP_SIZE ~WG_SIZE~ | |
// Unlike the threadgroup case, the Y-dimension of the workgroup size is not used. | |
// This is because the Y-dimension will be implicit in the number of subgroups in a workgroup. | |
layout(local_size_x = WORKGROUP_SIZE) in; | |
layout(set = 0, binding = 0) buffer BM { | |
uint[32] bms[]; | |
}; | |
struct Uniforms | |
{ | |
uint num_bms; | |
uint num_executions; | |
}; | |
layout(set=0, binding = 1) uniform UniformInput | |
{ | |
Uniforms u_consts; | |
}; | |
uint shuffle_round(uint r, uint m, uint s) { | |
uint b = subgroupShuffleXor(r, s); | |
uint c; | |
if ((gl_SubgroupInvocationID & s) == 0) { | |
c = b << s; | |
} else { | |
m = ~m; | |
c = b >> s; | |
} | |
return (r & m) | (c & ~m); | |
} | |
const uint shifts[5] = uint[5](16, 8, 4, 2, 1); | |
const uint masks[5] = uint[5](0xffff, 0xff00ff, 0xf0f0f0f, 0x33333333, 0x55555555); | |
void main() { | |
// First, note that x/(2.pow(y)) = x >> y. In our case, 2.pow(5) = 32 is particularly relevant. | |
// Suppose that S subgroups can fit within a workgroup. S is stored in gl_NumSubgroups. | |
// Within a subgroup, how many groups of 32 are there? This is given by (gl_SubgroupSize >> 5). | |
// gl_SubgroupID gives us the ID of the subgroup this invocation is in, within the workgroup. | |
// Finally, if gl_SubgroupSize > 32, then (gl_SubgroupInvocationID >> 5) gives us the matrix ID within the subgroup. | |
uint global_bm_ix = (gl_SubgroupSize >> 5)*gl_NumSubgroups*gl_WorkGroupID.x + gl_SubgroupID + (gl_SubgroupInvocationID >> 5); | |
//if (global_bm_ix < u_consts.num_bms) { | |
uint row_id = gl_SubgroupInvocationID & 31; | |
uint global_row = bms[global_bm_ix][row_id]; | |
uint row; | |
for (uint iter = 0; iter < u_consts.num_executions; iter++) { | |
row = global_row; | |
row = shuffle_round(row, 0xffff, 16); | |
row = shuffle_round(row, 0xff00ff, 8); | |
row = shuffle_round(row, 0xf0f0f0f, 4); | |
row = shuffle_round(row, 0x33333333, 2); | |
row = shuffle_round(row, 0x55555555, 1); | |
// for (uint i = 0; i < 5; i++) { | |
// m = masks[i]; | |
// s = shifts[i]; | |
// | |
// row = shuffle_round(row, m, s); | |
// } | |
} | |
bms[global_bm_ix][row_id] = row; | |
//} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment