Created
March 16, 2017 20:10
-
-
Save killeent/fb6a6391cb0097dab5ad5b0628dc4b1f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Block-wide reduction in shared memory helper; only threadIdx.x == 0 will | |
// return the reduced value | |
template <typename T, typename ReduceOp> | |
__device__ T reduceBlock(T* smem, | |
int numVals, | |
T threadVal, | |
ReduceOp reduceOp, | |
T init) { | |
if (numVals == 0) { | |
return init; | |
} | |
if (threadIdx.x < numVals) { | |
smem[threadIdx.x] = threadVal; | |
} | |
// First warp will perform reductions across warps | |
__syncthreads(); | |
if ((threadIdx.x / warpSize) == 0) { | |
T r = threadIdx.x < numVals ? smem[threadIdx.x] : init; | |
for (int i = warpSize + threadIdx.x; i < numVals; i += warpSize) { | |
r = reduceOp(r, smem[i]); | |
} | |
smem[threadIdx.x] = r; | |
} | |
// First thread will perform reductions across the block | |
__syncthreads(); | |
T r = init; | |
if (threadIdx.x == 0) { | |
r = smem[0]; | |
int numLanesParticipating = min(numVals, warpSize); | |
if (numLanesParticipating == 32) { | |
// Unroll for warpSize == 32 and numVals >= 32 | |
#pragma unroll | |
for (int i = 1; i < 32; ++i) { | |
r = reduceOp(r, smem[i]); | |
} | |
} else { | |
for (int i = 1; i < numLanesParticipating; ++i) { | |
r = reduceOp(r, smem[i]); | |
} | |
} | |
} | |
return r; | |
} | |
// Block-wide reduction where each thread locally reduces N | |
// values before letting a single warp take over | |
template <typename T, typename ReduceOp, int N> | |
__device__ T reduceBlockN(T *smem, | |
int numVals, | |
ReduceOp reduceOp, | |
T init) { | |
T local = threadIdx.x < numVals ? smem[threadIdx.x] : init; | |
#pragma unroll | |
for (int i = 1; i < N; ++i) { | |
int index = threadIdx.x + (i * blockDim.x); | |
T next = index < numVals ? smem[index] : init; | |
local = reduceOp(local, next); | |
} | |
return reduceBlock<T, ReduceOp>(smem, blockDim.x < numVals ? blockDim.x : numVals, local, reduceOp, init); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment