-
-
Save foolnotion/4af2c5695d1285536dff1e97b502d99f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
void Add(double x) | |
{ | |
if (n <= 0) { | |
n = 1; | |
s = x; | |
q = 0; | |
return; | |
} | |
double d = n * x - s; | |
n += 1; | |
s += x; | |
q += d * d / (n * (n-1)); | |
} | |
void Add(gsl::span<const double> values) { | |
auto p = values.data(); | |
auto end = p + values.size(); | |
n = 1; | |
s = *p; | |
q = 0; | |
while (++p != end) { | |
double x = *p; | |
double d = n * x - s; | |
n = n + 1; | |
s += x; | |
q += d * d / (n * (n-1)); | |
} | |
} | |
void AddSIMD(gsl::span<const double> values) { | |
// the general idea is to partition the data and perform this computation in parallel | |
if (values.size() < 16) { | |
return Add(values); | |
} | |
size_t sz = values.size() - values.size() % 4; // closest multiple of 4 | |
size_t ps = sz / 4; // partition size | |
gsl::span<const double> parts[4] = { | |
values.subspan(0 * ps, ps), | |
values.subspan(1 * ps, ps), | |
values.subspan(2 * ps, ps), | |
values.subspan(3 * ps, ps) | |
}; | |
Eigen::Array4d ss; // sums | |
Eigen::Array4d qq; // sums of squares | |
Eigen::Array4d nn; // counts | |
ss(0) = parts[0][0]; | |
ss(1) = parts[1][0]; | |
ss(2) = parts[2][0]; | |
ss(3) = parts[3][0]; | |
qq.fill(0); | |
nn.fill(1); | |
for (size_t i = 1; i < ps; ++i) { | |
Eigen::Array4d xx { | |
parts[0][i], | |
parts[1][i], | |
parts[2][i], | |
parts[3][i] | |
}; | |
Eigen::Array4d dd = nn * xx - ss; | |
nn += 1.0; | |
ss += xx; | |
qq += dd * dd / (nn * (nn - 1)); | |
} | |
s = ss.sum(); | |
n = nn.sum(); | |
// reduction step | |
// r01 = reduce(part0, part1) | |
double tmp = nn(0) * ss(1) - nn(1) * ss(0); | |
double n01 = nn(0) + nn(1); | |
double q01 = qq(0) + qq(1) + tmp * tmp / (nn(0) * n01 * nn(1)); | |
double s01 = ss(0) + ss(1); | |
// r23 = reduce(part2, part3) | |
tmp = nn(2) * ss(3) - nn(3) * ss(2); | |
double n23 = nn(2) + nn(3); | |
double q23 = qq(2) + qq(3) + tmp * tmp / (nn(2) * n23 * nn(3)); | |
double s23 = ss(2) + ss(3); | |
// final result = reduce(r1, r2) | |
tmp = (nn(0) + nn(1)) * s23 - (nn(2) + nn(3)) * s01; | |
q = q01 + q23 + tmp * tmp / (n01 * (n01 + n23) * n23); | |
// deal with remaining values | |
if (sz < values.size()) { | |
Add(values.subspan(sz, values.size() - sz)); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment