Skip to content

Instantly share code, notes, and snippets.

Last active April 30, 2016 02:52
Show Gist options
  • Save pengsun/ac9d40f33deba1e36e18389110994fa2 to your computer and use it in GitHub Desktop.
Save pengsun/ac9d40f33deba1e36e18389110994fa2 to your computer and use it in GitHub Desktop.
__global__ void OHNN_CudaLookupTable2_updateOutput_kernel(
float *inputInd, float *weight, int weightStride, int B, int M, int V, int C,
float *output, int outputStride)
int iFet = blockIdx.x * blockDim.x + threadIdx.x;
int iWord = blockIdx.y * blockDim.y + threadIdx.y;
if (iFet < C && iWord < B*M) {
int iVocab = (int)(inputInd[iWord] - 1); // C zero base <- lua one base
int nSrc = iVocab * weightStride + iFet;
int nDst = iWord * outputStride + iFet;
output[nDst] = weight[nSrc];
/// Expose
extern "C"
void OHNN_CudaLookupTable2_updateOutput(
THCState *state,
// In
THCudaTensor *input,
THCudaTensor *weight,
// Out
THCudaTensor *output)
THAssert(THCudaTensor_checkGPU(state, 3, input, weight, output));
// Cheat sheet:
// B = batch size,
// M = sequence length,
// V = vocabulary size = input dim
// C = embedding size = output dim = feature size
// input: B, M (,V)
// weight: V, C
// output: B, M, C
int B = THCudaTensor_size(state, input, 0);
int M = THCudaTensor_size(state, input, 1);
int V = THCudaTensor_size(state, weight, 0);
int C = THCudaTensor_size(state, weight, 1);
// prepare data
THCudaTensor_resize2d(state, output, B*M, C);
int outputStride = output->stride[0];
int weightStride = weight->stride[0];
// update output
cudaStream_t stream = THCState_getCurrentStream(state);
dim3 grid(DIV_CEIL(C, 32), DIV_CEIL(B*M, 32));
dim3 block(32, 32); // better memory access coalescing
OHNN_CudaLookupTable2_updateOutput_kernel<<<grid, block, 0, stream>>>(
THCudaTensor_data(state, input),
THCudaTensor_data(state, weight), weightStride,
B, M, V, C,
THCudaTensor_data(state, output), outputStride
// post process
THCudaTensor_resize3d(state, output, B, M, C);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment