Last active
April 27, 2016 22:20
-
-
Save pavanky/5c2ee34314e20605625bdc7b2bc54fea to your computer and use it in GitHub Desktop.
custom reduction
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
template<typename T> | |
__device__ T func(T a, T b) | |
{ | |
return log(1 + exp(a * b)); | |
} | |
template<typename T> | |
__global__ void my_reduce(T *C, T *A, T *B, int M, int N, int K) | |
{ | |
int n = blockIdx.x * blockDim.x + threadIdx.x; | |
int m = blockIdx.y; | |
T res = 0; | |
bool cond = n < N; | |
for (int k = threadIdx.y; k < K; k+=blockDim.y) { | |
T a = A[k * M + m]; | |
T b = cond ? B[k * N + n] : T(0); | |
res += func(a, b); | |
} | |
extern __shared__ char s_mem[]; | |
T *s_res = (T *)s_mem; | |
int id = threadIdx.x + blockDim.y * threadIdx.y; | |
s_res[id] = res; | |
for (int y = blockDim.y / 2; y >= 1; y /= 2) { | |
__syncthreads(); | |
if (threadIdx.y >= y) return; | |
s_res[id] = s_res[id + y]; | |
} | |
if (cond) C[n * M + m] = s_res[threadIdx.x]; | |
} | |
af::array my_reduce_launcher(const af::array &A, const af::array &B) | |
{ | |
int M = A.dims(0); | |
int N = B.dims(0); | |
int K = A.dims(1); | |
af::array C = af::array(M, N); | |
dim3 blocks(1, 1); | |
// Assuming N >> M | |
dim3 threads(32, 8); | |
blocks.x = (N + threads.x - 1) / threads.x; | |
blocks.y = M; | |
int smem = threads.x * threads.y * sizeof(float); | |
my_reduce<<<blocks, threads, smem, afcu::getStream(af::getDevice())>>>(C.device<float>(), | |
A.device<float>(), | |
B.device<float>(), | |
M, N, K); | |
A.unlock(); | |
B.unlock(); | |
C.unlock(); | |
return C; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment