Skip to content

Instantly share code, notes, and snippets.

@raver119
Created May 13, 2016 12:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save raver119/79be39e6fe8d3b865d88e440c4b981d5 to your computer and use it in GitHub Desktop.
Save raver119/79be39e6fe8d3b865d88e440c4b981d5 to your computer and use it in GitHub Desktop.
__inline__ __device__ virtual void transformCuda1D(T *dx,
int *xShapeInfo,
T *extraParams,
T *result,
int *resultShapeInfo,
int *dimension,
int dimensionLength,
T *reductionBuffer, UnifiedSharedMemory<T> *manager, int *tadOnlyShapeInfo) {
__shared__ int resultLength;
//shared memory space for storing intermediate results
T *sPartials = manager->getSharedReductionBuffer();
sPartials[threadIdx.x] = this->startingValue(dx);
if (threadIdx.x == 0)
resultLength = shape::length(resultShapeInfo);
__shared__ shape::TAD *tad;
__shared__ int tadLength;
__shared__ int tadEWS;
__shared__ int tadRank;
if (threadIdx.x == 0) {
tad = new(manager->getTADSpace()) shape::TAD(); //(xShapeInfo,dimension,dimensionLength)
tad->setExternalBuffers((void *) manager);
tad->initWithExternalTAD(tadOnlyShapeInfo, xShapeInfo, dimension, dimensionLength);
//tad->init(xShapeInfo,dimension,dimensionLength);
//tad->createTadOnlyShapeInfo();
tadLength = shape::length(tadOnlyShapeInfo);
tadEWS = shape::elementWiseStride(tadOnlyShapeInfo);
tadRank = shape::rank(tadOnlyShapeInfo);
}
__syncthreads();
for (int r = blockIdx.x; r < tad->numTads; r += gridDim.x) {
if (threadIdx.x == 0)
tad->createOffsetForBlock(r);
__syncthreads();
sPartials[threadIdx.x] = this->startingValue(dx + tad->tadOffsetForBlock);
for(int i = threadIdx.x; i < tadLength; i+= blockDim.x) {
sPartials[threadIdx.x] = this->update(sPartials[threadIdx.x],this->op(dx[tad->tadOffsetForBlock + i * tadEWS], extraParams), extraParams);
}
__syncthreads();
// aggregate. do NOT reduce for elements > tadLength
aggregatePartials(sPartials, threadIdx.x, nd4j::math::nd4j_min<int>(blockDim.x, tadLength), extraParams);
__syncthreads();
if (threadIdx.x == 0) {
result[r] = this->postProcess(sPartials[threadIdx.x], tadLength, extraParams);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment