Skip to content

Instantly share code, notes, and snippets.

@vbkaisetsu
Created November 24, 2017 05:27
Show Gist options
  • Save vbkaisetsu/53d0d4cd715fee5f8fe584ebf8459232 to your computer and use it in GitHub Desktop.
Save vbkaisetsu/53d0d4cd715fee5f8fe584ebf8459232 to your computer and use it in GitHub Desktop.
#define REDUCE(k, GROUP_SIZE) \
if (GROUP_SIZE >= k << 1) { \
if (tid < k) { \
if (max_val[tid + k] > max_val[tid]) { \
max_val[tid] = max_val[tid + k]; \
argmax_val[tid] = argmax_val[tid + k]; \
} \
} \
barrier(CLK_LOCAL_MEM_FENCE); \
}
#define ARGMAX_KERNEL(GROUP_SIZE) \
kernel void argmax_kernel_##GROUP_SIZE( \
const global float *px, const unsigned skip, \
const unsigned n, global unsigned *py) { \
const unsigned bid = get_group_id(0); \
const unsigned tid = get_local_id(0); \
local float max_val[GROUP_SIZE]; \
local unsigned argmax_val[GROUP_SIZE]; \
px += bid % skip + (bid / skip) * skip * n; \
max_val[tid] = -1e38; \
for (unsigned i = tid; i < n; i += GROUP_SIZE) { \
const float val = px[i * skip]; \
if (val > max_val[tid]) { \
max_val[tid] = val; \
argmax_val[tid] = i; \
} \
} \
barrier(CLK_LOCAL_MEM_FENCE); \
REDUCE(512, GROUP_SIZE) \
REDUCE(256, GROUP_SIZE) \
REDUCE(128, GROUP_SIZE) \
REDUCE(64, GROUP_SIZE) \
REDUCE(32, GROUP_SIZE) \
REDUCE(16, GROUP_SIZE) \
REDUCE(8, GROUP_SIZE) \
REDUCE(4, GROUP_SIZE) \
REDUCE(2, GROUP_SIZE) \
REDUCE(1, GROUP_SIZE) \
if (tid == 0) py[bid] = argmax_val[0]; \
}
ARGMAX_KERNEL(1024)
ARGMAX_KERNEL(512)
ARGMAX_KERNEL(256)
ARGMAX_KERNEL(128)
ARGMAX_KERNEL(64)
ARGMAX_KERNEL(32)
ARGMAX_KERNEL(16)
ARGMAX_KERNEL(8)
ARGMAX_KERNEL(4)
ARGMAX_KERNEL(2)
ARGMAX_KERNEL(1)
#undef REDUCE
#define REDUCE(k, GROUP_SIZE) \
if (GROUP_SIZE >= k << 1) { \
if (tid < k) { \
if (min_val[tid + k] < min_val[tid]) { \
min_val[tid] = min_val[tid + k]; \
argmin_val[tid] = argmin_val[tid + k]; \
} \
} \
barrier(CLK_LOCAL_MEM_FENCE); \
}
#define ARGMIN_KERNEL(GROUP_SIZE) \
kernel void argmin_kernel_##GROUP_SIZE( \
const global float *px, const unsigned skip, \
const unsigned n, global unsigned *py) { \
const unsigned bid = get_group_id(0); \
const unsigned tid = get_local_id(0); \
local float min_val[GROUP_SIZE]; \
local unsigned argmin_val[GROUP_SIZE]; \
px += bid % skip + (bid / skip) * skip * n; \
min_val[tid] = 1e38; \
for (unsigned i = tid; i < n; i += GROUP_SIZE) { \
const float val = px[i * skip]; \
if (val < min_val[tid]) { \
min_val[tid] = val; \
argmin_val[tid] = i; \
} \
} \
barrier(CLK_LOCAL_MEM_FENCE); \
REDUCE(512, GROUP_SIZE) \
REDUCE(256, GROUP_SIZE) \
REDUCE(128, GROUP_SIZE) \
REDUCE(64, GROUP_SIZE) \
REDUCE(32, GROUP_SIZE) \
REDUCE(16, GROUP_SIZE) \
REDUCE(8, GROUP_SIZE) \
REDUCE(4, GROUP_SIZE) \
REDUCE(2, GROUP_SIZE) \
REDUCE(1, GROUP_SIZE) \
if (tid == 0) py[bid] = argmin_val[0]; \
}
ARGMIN_KERNEL(1024)
ARGMIN_KERNEL(512)
ARGMIN_KERNEL(256)
ARGMIN_KERNEL(128)
ARGMIN_KERNEL(64)
ARGMIN_KERNEL(32)
ARGMIN_KERNEL(16)
ARGMIN_KERNEL(8)
ARGMIN_KERNEL(4)
ARGMIN_KERNEL(2)
ARGMIN_KERNEL(1)
#undef REDUCE
kernel void set_identity_kernel(
const unsigned size, const unsigned skip, global float *py) {
const unsigned i = get_global_id(0);
if (i < size) py[i] = !(i % skip);
}
kernel void pick_fw_kernel(
const global float *px, const global unsigned *pi,
const unsigned wx, const unsigned wy, const unsigned sx,
const unsigned si, const unsigned sy, global float *py) {
const unsigned t = get_global_id(0);
const unsigned bid_y = get_group_id(1);
const unsigned ox = bid_y * sx + pi[bid_y * si] * wy;
const unsigned oy = bid_y * sy;
if (t < sy) py[oy + t] = px[ox + (t / wy) * wx + (t % wy)];
}
kernel void slice_fw_kernel(
const global float *px, const unsigned shift, const unsigned span,
const unsigned skip, const unsigned size, global float *py) {
const unsigned i = get_global_id(0);
if (i < size) py[i] = px[(i / span) * skip + (i % span) + shift];
}
kernel void concat_fw_kernel(
const global float *px, const unsigned span, const unsigned skip,
const unsigned x_size, const unsigned y_size,
global float *py, const unsigned shift) {
const unsigned i = get_global_id(0);
if (i < y_size) py[(i / span) * skip + (i % span) + shift] = px[i % x_size];
}
inline void atomic_add_float(global float *source, const float operand) {
union {
unsigned u;
float f;
} oldval, newval;
unsigned readback;
oldval.f = *source;
newval.f = oldval.f + operand;
while ((readback = atomic_cmpxchg(
(global unsigned *) source, oldval.u, newval.u)) != oldval.u) {
oldval.u = readback;
newval.f = oldval.f + operand;
}
}
kernel void pick_bw_kernel(
const global float *pgy, const global unsigned *pi,
const unsigned wx, const unsigned wy,
const unsigned sx, const unsigned si, const unsigned sy,
global float *pgx) {
const unsigned t = get_global_id(0);
const unsigned bid_y = get_group_id(1);
const unsigned ox = bid_y * sx + pi[bid_y * si] * wy;
const unsigned oy = bid_y * sy;
if (t < sy) {
atomic_add_float(pgx + ox + (t / wy) * wx + (t % wy), pgy[oy + t]);
}
}
kernel void slice_bw_kernel(
const global float *pgy, const unsigned wx, const unsigned wy,
const unsigned nx, const unsigned ny,
global float *pgx, const unsigned shift) {
const unsigned i = get_global_id(0);
if (i < wy * max(nx, ny)) {
atomic_add_float(
pgx + shift + ((i / wy) * wx + (i % wy)) % (wx * nx),
pgy[i % (wy * ny)]);
}
}
#define OPENCLDEV_KERNEL_FW_X(name, op) \
kernel void name##_fw_kernel( \
const global float *px, const unsigned size, global float *py) { \
const unsigned i = get_global_id(0); \
if (i < size) py[i] = (op); \
}
#define OPENCLDEV_KERNEL_BW_X(name, op) \
kernel void name##_bw_kernel( \
const global float *px, const global float *py, const global float *pgy, \
const unsigned size, global float *pgx) { \
const unsigned i = get_global_id(0); \
if (i < size) pgx[i] += (op); \
}
#define OPENCLDEV_KERNEL_FW_X_CONST(name, op) \
kernel void name##_fw_kernel( \
const global float *px, const float k, \
const unsigned size, global float *py) { \
const unsigned i = get_global_id(0); \
if (i < size) py[i] = (op); \
}
#define OPENCLDEV_KERNEL_BW_X_CONST(name, op) \
kernel void name##_bw_kernel( \
const global float *px, const global float *py, const global float *pgy, \
const float k, const unsigned size, global float *pgx) { \
const unsigned i = get_global_id(0); \
if (i < size) pgx[i] += (op); \
}
#define OPENCLDEV_KERNEL_FW_X_SCALAR_R_INFIX(name, op) \
kernel void name##_fw_kernel( \
const global float *px, const global float *pk, const unsigned size, \
const unsigned mbx, const unsigned mbk, global float *py) { \
const unsigned i = get_global_id(0); \
const unsigned bid_y = get_group_id(1); \
const unsigned shift = bid_y * size; \
if (i < size) { \
py[i + shift] = px[i + mbx * shift] op pk[mbk * bid_y]; \
} \
}
#define OPENCLDEV_KERNEL_FW_X_SCALAR_L_INFIX(name, op) \
kernel void name##_fw_kernel( \
const global float *px, const global float *pk, const unsigned size, \
const unsigned mbx, const unsigned mbk, global float *py) { \
const unsigned i = get_global_id(0); \
const unsigned bid_y = get_group_id(1); \
const unsigned shift = bid_y * size; \
if (i < size) { \
py[i + shift] = pk[mbk * bid_y] op px[i + mbx * shift]; \
} \
}
#define OPENCLDEV_KERNEL_FW_AB_INFIX(name, op) \
kernel void name##_fw_kernel( \
const global float *pa, const global float *pb, const unsigned size, \
const unsigned mba, const unsigned mbb, global float *py) { \
const unsigned i = get_global_id(0); \
const unsigned bid_y = get_group_id(1); \
const unsigned shift = bid_y * size; \
if (i < size) { \
py[i + shift] = pa[i + mba * shift] op \
pb[i + mbb * shift]; \
} \
}
OPENCLDEV_KERNEL_FW_X(negate, -px[i])
OPENCLDEV_KERNEL_FW_X(sqrt, sqrt(px[i]))
OPENCLDEV_KERNEL_FW_X(exp, exp(px[i]))
OPENCLDEV_KERNEL_FW_X(log, log(px[i]))
OPENCLDEV_KERNEL_FW_X(tanh, tanh(px[i]))
OPENCLDEV_KERNEL_FW_X(sigmoid, .5f + .5f * tanh(.5f * px[i]))
OPENCLDEV_KERNEL_FW_X(
softplus, max(px[i], .0f) + log(1.f + exp(-fabs(px[i]))))
OPENCLDEV_KERNEL_FW_X(sin, sin(px[i]))
OPENCLDEV_KERNEL_FW_X(cos, cos(px[i]))
OPENCLDEV_KERNEL_FW_X(tan, tan(px[i]))
OPENCLDEV_KERNEL_BW_X(sqrt, .5f * pgy[i] / py[i])
OPENCLDEV_KERNEL_BW_X(exp, py[i] * pgy[i])
OPENCLDEV_KERNEL_BW_X(log, pgy[i] / px[i])
OPENCLDEV_KERNEL_BW_X(tanh, (1.f - py[i] * py[i]) * pgy[i])
OPENCLDEV_KERNEL_BW_X(sigmoid, py[i] * (1.f - py[i]) * pgy[i])
OPENCLDEV_KERNEL_BW_X(softplus, (.5f + .5f * tanh(.5f * px[i])) * pgy[i])
OPENCLDEV_KERNEL_BW_X(sin, cos(px[i]) * pgy[i])
OPENCLDEV_KERNEL_BW_X(cos, -sin(px[i]) * pgy[i])
OPENCLDEV_KERNEL_BW_X(tan, (1.f + py[i] * py[i]) * pgy[i])
OPENCLDEV_KERNEL_FW_X_CONST(add_const, px[i] + k)
OPENCLDEV_KERNEL_FW_X_CONST(subtract_const_r, px[i] - k)
OPENCLDEV_KERNEL_FW_X_CONST(subtract_const_l, k - px[i])
OPENCLDEV_KERNEL_FW_X_CONST(multiply_const, px[i] * k)
OPENCLDEV_KERNEL_FW_X_CONST(divide_const_r, px[i] / k)
OPENCLDEV_KERNEL_FW_X_CONST(divide_const_l, k / px[i])
OPENCLDEV_KERNEL_FW_X_CONST(prelu, max(px[i], .0f) + k * min(px[i], .0f))
OPENCLDEV_KERNEL_FW_X_CONST(
elu, max(px[i], .0f) + k * (exp(min(px[i], .0f)) - 1.0f))
OPENCLDEV_KERNEL_BW_X_CONST(add_const, pgy[i])
OPENCLDEV_KERNEL_BW_X_CONST(subtract_const_r, pgy[i])
OPENCLDEV_KERNEL_BW_X_CONST(subtract_const_l, -pgy[i])
OPENCLDEV_KERNEL_BW_X_CONST(multiply_const, k * pgy[i])
OPENCLDEV_KERNEL_BW_X_CONST(divide_const_r, pgy[i] / k)
OPENCLDEV_KERNEL_BW_X_CONST(divide_const_l, -py[i] * pgy[i] / px[i])
OPENCLDEV_KERNEL_BW_X_CONST(
prelu, pgy[i] * ((px[i] > .0f) + k * (px[i] <= .0f)))
OPENCLDEV_KERNEL_BW_X_CONST(
elu, pgy[i] * ((px[i] > .0f) + (py[i] + k) * (px[i] <= .0f)))
OPENCLDEV_KERNEL_FW_X_SCALAR_R_INFIX(add_scalar, +)
OPENCLDEV_KERNEL_FW_X_SCALAR_R_INFIX(subtract_scalar_r, -)
OPENCLDEV_KERNEL_FW_X_SCALAR_L_INFIX(subtract_scalar_l, -)
OPENCLDEV_KERNEL_FW_X_SCALAR_R_INFIX(multiply_scalar, *)
OPENCLDEV_KERNEL_FW_X_SCALAR_R_INFIX(divide_scalar_r, /)
OPENCLDEV_KERNEL_FW_X_SCALAR_L_INFIX(divide_scalar_l, /)
OPENCLDEV_KERNEL_FW_AB_INFIX(add, +)
OPENCLDEV_KERNEL_FW_AB_INFIX(subtract, -)
OPENCLDEV_KERNEL_FW_AB_INFIX(multiply, *)
OPENCLDEV_KERNEL_FW_AB_INFIX(divide, /)
#undef OPENCLDEV_KERNEL_FW_X
#undef OPENCLDEV_KERNEL_BW_X
#undef OPENCLDEV_KERNEL_FW_X_CONST
#undef OPENCLDEV_KERNEL_BW_X_CONST
#undef CUDADEV_KERNEL_FW_X_SCALAR_R
#undef CUDADEV_KERNEL_FW_X_SCALAR_L
#undef CUDADEV_KERNEL_FW_AB
kernel void add_bw_kernel(
const global float *pa, const global float *pb,
const global float *py, const global float *pgy,
const unsigned size, const unsigned mba, const unsigned mbb,
global float *pga, global float *pgb) {
const unsigned i = get_global_id(0);
const unsigned bid_y = get_group_id(1);
const unsigned shift = bid_y * size;
if (i < size) {
const float gy = pgy[i + shift];
atomic_add_float(pga + i + mba * shift, gy);
atomic_add_float(pgb + i + mbb * shift, gy);
}
}
kernel void subtract_bw_kernel(
const global float *pa, const global float *pb,
const global float *py, const global float *pgy,
const unsigned size, const unsigned mba, const unsigned mbb,
global float *pga, global float *pgb) {
const unsigned i = get_global_id(0);
const unsigned bid_y = get_group_id(1);
const unsigned shift = bid_y * size;
if (i < size) {
const float gy = pgy[i + shift];
atomic_add_float(pga + i + mba * shift, gy);
atomic_add_float(pgb + i + mbb * shift, -gy);
}
}
kernel void multiply_bw_kernel(
const global float *pa, const global float *pb,
const global float *py, const global float *pgy,
const unsigned size, const unsigned mba, const unsigned mbb,
global float *pga, global float *pgb) {
const unsigned i = get_global_id(0);
const unsigned bid_y = get_group_id(1);
const unsigned shift = bid_y * size;
if (i < size) {
const float gy = pgy[i + shift];
const unsigned a_ofs = i + mba * shift;
const unsigned b_ofs = i + mbb * shift;
atomic_add_float(pga + a_ofs, gy * pb[b_ofs]);
atomic_add_float(pgb + b_ofs, gy * pa[a_ofs]);
}
}
kernel void divide_bw_kernel(
const global float *pa, const global float *pb,
const global float *py, const global float *pgy,
const unsigned size, const unsigned mba, const unsigned mbb,
global float *pga, global float *pgb) {
const unsigned i = get_global_id(0);
const unsigned bid_y = get_group_id(1);
const unsigned shift = bid_y * size;
if (i < size) {
const unsigned b_ofs = i + mbb * shift;
const unsigned y_ofs = i + shift;
const float k = pgy[y_ofs] / pb[b_ofs];
atomic_add_float(pga + i + mba * shift, k);
atomic_add_float(pgb + b_ofs, -k * py[y_ofs]);
}
}
kernel void transpose_fw_kernel(
const global float *px, unsigned rows, unsigned cols, global float *py) {
const unsigned i = get_global_id(0);
const unsigned j = get_global_id(1);
const unsigned bid_z = get_group_id(2);
const unsigned ofs = bid_z * rows * cols;
if (i < rows && j < cols) py[ofs + j + i * cols] = px[ofs + i + j * rows];
}
kernel void transpose_bw_kernel(
const global float *py, const unsigned rows, const unsigned cols,
global float *px) {
const unsigned i = get_global_id(0);
const unsigned j = get_global_id(1);
const unsigned bid_z = get_group_id(2);
const unsigned ofs = bid_z * rows * cols;
if (i < rows && j < cols) px[ofs + i + j * rows] += py[ofs + j + i * cols];
}
#define REDUCE(k, GROUP_SIZE) \
if (GROUP_SIZE >= k << 1) { \
if (tid < k) temp[tid] += temp[tid + k]; \
barrier(CLK_LOCAL_MEM_FENCE); \
}
#define SUM_FW_KERNEL(GROUP_SIZE) \
kernel void sum_fw_kernel_##GROUP_SIZE( \
const global float *px, const unsigned skip, const unsigned n, \
global float *py) { \
const unsigned bid = get_group_id(0); \
const unsigned tid = get_local_id(0); \
local float temp[GROUP_SIZE]; \
px += bid % skip + (bid / skip) * skip * n; \
temp[tid] = 0; \
for (unsigned i = tid; i < n; i += GROUP_SIZE) temp[tid] += px[i * skip]; \
barrier(CLK_LOCAL_MEM_FENCE); \
REDUCE(512, GROUP_SIZE) \
REDUCE(256, GROUP_SIZE) \
REDUCE(128, GROUP_SIZE) \
REDUCE(64, GROUP_SIZE) \
REDUCE(32, GROUP_SIZE) \
REDUCE(16, GROUP_SIZE) \
REDUCE(8, GROUP_SIZE) \
REDUCE(4, GROUP_SIZE) \
REDUCE(2, GROUP_SIZE) \
REDUCE(1, GROUP_SIZE) \
if (tid == 0) py[bid] = temp[0]; \
}
SUM_FW_KERNEL(1024)
SUM_FW_KERNEL(512)
SUM_FW_KERNEL(256)
SUM_FW_KERNEL(128)
SUM_FW_KERNEL(64)
SUM_FW_KERNEL(32)
SUM_FW_KERNEL(16)
SUM_FW_KERNEL(8)
SUM_FW_KERNEL(4)
SUM_FW_KERNEL(2)
SUM_FW_KERNEL(1)
#undef REDUCE
inline float logsumexp2_fw_kernel(float a, float b) {
return a > b
? a + log(1.f + exp(b - a))
: b + log(1.f + exp(a - b));
}
#define REDUCE(k, GROUP_SIZE) \
if (GROUP_SIZE >= k << 1) { \
if (tid < k) temp[tid] = logsumexp2_fw_kernel(temp[tid], temp[tid + k]); \
barrier(CLK_LOCAL_MEM_FENCE); \
}
#define LOGSUMEXP_FW_KERNEL(GROUP_SIZE) \
kernel void logsumexp_fw_kernel_##GROUP_SIZE( \
const global float *px, const unsigned skip, const unsigned n, \
global float *py) { \
const unsigned bid = get_group_id(0); \
const unsigned tid = get_local_id(0); \
local float temp[GROUP_SIZE]; \
px += bid % skip + (bid / skip) * skip * n; \
temp[tid] = -1e38; \
for (unsigned i = tid; i < n; i += GROUP_SIZE) { \
temp[tid] = logsumexp2_fw_kernel(temp[tid], px[i * skip]); \
} \
barrier(CLK_LOCAL_MEM_FENCE); \
REDUCE(512, GROUP_SIZE) \
REDUCE(256, GROUP_SIZE) \
REDUCE(128, GROUP_SIZE) \
REDUCE(64, GROUP_SIZE) \
REDUCE(32, GROUP_SIZE) \
REDUCE(16, GROUP_SIZE) \
REDUCE(8, GROUP_SIZE) \
REDUCE(4, GROUP_SIZE) \
REDUCE(2, GROUP_SIZE) \
REDUCE(1, GROUP_SIZE) \
if (tid == 0) py[bid] = temp[0]; \
}
LOGSUMEXP_FW_KERNEL(1024)
LOGSUMEXP_FW_KERNEL(512)
LOGSUMEXP_FW_KERNEL(256)
LOGSUMEXP_FW_KERNEL(128)
LOGSUMEXP_FW_KERNEL(64)
LOGSUMEXP_FW_KERNEL(32)
LOGSUMEXP_FW_KERNEL(16)
LOGSUMEXP_FW_KERNEL(8)
LOGSUMEXP_FW_KERNEL(4)
LOGSUMEXP_FW_KERNEL(2)
LOGSUMEXP_FW_KERNEL(1)
#undef REDUCE
kernel void broadcast_fw_kernel(
const global float *px, const unsigned skip1, const unsigned skip2,
const unsigned size, global float *py) {
const unsigned i = get_global_id(0);
if (i < size) py[i] = px[i % skip1 + (i / skip2) * skip1];
}
kernel void batch_sum_fw_kernel(
const global float *px, const unsigned size,
const unsigned batch, global float *py) {
const unsigned i = get_global_id(0);
if (i < size) {
float temp = .0f;
px += i;
for (unsigned j = 0; j < batch; ++j, px += size) {
temp += *px;
}
py[i] = temp;
}
}
kernel void inplace_multiply_const_kernel(
const float k, const unsigned size, global float *px) {
const unsigned i = get_global_id(0);
if (i < size) px[i] *= k;
}
kernel void inplace_add_kernel(
const global float *px, const unsigned size,
const unsigned mbx, const unsigned mby, global float *py) {
const unsigned i = get_global_id(0);
const unsigned bid_y = get_group_id(1);
const unsigned shift = bid_y * size;
if (i < size) atomic_add_float(py + i + mby * shift, px[i + mbx * shift]);
}
kernel void inplace_subtract_kernel(
const global float *px, const unsigned size,
const unsigned mbx, const unsigned mby, global float *py) {
const unsigned i = get_global_id(0);
const unsigned bid_y = get_group_id(1);
const unsigned shift = bid_y * size;
if (i < size) atomic_add_float(py + i + mby * shift, -px[i + mbx * shift]);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment