Created
November 24, 2017 05:27
-
-
Save vbkaisetsu/53d0d4cd715fee5f8fe584ebf8459232 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#define REDUCE(k, GROUP_SIZE) \ | |
if (GROUP_SIZE >= k << 1) { \ | |
if (tid < k) { \ | |
if (max_val[tid + k] > max_val[tid]) { \ | |
max_val[tid] = max_val[tid + k]; \ | |
argmax_val[tid] = argmax_val[tid + k]; \ | |
} \ | |
} \ | |
barrier(CLK_LOCAL_MEM_FENCE); \ | |
} | |
#define ARGMAX_KERNEL(GROUP_SIZE) \ | |
kernel void argmax_kernel_##GROUP_SIZE( \ | |
const global float *px, const unsigned skip, \ | |
const unsigned n, global unsigned *py) { \ | |
const unsigned bid = get_group_id(0); \ | |
const unsigned tid = get_local_id(0); \ | |
local float max_val[GROUP_SIZE]; \ | |
local unsigned argmax_val[GROUP_SIZE]; \ | |
px += bid % skip + (bid / skip) * skip * n; \ | |
max_val[tid] = -1e38; \ | |
for (unsigned i = tid; i < n; i += GROUP_SIZE) { \ | |
const float val = px[i * skip]; \ | |
if (val > max_val[tid]) { \ | |
max_val[tid] = val; \ | |
argmax_val[tid] = i; \ | |
} \ | |
} \ | |
barrier(CLK_LOCAL_MEM_FENCE); \ | |
REDUCE(512, GROUP_SIZE) \ | |
REDUCE(256, GROUP_SIZE) \ | |
REDUCE(128, GROUP_SIZE) \ | |
REDUCE(64, GROUP_SIZE) \ | |
REDUCE(32, GROUP_SIZE) \ | |
REDUCE(16, GROUP_SIZE) \ | |
REDUCE(8, GROUP_SIZE) \ | |
REDUCE(4, GROUP_SIZE) \ | |
REDUCE(2, GROUP_SIZE) \ | |
REDUCE(1, GROUP_SIZE) \ | |
if (tid == 0) py[bid] = argmax_val[0]; \ | |
} | |
ARGMAX_KERNEL(1024) | |
ARGMAX_KERNEL(512) | |
ARGMAX_KERNEL(256) | |
ARGMAX_KERNEL(128) | |
ARGMAX_KERNEL(64) | |
ARGMAX_KERNEL(32) | |
ARGMAX_KERNEL(16) | |
ARGMAX_KERNEL(8) | |
ARGMAX_KERNEL(4) | |
ARGMAX_KERNEL(2) | |
ARGMAX_KERNEL(1) | |
#undef REDUCE | |
#define REDUCE(k, GROUP_SIZE) \ | |
if (GROUP_SIZE >= k << 1) { \ | |
if (tid < k) { \ | |
if (min_val[tid + k] < min_val[tid]) { \ | |
min_val[tid] = min_val[tid + k]; \ | |
argmin_val[tid] = argmin_val[tid + k]; \ | |
} \ | |
} \ | |
barrier(CLK_LOCAL_MEM_FENCE); \ | |
} | |
#define ARGMIN_KERNEL(GROUP_SIZE) \ | |
kernel void argmin_kernel_##GROUP_SIZE( \ | |
const global float *px, const unsigned skip, \ | |
const unsigned n, global unsigned *py) { \ | |
const unsigned bid = get_group_id(0); \ | |
const unsigned tid = get_local_id(0); \ | |
local float min_val[GROUP_SIZE]; \ | |
local unsigned argmin_val[GROUP_SIZE]; \ | |
px += bid % skip + (bid / skip) * skip * n; \ | |
min_val[tid] = 1e38; \ | |
for (unsigned i = tid; i < n; i += GROUP_SIZE) { \ | |
const float val = px[i * skip]; \ | |
if (val < min_val[tid]) { \ | |
min_val[tid] = val; \ | |
argmin_val[tid] = i; \ | |
} \ | |
} \ | |
barrier(CLK_LOCAL_MEM_FENCE); \ | |
REDUCE(512, GROUP_SIZE) \ | |
REDUCE(256, GROUP_SIZE) \ | |
REDUCE(128, GROUP_SIZE) \ | |
REDUCE(64, GROUP_SIZE) \ | |
REDUCE(32, GROUP_SIZE) \ | |
REDUCE(16, GROUP_SIZE) \ | |
REDUCE(8, GROUP_SIZE) \ | |
REDUCE(4, GROUP_SIZE) \ | |
REDUCE(2, GROUP_SIZE) \ | |
REDUCE(1, GROUP_SIZE) \ | |
if (tid == 0) py[bid] = argmin_val[0]; \ | |
} | |
ARGMIN_KERNEL(1024) | |
ARGMIN_KERNEL(512) | |
ARGMIN_KERNEL(256) | |
ARGMIN_KERNEL(128) | |
ARGMIN_KERNEL(64) | |
ARGMIN_KERNEL(32) | |
ARGMIN_KERNEL(16) | |
ARGMIN_KERNEL(8) | |
ARGMIN_KERNEL(4) | |
ARGMIN_KERNEL(2) | |
ARGMIN_KERNEL(1) | |
#undef REDUCE | |
kernel void set_identity_kernel( | |
const unsigned size, const unsigned skip, global float *py) { | |
const unsigned i = get_global_id(0); | |
if (i < size) py[i] = !(i % skip); | |
} | |
kernel void pick_fw_kernel( | |
const global float *px, const global unsigned *pi, | |
const unsigned wx, const unsigned wy, const unsigned sx, | |
const unsigned si, const unsigned sy, global float *py) { | |
const unsigned t = get_global_id(0); | |
const unsigned bid_y = get_group_id(1); | |
const unsigned ox = bid_y * sx + pi[bid_y * si] * wy; | |
const unsigned oy = bid_y * sy; | |
if (t < sy) py[oy + t] = px[ox + (t / wy) * wx + (t % wy)]; | |
} | |
kernel void slice_fw_kernel( | |
const global float *px, const unsigned shift, const unsigned span, | |
const unsigned skip, const unsigned size, global float *py) { | |
const unsigned i = get_global_id(0); | |
if (i < size) py[i] = px[(i / span) * skip + (i % span) + shift]; | |
} | |
kernel void concat_fw_kernel( | |
const global float *px, const unsigned span, const unsigned skip, | |
const unsigned x_size, const unsigned y_size, | |
global float *py, const unsigned shift) { | |
const unsigned i = get_global_id(0); | |
if (i < y_size) py[(i / span) * skip + (i % span) + shift] = px[i % x_size]; | |
} | |
inline void atomic_add_float(global float *source, const float operand) { | |
union { | |
unsigned u; | |
float f; | |
} oldval, newval; | |
unsigned readback; | |
oldval.f = *source; | |
newval.f = oldval.f + operand; | |
while ((readback = atomic_cmpxchg( | |
(global unsigned *) source, oldval.u, newval.u)) != oldval.u) { | |
oldval.u = readback; | |
newval.f = oldval.f + operand; | |
} | |
} | |
kernel void pick_bw_kernel( | |
const global float *pgy, const global unsigned *pi, | |
const unsigned wx, const unsigned wy, | |
const unsigned sx, const unsigned si, const unsigned sy, | |
global float *pgx) { | |
const unsigned t = get_global_id(0); | |
const unsigned bid_y = get_group_id(1); | |
const unsigned ox = bid_y * sx + pi[bid_y * si] * wy; | |
const unsigned oy = bid_y * sy; | |
if (t < sy) { | |
atomic_add_float(pgx + ox + (t / wy) * wx + (t % wy), pgy[oy + t]); | |
} | |
} | |
kernel void slice_bw_kernel( | |
const global float *pgy, const unsigned wx, const unsigned wy, | |
const unsigned nx, const unsigned ny, | |
global float *pgx, const unsigned shift) { | |
const unsigned i = get_global_id(0); | |
if (i < wy * max(nx, ny)) { | |
atomic_add_float( | |
pgx + shift + ((i / wy) * wx + (i % wy)) % (wx * nx), | |
pgy[i % (wy * ny)]); | |
} | |
} | |
#define OPENCLDEV_KERNEL_FW_X(name, op) \ | |
kernel void name##_fw_kernel( \ | |
const global float *px, const unsigned size, global float *py) { \ | |
const unsigned i = get_global_id(0); \ | |
if (i < size) py[i] = (op); \ | |
} | |
#define OPENCLDEV_KERNEL_BW_X(name, op) \ | |
kernel void name##_bw_kernel( \ | |
const global float *px, const global float *py, const global float *pgy, \ | |
const unsigned size, global float *pgx) { \ | |
const unsigned i = get_global_id(0); \ | |
if (i < size) pgx[i] += (op); \ | |
} | |
#define OPENCLDEV_KERNEL_FW_X_CONST(name, op) \ | |
kernel void name##_fw_kernel( \ | |
const global float *px, const float k, \ | |
const unsigned size, global float *py) { \ | |
const unsigned i = get_global_id(0); \ | |
if (i < size) py[i] = (op); \ | |
} | |
#define OPENCLDEV_KERNEL_BW_X_CONST(name, op) \ | |
kernel void name##_bw_kernel( \ | |
const global float *px, const global float *py, const global float *pgy, \ | |
const float k, const unsigned size, global float *pgx) { \ | |
const unsigned i = get_global_id(0); \ | |
if (i < size) pgx[i] += (op); \ | |
} | |
#define OPENCLDEV_KERNEL_FW_X_SCALAR_R_INFIX(name, op) \ | |
kernel void name##_fw_kernel( \ | |
const global float *px, const global float *pk, const unsigned size, \ | |
const unsigned mbx, const unsigned mbk, global float *py) { \ | |
const unsigned i = get_global_id(0); \ | |
const unsigned bid_y = get_group_id(1); \ | |
const unsigned shift = bid_y * size; \ | |
if (i < size) { \ | |
py[i + shift] = px[i + mbx * shift] op pk[mbk * bid_y]; \ | |
} \ | |
} | |
#define OPENCLDEV_KERNEL_FW_X_SCALAR_L_INFIX(name, op) \ | |
kernel void name##_fw_kernel( \ | |
const global float *px, const global float *pk, const unsigned size, \ | |
const unsigned mbx, const unsigned mbk, global float *py) { \ | |
const unsigned i = get_global_id(0); \ | |
const unsigned bid_y = get_group_id(1); \ | |
const unsigned shift = bid_y * size; \ | |
if (i < size) { \ | |
py[i + shift] = pk[mbk * bid_y] op px[i + mbx * shift]; \ | |
} \ | |
} | |
#define OPENCLDEV_KERNEL_FW_AB_INFIX(name, op) \ | |
kernel void name##_fw_kernel( \ | |
const global float *pa, const global float *pb, const unsigned size, \ | |
const unsigned mba, const unsigned mbb, global float *py) { \ | |
const unsigned i = get_global_id(0); \ | |
const unsigned bid_y = get_group_id(1); \ | |
const unsigned shift = bid_y * size; \ | |
if (i < size) { \ | |
py[i + shift] = pa[i + mba * shift] op \ | |
pb[i + mbb * shift]; \ | |
} \ | |
} | |
OPENCLDEV_KERNEL_FW_X(negate, -px[i]) | |
OPENCLDEV_KERNEL_FW_X(sqrt, sqrt(px[i])) | |
OPENCLDEV_KERNEL_FW_X(exp, exp(px[i])) | |
OPENCLDEV_KERNEL_FW_X(log, log(px[i])) | |
OPENCLDEV_KERNEL_FW_X(tanh, tanh(px[i])) | |
OPENCLDEV_KERNEL_FW_X(sigmoid, .5f + .5f * tanh(.5f * px[i])) | |
OPENCLDEV_KERNEL_FW_X( | |
softplus, max(px[i], .0f) + log(1.f + exp(-fabs(px[i])))) | |
OPENCLDEV_KERNEL_FW_X(sin, sin(px[i])) | |
OPENCLDEV_KERNEL_FW_X(cos, cos(px[i])) | |
OPENCLDEV_KERNEL_FW_X(tan, tan(px[i])) | |
OPENCLDEV_KERNEL_BW_X(sqrt, .5f * pgy[i] / py[i]) | |
OPENCLDEV_KERNEL_BW_X(exp, py[i] * pgy[i]) | |
OPENCLDEV_KERNEL_BW_X(log, pgy[i] / px[i]) | |
OPENCLDEV_KERNEL_BW_X(tanh, (1.f - py[i] * py[i]) * pgy[i]) | |
OPENCLDEV_KERNEL_BW_X(sigmoid, py[i] * (1.f - py[i]) * pgy[i]) | |
OPENCLDEV_KERNEL_BW_X(softplus, (.5f + .5f * tanh(.5f * px[i])) * pgy[i]) | |
OPENCLDEV_KERNEL_BW_X(sin, cos(px[i]) * pgy[i]) | |
OPENCLDEV_KERNEL_BW_X(cos, -sin(px[i]) * pgy[i]) | |
OPENCLDEV_KERNEL_BW_X(tan, (1.f + py[i] * py[i]) * pgy[i]) | |
OPENCLDEV_KERNEL_FW_X_CONST(add_const, px[i] + k) | |
OPENCLDEV_KERNEL_FW_X_CONST(subtract_const_r, px[i] - k) | |
OPENCLDEV_KERNEL_FW_X_CONST(subtract_const_l, k - px[i]) | |
OPENCLDEV_KERNEL_FW_X_CONST(multiply_const, px[i] * k) | |
OPENCLDEV_KERNEL_FW_X_CONST(divide_const_r, px[i] / k) | |
OPENCLDEV_KERNEL_FW_X_CONST(divide_const_l, k / px[i]) | |
OPENCLDEV_KERNEL_FW_X_CONST(prelu, max(px[i], .0f) + k * min(px[i], .0f)) | |
OPENCLDEV_KERNEL_FW_X_CONST( | |
elu, max(px[i], .0f) + k * (exp(min(px[i], .0f)) - 1.0f)) | |
OPENCLDEV_KERNEL_BW_X_CONST(add_const, pgy[i]) | |
OPENCLDEV_KERNEL_BW_X_CONST(subtract_const_r, pgy[i]) | |
OPENCLDEV_KERNEL_BW_X_CONST(subtract_const_l, -pgy[i]) | |
OPENCLDEV_KERNEL_BW_X_CONST(multiply_const, k * pgy[i]) | |
OPENCLDEV_KERNEL_BW_X_CONST(divide_const_r, pgy[i] / k) | |
OPENCLDEV_KERNEL_BW_X_CONST(divide_const_l, -py[i] * pgy[i] / px[i]) | |
OPENCLDEV_KERNEL_BW_X_CONST( | |
prelu, pgy[i] * ((px[i] > .0f) + k * (px[i] <= .0f))) | |
OPENCLDEV_KERNEL_BW_X_CONST( | |
elu, pgy[i] * ((px[i] > .0f) + (py[i] + k) * (px[i] <= .0f))) | |
OPENCLDEV_KERNEL_FW_X_SCALAR_R_INFIX(add_scalar, +) | |
OPENCLDEV_KERNEL_FW_X_SCALAR_R_INFIX(subtract_scalar_r, -) | |
OPENCLDEV_KERNEL_FW_X_SCALAR_L_INFIX(subtract_scalar_l, -) | |
OPENCLDEV_KERNEL_FW_X_SCALAR_R_INFIX(multiply_scalar, *) | |
OPENCLDEV_KERNEL_FW_X_SCALAR_R_INFIX(divide_scalar_r, /) | |
OPENCLDEV_KERNEL_FW_X_SCALAR_L_INFIX(divide_scalar_l, /) | |
OPENCLDEV_KERNEL_FW_AB_INFIX(add, +) | |
OPENCLDEV_KERNEL_FW_AB_INFIX(subtract, -) | |
OPENCLDEV_KERNEL_FW_AB_INFIX(multiply, *) | |
OPENCLDEV_KERNEL_FW_AB_INFIX(divide, /) | |
#undef OPENCLDEV_KERNEL_FW_X | |
#undef OPENCLDEV_KERNEL_BW_X | |
#undef OPENCLDEV_KERNEL_FW_X_CONST | |
#undef OPENCLDEV_KERNEL_BW_X_CONST | |
#undef CUDADEV_KERNEL_FW_X_SCALAR_R | |
#undef CUDADEV_KERNEL_FW_X_SCALAR_L | |
#undef CUDADEV_KERNEL_FW_AB | |
kernel void add_bw_kernel( | |
const global float *pa, const global float *pb, | |
const global float *py, const global float *pgy, | |
const unsigned size, const unsigned mba, const unsigned mbb, | |
global float *pga, global float *pgb) { | |
const unsigned i = get_global_id(0); | |
const unsigned bid_y = get_group_id(1); | |
const unsigned shift = bid_y * size; | |
if (i < size) { | |
const float gy = pgy[i + shift]; | |
atomic_add_float(pga + i + mba * shift, gy); | |
atomic_add_float(pgb + i + mbb * shift, gy); | |
} | |
} | |
kernel void subtract_bw_kernel( | |
const global float *pa, const global float *pb, | |
const global float *py, const global float *pgy, | |
const unsigned size, const unsigned mba, const unsigned mbb, | |
global float *pga, global float *pgb) { | |
const unsigned i = get_global_id(0); | |
const unsigned bid_y = get_group_id(1); | |
const unsigned shift = bid_y * size; | |
if (i < size) { | |
const float gy = pgy[i + shift]; | |
atomic_add_float(pga + i + mba * shift, gy); | |
atomic_add_float(pgb + i + mbb * shift, -gy); | |
} | |
} | |
kernel void multiply_bw_kernel( | |
const global float *pa, const global float *pb, | |
const global float *py, const global float *pgy, | |
const unsigned size, const unsigned mba, const unsigned mbb, | |
global float *pga, global float *pgb) { | |
const unsigned i = get_global_id(0); | |
const unsigned bid_y = get_group_id(1); | |
const unsigned shift = bid_y * size; | |
if (i < size) { | |
const float gy = pgy[i + shift]; | |
const unsigned a_ofs = i + mba * shift; | |
const unsigned b_ofs = i + mbb * shift; | |
atomic_add_float(pga + a_ofs, gy * pb[b_ofs]); | |
atomic_add_float(pgb + b_ofs, gy * pa[a_ofs]); | |
} | |
} | |
kernel void divide_bw_kernel( | |
const global float *pa, const global float *pb, | |
const global float *py, const global float *pgy, | |
const unsigned size, const unsigned mba, const unsigned mbb, | |
global float *pga, global float *pgb) { | |
const unsigned i = get_global_id(0); | |
const unsigned bid_y = get_group_id(1); | |
const unsigned shift = bid_y * size; | |
if (i < size) { | |
const unsigned b_ofs = i + mbb * shift; | |
const unsigned y_ofs = i + shift; | |
const float k = pgy[y_ofs] / pb[b_ofs]; | |
atomic_add_float(pga + i + mba * shift, k); | |
atomic_add_float(pgb + b_ofs, -k * py[y_ofs]); | |
} | |
} | |
kernel void transpose_fw_kernel( | |
const global float *px, unsigned rows, unsigned cols, global float *py) { | |
const unsigned i = get_global_id(0); | |
const unsigned j = get_global_id(1); | |
const unsigned bid_z = get_group_id(2); | |
const unsigned ofs = bid_z * rows * cols; | |
if (i < rows && j < cols) py[ofs + j + i * cols] = px[ofs + i + j * rows]; | |
} | |
kernel void transpose_bw_kernel( | |
const global float *py, const unsigned rows, const unsigned cols, | |
global float *px) { | |
const unsigned i = get_global_id(0); | |
const unsigned j = get_global_id(1); | |
const unsigned bid_z = get_group_id(2); | |
const unsigned ofs = bid_z * rows * cols; | |
if (i < rows && j < cols) px[ofs + i + j * rows] += py[ofs + j + i * cols]; | |
} | |
#define REDUCE(k, GROUP_SIZE) \ | |
if (GROUP_SIZE >= k << 1) { \ | |
if (tid < k) temp[tid] += temp[tid + k]; \ | |
barrier(CLK_LOCAL_MEM_FENCE); \ | |
} | |
#define SUM_FW_KERNEL(GROUP_SIZE) \ | |
kernel void sum_fw_kernel_##GROUP_SIZE( \ | |
const global float *px, const unsigned skip, const unsigned n, \ | |
global float *py) { \ | |
const unsigned bid = get_group_id(0); \ | |
const unsigned tid = get_local_id(0); \ | |
local float temp[GROUP_SIZE]; \ | |
px += bid % skip + (bid / skip) * skip * n; \ | |
temp[tid] = 0; \ | |
for (unsigned i = tid; i < n; i += GROUP_SIZE) temp[tid] += px[i * skip]; \ | |
barrier(CLK_LOCAL_MEM_FENCE); \ | |
REDUCE(512, GROUP_SIZE) \ | |
REDUCE(256, GROUP_SIZE) \ | |
REDUCE(128, GROUP_SIZE) \ | |
REDUCE(64, GROUP_SIZE) \ | |
REDUCE(32, GROUP_SIZE) \ | |
REDUCE(16, GROUP_SIZE) \ | |
REDUCE(8, GROUP_SIZE) \ | |
REDUCE(4, GROUP_SIZE) \ | |
REDUCE(2, GROUP_SIZE) \ | |
REDUCE(1, GROUP_SIZE) \ | |
if (tid == 0) py[bid] = temp[0]; \ | |
} | |
SUM_FW_KERNEL(1024) | |
SUM_FW_KERNEL(512) | |
SUM_FW_KERNEL(256) | |
SUM_FW_KERNEL(128) | |
SUM_FW_KERNEL(64) | |
SUM_FW_KERNEL(32) | |
SUM_FW_KERNEL(16) | |
SUM_FW_KERNEL(8) | |
SUM_FW_KERNEL(4) | |
SUM_FW_KERNEL(2) | |
SUM_FW_KERNEL(1) | |
#undef REDUCE | |
inline float logsumexp2_fw_kernel(float a, float b) { | |
return a > b | |
? a + log(1.f + exp(b - a)) | |
: b + log(1.f + exp(a - b)); | |
} | |
#define REDUCE(k, GROUP_SIZE) \ | |
if (GROUP_SIZE >= k << 1) { \ | |
if (tid < k) temp[tid] = logsumexp2_fw_kernel(temp[tid], temp[tid + k]); \ | |
barrier(CLK_LOCAL_MEM_FENCE); \ | |
} | |
#define LOGSUMEXP_FW_KERNEL(GROUP_SIZE) \ | |
kernel void logsumexp_fw_kernel_##GROUP_SIZE( \ | |
const global float *px, const unsigned skip, const unsigned n, \ | |
global float *py) { \ | |
const unsigned bid = get_group_id(0); \ | |
const unsigned tid = get_local_id(0); \ | |
local float temp[GROUP_SIZE]; \ | |
px += bid % skip + (bid / skip) * skip * n; \ | |
temp[tid] = -1e38; \ | |
for (unsigned i = tid; i < n; i += GROUP_SIZE) { \ | |
temp[tid] = logsumexp2_fw_kernel(temp[tid], px[i * skip]); \ | |
} \ | |
barrier(CLK_LOCAL_MEM_FENCE); \ | |
REDUCE(512, GROUP_SIZE) \ | |
REDUCE(256, GROUP_SIZE) \ | |
REDUCE(128, GROUP_SIZE) \ | |
REDUCE(64, GROUP_SIZE) \ | |
REDUCE(32, GROUP_SIZE) \ | |
REDUCE(16, GROUP_SIZE) \ | |
REDUCE(8, GROUP_SIZE) \ | |
REDUCE(4, GROUP_SIZE) \ | |
REDUCE(2, GROUP_SIZE) \ | |
REDUCE(1, GROUP_SIZE) \ | |
if (tid == 0) py[bid] = temp[0]; \ | |
} | |
LOGSUMEXP_FW_KERNEL(1024) | |
LOGSUMEXP_FW_KERNEL(512) | |
LOGSUMEXP_FW_KERNEL(256) | |
LOGSUMEXP_FW_KERNEL(128) | |
LOGSUMEXP_FW_KERNEL(64) | |
LOGSUMEXP_FW_KERNEL(32) | |
LOGSUMEXP_FW_KERNEL(16) | |
LOGSUMEXP_FW_KERNEL(8) | |
LOGSUMEXP_FW_KERNEL(4) | |
LOGSUMEXP_FW_KERNEL(2) | |
LOGSUMEXP_FW_KERNEL(1) | |
#undef REDUCE | |
kernel void broadcast_fw_kernel( | |
const global float *px, const unsigned skip1, const unsigned skip2, | |
const unsigned size, global float *py) { | |
const unsigned i = get_global_id(0); | |
if (i < size) py[i] = px[i % skip1 + (i / skip2) * skip1]; | |
} | |
kernel void batch_sum_fw_kernel( | |
const global float *px, const unsigned size, | |
const unsigned batch, global float *py) { | |
const unsigned i = get_global_id(0); | |
if (i < size) { | |
float temp = .0f; | |
px += i; | |
for (unsigned j = 0; j < batch; ++j, px += size) { | |
temp += *px; | |
} | |
py[i] = temp; | |
} | |
} | |
kernel void inplace_multiply_const_kernel( | |
const float k, const unsigned size, global float *px) { | |
const unsigned i = get_global_id(0); | |
if (i < size) px[i] *= k; | |
} | |
kernel void inplace_add_kernel( | |
const global float *px, const unsigned size, | |
const unsigned mbx, const unsigned mby, global float *py) { | |
const unsigned i = get_global_id(0); | |
const unsigned bid_y = get_group_id(1); | |
const unsigned shift = bid_y * size; | |
if (i < size) atomic_add_float(py + i + mby * shift, px[i + mbx * shift]); | |
} | |
kernel void inplace_subtract_kernel( | |
const global float *px, const unsigned size, | |
const unsigned mbx, const unsigned mby, global float *py) { | |
const unsigned i = get_global_id(0); | |
const unsigned bid_y = get_group_id(1); | |
const unsigned shift = bid_y * size; | |
if (i < size) atomic_add_float(py + i + mby * shift, -px[i + mbx * shift]); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment