Created
February 22, 2020 19:07
-
-
Save bzm3r/9078999cbc209af2cd059e3d5b0536e0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#version 450 | |
#define WORKGROUP_SIZE ~WG_SIZE~ | |
layout(local_size_x = WORKGROUP_SIZE, local_size_y = 1) in; | |
layout(set = 0, binding = 0) buffer BM { | |
uint[32] bms[]; | |
}; | |
struct Uniforms | |
{ | |
uint num_bms; | |
uint num_executions; | |
}; | |
layout(set=0, binding = 1) uniform UniformInput | |
{ | |
Uniforms u_consts; | |
}; | |
uint shuffle_round_tg(uint dst_tid, uint a, uint b, uint m, uint s) { | |
uint c; | |
if ((dst_tid & s) == 0) { | |
c = b << s; | |
} else { | |
m = ~m; | |
c = b >> s; | |
} | |
return (a & m) | (c & ~m); | |
} | |
uint shuffle_round(uint a, uint m, uint s) { | |
uint b = subgroupShuffleXor(r, s); | |
uint c; | |
if ((gl_SubgroupInvocationID & s) == 0) { | |
c = b << s; | |
} else { | |
m = ~m; | |
c = b >> s; | |
} | |
return (a & m) | (c & ~m); | |
} | |
// Assuming 32x32 bit matrices, we expect at most gl_WorkGroupSize.x >> 5 to be processed by the workgroup. | |
shared uint[32] tg_bms[gl_WorkGroupSize.x >> 5]; | |
const uint num_mats_per_wg = glWorkGroupSize.x >> 5; | |
const uint shifts[5] = uint[5](16, 8, 4, 2, 1); | |
const uint masks[5] = uint[5](0xffff, 0xff00ff, 0xf0f0f0f, 0x33333333, 0x55555555); | |
void main() { | |
uint local_bm_ix = gl_LocalInvocationID.x >> 5; | |
uint global_bm_ix = num_mats_per_wg*gl_WorkGroupID.x + local_bm_ix; | |
if (global_bm_ix < u_consts.num_bms) { | |
uint dst_tix = gl_LocalInvocationID.x; | |
uint s; | |
uint src_tix; | |
uint src_dat; | |
uint dst_dat; | |
for (uint iter = 0; iter < u_consts.num_executions; iter++) { | |
dst_dat = bms[global_bm_ix][dst_tix]; | |
memoryBarrierShared(); | |
barrier(); | |
for (int i = 0; i < 5; i++) { | |
s = shifts[i]; | |
m = masks[i]; | |
if (s >= gl_SubgroupSize) { | |
tg_bms[local_bm_ix][dst_tix] = dst_dat; | |
memoryBarrierShared(); | |
barrier(); | |
src_tix = dst_tix^s; | |
src_dat = tg_bms[local_bm_ix][src_tix]; | |
memoryBarrierShared(); | |
barrier(); | |
dst_dat = shuffle_round_tg(dst_tix, dst_dat, src_dat, m, s); | |
} else { | |
dst_dat = shuffle_round(dst_dat, m, s); | |
} | |
} | |
} | |
bms[global_bm_ix][dst_tix] = dst_dat; | |
} | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment