Skip to content

Instantly share code, notes, and snippets.

@k3kaimu
Last active September 16, 2015 12:59
Show Gist options
  • Save k3kaimu/ff72c891af9c95dffe93 to your computer and use it in GitHub Desktop.
Save k3kaimu/ff72c891af9c95dffe93 to your computer and use it in GitHub Desktop.
SSEを使った複素数の内積
import core.simd;
import std.stdio;
import std.datetime;
import std.math;
import std.parallelism;
import std.range;
import std.random;
cfloat dotProduct(in Vector!(float[4])[] a, in Vector!(float[4])[] b)
in{
assert(a.length == b.length);
}
body{
Vector!(float[4]) r, q;
r = 0;
q = 0;
auto px = a.ptr,
ph = b.ptr,
qx = a.ptr + a.length;
while(px != qx)
{
Vector!(float[4]) x0 = *px,
h = *ph;
r += x0 * h;
x0 = __simd(XMM.SHUFPS, x0, x0, 0b10_11_00_01);
q += x0 * h;
++px;
++ph;
}
Vector!(float[4]) sign, ones;
sign.array = [1.0f, -1.0f, 1.0f, -1.0f];
ones.array = [1.0f, 1.0f, 1.0f, 1.0f];
r = __simd(XMM.DPPS, r, sign, 0b11111111);
q = __simd(XMM.DPPS, q, ones, 0b11111111);
return r.array[0] + q.array[0]*1i;
}
void main()
{
enum size_t N = 1024;
enum size_t Times = 1024*32;
Vector!(float[4])[] xs = new Vector!(float[4])[N],
hs = xs.dup;
cfloat[] cxs = (cast(cfloat*)xs.ptr)[0 .. N*2],
chs = (cast(cfloat*)hs.ptr)[0 .. N*2];
foreach(i, ref e; cxs)
e = uniform01() + uniform01()*1i;
foreach(i, ref e; chs)
e = uniform01() + uniform01()*2i;
cfloat[5] res;
res[0] = 0+0i;
res[1] = 0+0i;
auto start = Clock.currTime;
{
foreach(times; 0 .. Times)
{
res[0] = 0+0i;
foreach(i; 0 .. N*2)
res[0] += cxs[i] * chs[i];
}
}
auto bnch1 = Clock.currTime - start;
start = Clock.currTime;
{
foreach(times; 0 .. Times)
{
res[1] = 0+0i;
auto px = cxs.ptr,
qx = px + cxs.length,
ph = chs.ptr;
while(px != qx)
{
res[1] += (*px) * (*ph);
++px;
++ph;
}
}
}
auto bnch2 = Clock.currTime - start;
start = Clock.currTime;
{
foreach(times; 0 .. Times)
{
res[2] = 0+0i;
auto px = cxs.ptr,
qx = px + cxs.length,
ph = chs.ptr;
while(px != qx)
{
res[2] += *(px+0) * *(ph+0);
res[2] += *(px+1) * *(ph+1);
res[2] += *(px+2) * *(ph+2);
res[2] += *(px+3) * *(ph+3);
res[2] += *(px+4) * *(ph+4);
res[2] += *(px+5) * *(ph+5);
res[2] += *(px+6) * *(ph+6);
res[2] += *(px+7) * *(ph+7);
px += 8;
ph += 8;
}
}
}
auto bnch3 = Clock.currTime - start;
start = Clock.currTime;
{
foreach(times; 0 .. Times)
res[3] = dotProduct(xs, hs);
}
auto bnch4 = Clock.currTime - start;
start = Clock.currTime;
{
foreach(times; parallel(iota(0, Times)))
res[4] = dotProduct(xs, hs); // やばそうだけど無視
}
auto bnch5 = Clock.currTime - start;
/*
459[ms], 146.067[Msps]
257[ms], 260.873[Msps]
212[ms], 316.248[Msps]
92[ms], 728.747[Msps]
21[ms], 3192.77[Msps]
-556.742+1556.79i == -556.742+1556.79i == -556.742+1556.79i == -556.741+1556.79i == -556.741+1556.79i
*/
writefln("%s[ms], %s[Msps]", bnch1.total!"msecs", cxs.length * Times * 1.0 / (bnch1.total!"usecs"));
writefln("%s[ms], %s[Msps]", bnch2.total!"msecs", cxs.length * Times * 1.0 / (bnch2.total!"usecs"));
writefln("%s[ms], %s[Msps]", bnch3.total!"msecs", cxs.length * Times * 1.0 / (bnch3.total!"usecs"));
writefln("%s[ms], %s[Msps]", bnch4.total!"msecs", cxs.length * Times * 1.0 / (bnch4.total!"usecs"));
writefln("%s[ms], %s[Msps]", bnch5.total!"msecs", cxs.length * Times * 1.0 / (bnch5.total!"usecs"));
writefln("%(%s == %)", res[]);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment