Skip to content

Instantly share code, notes, and snippets.

@dadeba
Created March 23, 2012 05:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dadeba/2167470 to your computer and use it in GitHub Desktop.
Save dadeba/2167470 to your computer and use it in GitHub Desktop.
OpenCL: A vectorized kernel for gravity interaction inspired by Tanikawa etal. 2012
#define READONLY_P const * restrict
float4 sum(float8 x)
{
float4 tmp;
tmp.x = x.s0 + x.s4;
tmp.y = x.s1 + x.s5;
tmp.z = x.s2 + x.s6;
tmp.w = x.s3 + x.s7;
return tmp;
}
__kernel
void
grav3(
__global float4 READONLY_P x,
__global float4 READONLY_P y,
__global float4 READONLY_P z,
__global float4 READONLY_P m,
__global float4 *ax,
__global float4 *ay,
__global float4 *az,
__global float4 *pt,
const int n,
const float eps2
)
{
unsigned int g_xid = get_global_id(0);
unsigned int g_yid = get_global_id(1);
unsigned int g_w = get_global_size(0);
unsigned int gid = g_yid*g_w + g_xid;
unsigned int i = gid;
float8 xi = (float8)(x[i], x[i]);
float8 yi = (float8)(y[i], y[i]);
float8 zi = (float8)(z[i], z[i]);
float8 e2 = (float8)(eps2);
float8 a_x = (float8)(0.0f);
float8 a_y = (float8)(0.0f);
float8 a_z = (float8)(0.0f);
float8 p_t = (float8)(0.0f);
for(unsigned int j = 0; j < n/4; j++) {
float4 xxj = x[j];
float4 yyj = y[j];
float4 zzj = z[j];
float4 mmj = m[j];
{
float8 xj = (float8)(xxj.x, xxj.x, xxj.x, xxj.x, xxj.y, xxj.y, xxj.y, xxj.y);
float8 yj = (float8)(yyj.x, yyj.x, yyj.x, yyj.x, yyj.y, yyj.y, yyj.y, yyj.y);
float8 zj = (float8)(zzj.x, zzj.x, zzj.x, zzj.x, zzj.y, zzj.y, zzj.y, zzj.y);
float8 mj = (float8)(mmj.x, mmj.x, mmj.x, mmj.x, mmj.y, mmj.y, mmj.y, mmj.y);
float8 dx, dy, dz;
dx = xj - xi;
dy = yj - yi;
dz = zj - zi;
float8 r2 = dx*dx + dy*dy + dz*dz + e2;
float8 r1i = native_rsqrt(r2);
float8 r2i = r1i*r1i;
float8 r1im = mj*r1i;
float8 r3im = r1im*r2i;
a_x += dx*r3im;
a_y += dy*r3im;
a_z += dz*r3im;
p_t += -r1im;
}
{
float8 xj = (float8)(xxj.z, xxj.z, xxj.z, xxj.z, xxj.w, xxj.w, xxj.w, xxj.w);
float8 yj = (float8)(yyj.z, yyj.z, yyj.z, yyj.z, yyj.w, yyj.w, yyj.w, yyj.w);
float8 zj = (float8)(zzj.z, zzj.z, zzj.z, zzj.z, zzj.w, zzj.w, zzj.w, zzj.w);
float8 mj = (float8)(mmj.z, mmj.z, mmj.z, mmj.z, mmj.w, mmj.w, mmj.w, mmj.w);
float8 dx, dy, dz;
dx = xj - xi;
dy = yj - yi;
dz = zj - zi;
float8 r2 = dx*dx + dy*dy + dz*dz + e2;
float8 r1i = native_rsqrt(r2);
float8 r2i = r1i*r1i;
float8 r1im = mj*r1i;
float8 r3im = r1im*r2i;
a_x += dx*r3im;
a_y += dy*r3im;
a_z += dz*r3im;
p_t += -r1im;
}
}
ax[i] = sum(a_x);
ay[i] = sum(a_y);
az[i] = sum(a_z);
pt[i] = sum(p_t);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment