Skip to content

Instantly share code, notes, and snippets.

@scott-gray
Last active September 11, 2016 23:21
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save scott-gray/5a3cd70465dcd2fe1df1 to your computer and use it in GitHub Desktop.
Save scott-gray/5a3cd70465dcd2fe1df1 to your computer and use it in GitHub Desktop.
Custom pooling kernels
#!/usr/bin/python
import numpy as np
import pycuda.driver as drv
from pycuda.tools import context_dependent_memoize
from pycuda.compiler import SourceModule
class GaussianPool(object):
def __init__(self,
N, C, H, W, R, S,
stride_h, stride_w,
var_y, var_x,
mean_y, mean_x):
P = _ceil_div(H - R + 1, stride_h)
Q = _ceil_div(W - S + 1, stride_w)
self.N = N
self.C = C
self.K = C
self.H = H
self.W = W
self.R = R
self.S = S
self.P = P
self.Q = Q
self.str_h = stride_h
self.str_w = stride_w
self.var_y = var_y
self.var_x = var_x
self.mean_y = mean_y
self.mean_x = mean_x
self.dimI = (C,H,W,N)
self.dimO = (C,P,Q,N)
return P, Q
class GaussianPoolGPU(GaussianPool):
def __init__(self,
N, C, H, W, R, S,
stride_h, stride_w,
var_y, var_x,
mean_y, mean_x):
P, Q = super(GaussianPoolGPU, self).__init__(
N, C, H, W, R, S, stride_h, stride_w,
var_y, var_x, mean_y, mean_x)
magic_S = _magic32(R*S + 32, S)
magic_str_h = _magic32(H + R, str_h)
magic_str_w = _magic32(W + S, str_w)
self.fprop_args = [(Q, P, C), (N, 1, 1), _flatten([
var_y, var_x, mean_y, mean_x,
Q, N, Q*N, P*Q*N, H, W, W*N, H*W*N, R, S, R*S,
magic_S, stride_h, stride_w ])]
self.bprop_args = [(W, H, C), (N, 1, 1), _flatten([
var_y, var_x, mean_y, mean_x,
P, Q, N, Q*N, P*Q*N, W, W*N, H*W*N, R, S, R*S,
magic_S, stride_h, stride_w, magic_str_h, magic_str_w ])]
lut_size = R*S
if lut_size % 4 != 0:
lut_size += 4 - lut_size % 4
self.shared_size = lut_size * 4 * 2
def fprop(self, I, O, alpha=1.0, beta=0.0):
args = self.fprop_args
params = [args[0], args[1], O.gpudata, I.gpudata, alpha, beta] + args[2]
kernel = _get_fprop_kernel()
kernel.prepared_call(*params, shared_size=self.shared_size)
def bprop(self, I, O, alpha=1.0, beta=0.0):
args = self.bprop_args
params = [args[0], args[1], O.gpudata, I.gpudata, alpha, beta] + args[2]
kernel = _get_bprop_kernel()
kernel.prepared_call(*params, shared_size=self.shared_size)
class GaussianPoolCPU(GaussianPool):
def __init__(self,
N, C, H, W, R, S,
stride_h, stride_w,
var_y, var_x,
mean_y, mean_x):
super(GaussianPoolCPU, self).__init__(
N, C, H, W, R, S, stride_h, stride_w,
var_y, var_x, mean_y, mean_x)
kernel = np.array([
self.guassian(r,s)
for r in range(R)
for s in range(S)], dtype=np.float32)
kernel = np.sqrt(kernel) / np.sqrt(np.sum(kernel))
# print kernel.reshape((R,S))
# print np.sum(np.square(kernel.reshape((R,S))))
self.filter = kernel.reshape((1, -1, 1))
def guassian(self, r, s):
fy = float(r - self.R//2) - self.mean_y;
fx = float(s - self.S//2) - self.mean_x;
return np.exp(-(self.var_y*0.5*fy*fy + self.var_x*0.5*fx*fx), dtype=np.float32)
def fpool_slice(self, q, S, W, stride):
qs = q*stride
sliceI = []
sliceF = []
for s in range(S):
x = qs + s
if x >= 0 and x < W:
sliceI.append(x)
sliceF.append(s)
return sliceI, sliceF
def fprop(self, I, O, alpha=1.0, beta=0.0):
slicableI = I.reshape((self.C, -1, self.N))
slicableF = self.filter
W = self.W
S = self.S
O *= beta
for p in range(self.P):
sliceY, sliceR = self.fpool_slice(p, self.R, self.H, self.str_h)
for q in range(self.Q):
sliceX, sliceS = self.fpool_slice(q, self.S, self.W, self.str_w)
sliceI = np.array([
y*W + x
for y in sliceY
for x in sliceX], dtype=np.intp)
sliceF = np.array([
r*S + s
for r in sliceR
for s in sliceS], dtype=np.intp)
O[:,p,q,:] += np.sum(slicableI[:,sliceI,:] * slicableF[:,sliceF,:], axis=1) * alpha
def bpool_slice(self, x, S, Q, stride):
qs = x - (S - 1)
sliceI = []
sliceF = []
for s in range(S):
q = qs + s
if q % stride == 0:
q //= stride
if q >= 0 and q < Q:
sliceI.append(q)
sliceF.append(S - s - 1)
return sliceI, sliceF
def bprop(self, I, O, alpha=1.0, beta=0.0):
slicableI = I.reshape((self.K, -1, self.N))
slicableF = self.filter
Q = self.Q
S = self.S
O *= beta
for y in range(self.H):
sliceP, sliceR = self.bpool_slice(y, self.R, self.P, self.str_h)
for x in range(self.W):
sliceQ, sliceS = self.bpool_slice(x, self.S, self.Q, self.str_w)
sliceI = np.array([
p*Q + q
for p in sliceP
for q in sliceQ], dtype=np.intp)
sliceF = np.array([
r*S + s
for r in sliceR
for s in sliceS], dtype=np.intp)
O[:,y,x,:] += np.sum(slicableI[:,sliceI,:] * slicableF[:,sliceF,:], axis=1) * alpha
def _ceil_div(x, y):
return -(-x // y)
# Magic numbers and shift amounts for integer division
def _magic32(nmax, d):
nc = ((nmax + 1) // d) * d - 1
nbits = len(bin(nmax)) - 2
for p in range(0, 2 * nbits + 1):
if 2 ** p > nc * (d - 1 - (2 ** p - 1) % d):
m = (2 ** p + d - 1 - (2 ** p - 1) % d) // d
return (m, p)
raise ValueError("Can't find magic number for division")
# flatten a nested list of lists or values
def _flatten(lst):
return sum(([x] if not isinstance(x, (list, tuple))
else _flatten(x) for x in lst), [])
@context_dependent_memoize
def _get_fprop_kernel():
code = r"""
union LutEntry {
struct {
int sliceI;
float funcVal;
} data;
int2 data2;
};
__global__ void spool_fprop_guassian(
float* O, const float* I, float alpha, float beta,
float var_y, float var_x, float mean_y, float mean_x,
int Q, int N, int QN, int PQN, int H, int W, int WN, int HWN,
int R, int S, int RS, int magic_S, int shift_S,
int stride_h, int stride_w)
{
float __shared__ rcpSqrtSum;
extern __shared__ int2 lut[];
int tid = threadIdx.x;
int n = tid;
int q = blockIdx.x;
int p = blockIdx.y;
int k = blockIdx.z;
// zigzaq q back and forth to improve L2 cache perf
if (p & 1)
q = Q - q - 1;
I += n;
O += k*PQN + p*QN + q*N + n;
float O_val = beta > 0.0f ? __ldg(O) : 0.0f;
if (tid < 32)
{
int pr = p * stride_h;
int qs = q * stride_w;
int r_half = R >> 1;
int s_half = S >> 1;
int chan_offset = k * HWN;
float var_y2 = var_y * 0.5f;
float var_x2 = var_x * 0.5f;
float sum = 0.0f;
int rs = tid;
while (rs < RS)
{
// r = rs / S;
// s = rs % S;
int r = rs * magic_S; r >>= shift_S;
int s = rs - r*S;
int x = qs + s;
int y = pr + r;
LutEntry entry;
entry.data.sliceI = chan_offset + y*WN + x*N;
float fy = (float)(r - r_half) - mean_y;
float fx = (float)(s - s_half) - mean_x;
float val = expf( -(var_y2*fy*fy + var_x2*fx*fx) );
entry.data.funcVal = sqrtf(val);
sum += val;
lut[rs] = entry.data2;
rs += 32;
}
#pragma unroll
for (int i = 16; i > 0; i >>= 1)
sum += __shfl_xor(sum, i);
rcpSqrtSum = 1.0f / sqrtf(sum);
}
__syncthreads();
float rcp_sqrt_sum = rcpSqrtSum;
int rs = 0;
float out = 0.0f;
while (rs < RS)
{
LutEntry entry0;
LutEntry entry1;
LutEntry entry2;
LutEntry entry3;
entry0.data2 = lut[rs + 0];
entry1.data2 = lut[rs + 1];
entry2.data2 = lut[rs + 2];
entry3.data2 = lut[rs + 3];
float val0 = rs + 0 < RS ? __ldg(I + entry0.data.sliceI) : 0.0f;
float val1 = rs + 1 < RS ? __ldg(I + entry1.data.sliceI) : 0.0f;
float val2 = rs + 2 < RS ? __ldg(I + entry2.data.sliceI) : 0.0f;
float val3 = rs + 3 < RS ? __ldg(I + entry3.data.sliceI) : 0.0f;
out += val0 * entry0.data.funcVal * rcp_sqrt_sum;
out += val1 * entry1.data.funcVal * rcp_sqrt_sum;
out += val2 * entry2.data.funcVal * rcp_sqrt_sum;
out += val3 * entry3.data.funcVal * rcp_sqrt_sum;
rs += 4;
}
*O = out*alpha + O_val*beta;
}
"""
module = SourceModule(code, options=["--use_fast_math"])
kernel = module.get_function("spool_fprop_guassian")
kernel.prepare("PPffffffIIIIIIIIIIIIIII")
return kernel
@context_dependent_memoize
def _get_bprop_kernel():
code = r"""
union LutEntry {
struct {
int sliceI;
float funcVal;
} data;
int2 data2;
};
__global__ void spool_bprop_guassian(
float* O, const float* I, float alpha, float beta,
float var_y, float var_x, float mean_y, float mean_x,
int P, int Q, int N, int QN, int PQN, int W, int WN, int HWN,
int R, int S, int RS, int magic_S, int shift_S,
int stride_h, int stride_w,
int magic_stride_h, int shift_stride_h,
int magic_stride_w, int shift_stride_w)
{
int __shared__ lutSize;
extern __shared__ int2 lut[];
int tid = threadIdx.x;
int n = tid;
int x = blockIdx.x;
int y = blockIdx.y;
int c = blockIdx.z;
// zigzaq q back and forth to improve L2 cache perf
if (y & 1)
x = W - x - 1;
I += n;
O += c*HWN + y*WN + x*N + n;
float O_val = beta > 0.0f ? __ldg(O) : 0.0f;
int lut_size;
if (tid < 32)
{
int r_half = R >> 1;
int s_half = S >> 1;
float var_y2 = var_y * 0.5f;
float var_x2 = var_x * 0.5f;
float sum = 0.0f;
int rs = tid;
while (rs < RS)
{
// r = rs / S;
// s = rs % S;
int r = rs * magic_S; r >>= shift_S;
int s = rs - r*S;
float fy = (float)(r - r_half) - mean_y;
float fx = (float)(s - s_half) - mean_x;
sum += expf( -(var_y2*fy*fy + var_x2*fx*fx) );
rs += 32;
}
#pragma unroll
for (int i = 16; i > 0; i >>= 1)
sum += __shfl_xor(sum, i);
float rcp_sqrt_sum = 1.0f / sqrtf(sum);
int pr = y - (R - 1);
int qs = x - (S - 1);
int chan_offset = c * PQN;
unsigned dep_thd_mask = 0xffffffff;
dep_thd_mask >>= 32 - tid;
lut_size = 0;
rs = tid;
while (rs < RS)
{
// r = rs / S;
// s = rs % S;
int r = rs * magic_S; r >>= shift_S;
int s = rs - r*S;
int p_prime = pr + r;
int q_prime = qs + s;
// Invert kernel coordinates
r = R - r - 1;
s = S - s - 1;
// p = p_prime / stride_h
// p_mod = p_prime % stride_h
int p = p_prime * magic_stride_h; p >>= shift_stride_h;
int p_mod = p_prime - p*stride_h;
bool p_bounds = p_mod == 0 && p >= 0 && p < P;
// q = q_prime / stride_w
// q_mod = q_prime % stride_w
int q = q_prime * magic_stride_h; q >>= shift_stride_w;
int q_mod = q_prime - q*stride_w;
bool q_bounds = q_mod == 0 && q >= 0 && q < Q;
bool in_bounds = q_bounds && p_bounds;
// Get a mask of all valid slices in the warp
unsigned ballot = __ballot(in_bounds);
// Count the total valid slices
unsigned warp_slices = __popc(ballot);
if (in_bounds)
{
// Count all the valid slices below this threadid
unsigned dep_thd_cnt = __popc(dep_thd_mask & ballot);
LutEntry entry;
entry.data.sliceI = chan_offset + p*QN + q*N;
float fy = (float)(r - r_half) - mean_y;
float fx = (float)(s - s_half) - mean_x;
entry.data.funcVal = sqrtf(expf( -(var_y2*fy*fy + var_x2*fx*fx) )) * rcp_sqrt_sum;
lut[lut_size + dep_thd_cnt] = entry.data2;
}
lut_size += warp_slices;
rs += 32;
}
lutSize = lut_size;
}
__syncthreads();
lut_size = lutSize;
int rs = 0;
float out = 0.0f;
while (rs < lut_size)
{
LutEntry entry0;
LutEntry entry1;
LutEntry entry2;
LutEntry entry3;
entry0.data2 = lut[rs + 0];
entry1.data2 = lut[rs + 1];
entry2.data2 = lut[rs + 2];
entry3.data2 = lut[rs + 3];
float val0 = rs + 0 < lut_size ? __ldg(I + entry0.data.sliceI) : 0.0f;
float val1 = rs + 1 < lut_size ? __ldg(I + entry1.data.sliceI) : 0.0f;
float val2 = rs + 2 < lut_size ? __ldg(I + entry2.data.sliceI) : 0.0f;
float val3 = rs + 3 < lut_size ? __ldg(I + entry3.data.sliceI) : 0.0f;
out += val0 * entry0.data.funcVal;
out += val1 * entry1.data.funcVal;
out += val2 * entry2.data.funcVal;
out += val3 * entry3.data.funcVal;
rs += 4;
}
*O = out*alpha + O_val*beta;
}
"""
module = SourceModule(code, options=["--use_fast_math"])
kernel = module.get_function("spool_bprop_guassian")
kernel.prepare("PPffffffIIIIIIIIIIIIIIIIIII")
return kernel
from neon.backends.nervanagpu import NervanaGPU
ng = NervanaGPU()
N,C = (32,1)
H,W = (6,6)
R,S = (3,3)
str_h, str_w = (3,3)
var_y, var_x = (1,1)
mean_y,mean_x = (0,0)
cpu_pool = GaussianPoolCPU(
N, C, H, W, R, S,
str_h, str_w,
var_y, var_x,
mean_y,mean_x)
#I = np.random.uniform(-1.0, 1.0, cpu_pool.dimI)
#E = np.random.uniform(-1.0, 1.0, cpu_pool.dimO)
I = np.ones(cpu_pool.dimI)
E = np.ones(cpu_pool.dimO)
O = np.zeros(cpu_pool.dimO)
B = np.zeros(cpu_pool.dimI)
cpu_pool.fprop(I, O)
cpu_pool.bprop(E, B)
print O[0,:,:,0]
print B[0,:,:,0]
gpu_pool = GaussianPoolGPU(
N, C, H, W, R, S,
str_h, str_w,
var_y, var_x,
mean_y,mean_x)
I = ng.ones(gpu_pool.dimI)
E = ng.ones(gpu_pool.dimO)
O = ng.zeros(gpu_pool.dimO)
B = ng.zeros(gpu_pool.dimI)
gpu_pool.fprop(I, O)
gpu_pool.bprop(E, B)
print O.get()[0,:,:,0]
print B.get()[0,:,:,0]
union LutEntry {
struct {
int sliceI;
float funcVal;
} data;
int2 data2;
};
extern "C"
__global__ void spool_bprop_gaussian(
float* O,
const float* I,
int P,
int Q,
int N,
int QN,
int PQN,
int W,
int WN,
int HWN,
int R,
int S,
int RS,
int magic_S,
int shift_S,
int stride_h,
int stride_w,
int magic_stride_h,
int shift_stride_h,
int magic_stride_w,
int shift_stride_w,
int pad_h,
int pad_w,
float var_y,
float var_x,
float mean_y,
float mean_x
)
{
int __shared__ lutSize;
extern __shared__ int2 lut[];
int tid = threadIdx.x;
int n = tid;
int x = blockIdx.x;
int y = blockIdx.y;
int c = blockIdx.z;
// zigzaq q back and forth to improve L2 cache perf
if (y & 1)
x = W - x - 1;
I += n;
O += c*HWN + y*WN + x*N + n;
int lut_size;
if (tid < 32)
{
int pr = y - (R - pad_h - 1);
int qs = x - (S - pad_w - 1);
int r_half = (R - 1) >> 1;
int s_half = (S - 1) >> 1;
int chan_offset = c * PQN;
float var_y2 = var_y * 0.5f;
float var_x2 = var_x * 0.5f;
unsigned dep_thd_mask = 0xffffffff;
dep_thd_mask >>= 32 - tid;
lut_size = 0;
int rs = tid;
while (rs < RS)
{
// r = rs / S;
// s = rs % S;
int r = rs * magic_S; r >>= shift_S;
int s = rs - r*S;
int p_prime = pr + r;
int q_prime = qs + s;
// Invert kernel coordinates
r = R - r - 1;
s = S - s - 1;
// p = p_prime / stride_h
// p_mod = p_prime % stride_h
int p = p_prime * magic_stride_h; p >>= shift_stride_h;
int p_mod = p_prime - p*stride_h;
bool p_bounds = p_mod == 0 && p >= 0 && p < P;
// q = q_prime / stride_w
// q_mod = q_prime % stride_w
int q = q_prime * magic_stride_h; q >>= shift_stride_w;
int q_mod = q_prime - q*stride_w;
bool q_bounds = q_mod == 0 && q >= 0 && q < Q;
bool in_bounds = q_bounds && p_bounds;
// Get a mask of all valid slices in the warp
unsigned ballot = __ballot(in_bounds);
// Count the total valid slices
unsigned warp_slices = __popc(ballot);
if (in_bounds)
{
// Count all the valid slices below this threadid
unsigned dep_thd_cnt = __popc(dep_thd_mask & ballot);
LutEntry entry;
entry.data.sliceI = chan_offset + p*QN + q*N;
float fy = (float)(r - r_half) - mean_y;
float fx = (float)(s - s_half) - mean_x;
entry.data.funcVal = expf( -(var_y2*fy*fy + var_x2*fx*fx) );
lut[lut_size + dep_thd_cnt] = entry.data2;
}
lut_size += warp_slices;
rs += 32;
}
lutSize = lut_size;
}
__syncthreads();
lut_size = lutSize;
int rs = 0;
float out = 0.0f;
while (rs < lut_size)
{
LutEntry entry0;
LutEntry entry1;
LutEntry entry2;
LutEntry entry3;
entry0.data2 = lut[rs + 0];
entry1.data2 = lut[rs + 1];
entry2.data2 = lut[rs + 2];
entry3.data2 = lut[rs + 3];
float val0 = rs + 0 < lut_size ? __ldg(I + entry0.data.sliceI) : 0.0f;
float val1 = rs + 1 < lut_size ? __ldg(I + entry1.data.sliceI) : 0.0f;
float val2 = rs + 2 < lut_size ? __ldg(I + entry2.data.sliceI) : 0.0f;
float val3 = rs + 3 < lut_size ? __ldg(I + entry3.data.sliceI) : 0.0f;
out += val0 * entry0.data.funcVal;
out += val1 * entry1.data.funcVal;
out += val2 * entry2.data.funcVal;
out += val3 * entry3.data.funcVal;
rs += 4;
}
*O = out;
}
union LutEntry {
struct {
int sliceI;
float funcVal;
} data;
int2 data2;
};
extern "C"
__global__ void spool_fprop_gaussian(
float* O,
const float* I,
int Q,
int N,
int QN,
int PQN,
int H,
int W,
int WN,
int HWN,
int R,
int S,
int RS,
int magic_S,
int shift_S,
int stride_h,
int stride_w,
int pad_h,
int pad_w,
float var_y,
float var_x,
float mean_y,
float mean_x
)
{
int __shared__ lutSize;
extern __shared__ int2 lut[];
int tid = threadIdx.x;
int n = tid;
int q = blockIdx.x;
int p = blockIdx.y;
int k = blockIdx.z;
// zigzaq q back and forth to improve L2 cache perf
if (p & 1)
q = Q - q - 1;
I += n;
O += k*PQN + p*QN + q*N + n;
int lut_size;
if (tid < 32)
{
int pr = p * stride_h - pad_h;
int qs = q * stride_w - pad_w;
int r_half = (R - 1) >> 1;
int s_half = (S - 1) >> 1;
int chan_offset = k * HWN;
float var_y2 = var_y * 0.5f;
float var_x2 = var_x * 0.5f;
unsigned dep_thd_mask = 0xffffffff;
dep_thd_mask >>= 32 - tid;
lut_size = 0;
int rs = tid;
while (rs < RS)
{
// r = rs / S;
// s = rs % S;
int r = rs * magic_S; r >>= shift_S;
int s = rs - r*S;
int x = qs + s;
int y = pr + r;
bool in_bounds = x >= 0 && x < W && y >= 0 && y < H;
// Get a mask of all valid slices in the warp
unsigned ballot = __ballot(in_bounds);
// Count the total valid slices
unsigned warp_slices = __popc(ballot);
if (in_bounds)
{
// Count all the valid slices below this threadid
unsigned dep_thd_cnt = __popc(dep_thd_mask & ballot);
LutEntry entry;
entry.data.sliceI = chan_offset + y*WN + x*N;
float fy = (float)(r - r_half) - mean_y;
float fx = (float)(s - s_half) - mean_x;
entry.data.funcVal = expf( -(var_y2*fy*fy + var_x2*fx*fx) );
lut[lut_size + dep_thd_cnt] = entry.data2;
}
lut_size += warp_slices;
rs += 32;
}
lutSize = lut_size;
}
__syncthreads();
lut_size = lutSize;
int rs = 0;
float out = 0.0f;
while (rs < lut_size)
{
LutEntry entry0;
LutEntry entry1;
LutEntry entry2;
LutEntry entry3;
entry0.data2 = lut[rs + 0];
entry1.data2 = lut[rs + 1];
entry2.data2 = lut[rs + 2];
entry3.data2 = lut[rs + 3];
float val0 = rs + 0 < lut_size ? __ldg(I + entry0.data.sliceI) : 0.0f;
float val1 = rs + 1 < lut_size ? __ldg(I + entry1.data.sliceI) : 0.0f;
float val2 = rs + 2 < lut_size ? __ldg(I + entry2.data.sliceI) : 0.0f;
float val3 = rs + 3 < lut_size ? __ldg(I + entry3.data.sliceI) : 0.0f;
out += val0 * entry0.data.funcVal;
out += val1 * entry1.data.funcVal;
out += val2 * entry2.data.funcVal;
out += val3 * entry3.data.funcVal;
rs += 4;
}
*O = out;
}
union LutEntry {
struct {
int sliceI;
float funcVal;
} data;
int2 data2;
};
extern "C"
__global__ void spool_fprop_gaussian_nopad(
float* O,
const float* I,
int Q,
int N,
int QN,
int PQN,
int WN,
int HWN,
int R,
int S,
int RS,
int magic_S,
int shift_S,
int stride_y,
int stride_x,
float var_y,
float var_x,
float mean_y,
float mean_x
)
{
extern __shared__ int2 lut[];
int tid = threadIdx.x;
int n = tid;
int q = blockIdx.x;
int p = blockIdx.y;
int k = blockIdx.z;
// zigzaq q back and forth to improve L2 cache perf
if (p & 1)
q = Q - q - 1;
I += n;
O += k*PQN + p*QN + q*N + n;
int pr = p * stride_y;
int qs = q * stride_x;
int r_half = (R - 1) >> 1;
int s_half = (S - 1) >> 1;
int chan_offset = k * HWN;
float var_y2 = var_y * 0.5f;
float var_x2 = var_x * 0.5f;
int rs = tid;
while (rs < RS)
{
// r = rs / S;
// s = rs % S;
int r = rs * magic_S; r >>= shift_S;
int s = rs - r*S;
int x = qs + s;
int y = pr + r;
LutEntry entry;
entry.data.sliceI = chan_offset + y*WN + x*N;
float fy = (float)(r - r_half) - mean_y;
float fx = (float)(s - s_half) - mean_x;
entry.data.funcVal = expf( -(var_y2*fy*fy + var_x2*fx*fx) );
lut[rs] = entry.data2;
rs += blockDim.x;
}
__syncthreads();
rs = 0;
float out = 0.0f;
while (rs < RS)
{
LutEntry entry0;
LutEntry entry1;
LutEntry entry2;
LutEntry entry3;
entry0.data2 = lut[rs + 0];
entry1.data2 = lut[rs + 1];
entry2.data2 = lut[rs + 2];
entry3.data2 = lut[rs + 3];
float val0 = rs + 0 < RS ? __ldg(I + entry0.data.sliceI) : 0.0f;
float val1 = rs + 1 < RS ? __ldg(I + entry1.data.sliceI) : 0.0f;
float val2 = rs + 2 < RS ? __ldg(I + entry2.data.sliceI) : 0.0f;
float val3 = rs + 3 < RS ? __ldg(I + entry3.data.sliceI) : 0.0f;
out += val0 * entry0.data.funcVal;
out += val1 * entry1.data.funcVal;
out += val2 * entry2.data.funcVal;
out += val3 * entry3.data.funcVal;
rs += 4;
}
*O = out;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment