Skip to content

Instantly share code, notes, and snippets.

@foolnotion
Last active December 5, 2020 01:00
Show Gist options
  • Save foolnotion/4af2c5695d1285536dff1e97b502d99f to your computer and use it in GitHub Desktop.
Save foolnotion/4af2c5695d1285536dff1e97b502d99f to your computer and use it in GitHub Desktop.
void Add(double x)
{
if (n <= 0) {
n = 1;
s = x;
q = 0;
return;
}
double d = n * x - s;
n += 1;
s += x;
q += d * d / (n * (n-1));
}
void Add(gsl::span<const double> values) {
auto p = values.data();
auto end = p + values.size();
n = 1;
s = *p;
q = 0;
while (++p != end) {
double x = *p;
double d = n * x - s;
n = n + 1;
s += x;
q += d * d / (n * (n-1));
}
}
void AddSIMD(gsl::span<const double> values) {
// the general idea is to partition the data and perform this computation in parallel
if (values.size() < 16) {
return Add(values);
}
size_t sz = values.size() - values.size() % 4; // closest multiple of 4
size_t ps = sz / 4; // partition size
gsl::span<const double> parts[4] = {
values.subspan(0 * ps, ps),
values.subspan(1 * ps, ps),
values.subspan(2 * ps, ps),
values.subspan(3 * ps, ps)
};
Eigen::Array4d ss; // sums
Eigen::Array4d qq; // sums of squares
Eigen::Array4d nn; // counts
ss(0) = parts[0][0];
ss(1) = parts[1][0];
ss(2) = parts[2][0];
ss(3) = parts[3][0];
qq.fill(0);
nn.fill(1);
for (size_t i = 1; i < ps; ++i) {
Eigen::Array4d xx {
parts[0][i],
parts[1][i],
parts[2][i],
parts[3][i]
};
Eigen::Array4d dd = nn * xx - ss;
nn += 1.0;
ss += xx;
qq += dd * dd / (nn * (nn - 1));
}
s = ss.sum();
n = nn.sum();
// reduction step
// r01 = reduce(part0, part1)
double tmp = nn(0) * ss(1) - nn(1) * ss(0);
double n01 = nn(0) + nn(1);
double q01 = qq(0) + qq(1) + tmp * tmp / (nn(0) * n01 * nn(1));
double s01 = ss(0) + ss(1);
// r23 = reduce(part2, part3)
tmp = nn(2) * ss(3) - nn(3) * ss(2);
double n23 = nn(2) + nn(3);
double q23 = qq(2) + qq(3) + tmp * tmp / (nn(2) * n23 * nn(3));
double s23 = ss(2) + ss(3);
// final result = reduce(r1, r2)
tmp = (nn(0) + nn(1)) * s23 - (nn(2) + nn(3)) * s01;
q = q01 + q23 + tmp * tmp / (n01 * (n01 + n23) * n23);
// deal with remaining values
if (sz < values.size()) {
Add(values.subspan(sz, values.size() - sz));
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment