Created
May 6, 2017 14:25
-
-
Save Ushio/0203917f5d9e599ede5bdac579afd00e to your computer and use it in GitHub Desktop.
インクリメンタル分散計算
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdio.h> | |
#include <random> | |
class Kahan { | |
public: | |
Kahan() {} | |
Kahan(double value) {} | |
void add(double x) { | |
double y = x - _c; | |
double t = _sum + y; | |
_c = (t - _sum) - y; | |
_sum = t; | |
} | |
void operator=(double x) { | |
_sum = x; | |
_c = 0.0; | |
} | |
void operator+=(double x) { | |
add(x); | |
} | |
operator double() const { | |
return _sum; | |
} | |
private: | |
double _sum = 0.0; | |
double _c = 0.0; | |
}; | |
class IncrementalStatatics { | |
public: | |
void addSample(double value) { | |
_sum_of_squared += value * value; | |
_sum += value; | |
_n++; | |
} | |
double variance() const { | |
return (_sum_of_squared - _sum * _sum / _n) / _n; | |
} | |
double avarage() const { | |
return _sum / _n; | |
} | |
private: | |
Kahan _sum_of_squared = 0.0; | |
Kahan _sum = 0.0; | |
int _n = 0; | |
}; | |
class OnlineVariance { | |
public: | |
void addSample(double x) { | |
_n++; | |
double delta = x - _mean; | |
_mean += delta / _n; | |
double delta2 = x - _mean; | |
_M2 += delta * delta2; | |
} | |
double variance() const { | |
// return _M2 / (_n - 1); | |
return _M2 / _n; | |
} | |
double avarage() const { | |
return _mean; | |
} | |
OnlineVariance merge(const OnlineVariance &rhs) const { | |
OnlineVariance r; | |
double ma = _mean; | |
double mb = rhs._mean; | |
double N = _n; | |
double M = rhs._n; | |
double N_M = N + M; | |
double a = N / N_M; | |
double b = M / N_M; | |
r._mean = a * ma + b * mb; | |
r._M2 = _M2 + rhs._M2; | |
r._n = N_M; | |
return r; | |
} | |
private: | |
Kahan _mean = 0.0; | |
Kahan _M2 = 0.0; | |
int _n = 0; | |
}; | |
int main() | |
{ | |
std::vector<double> data; | |
std::mt19937 engine; | |
std::uniform_real_distribution<> random(-1.0, 1.0); | |
for (int i = 0; i < 100000; ++i) { | |
data.push_back(4.0 + random(engine) + random(engine) + random(engine) + random(engine)); | |
} | |
{ | |
double sum = 0.0; | |
for (int i = 0; i < data.size(); ++i) { | |
sum += data[i]; | |
} | |
double average = sum / data.size(); | |
double variance_sum = 0.0; | |
for (int i = 0; i < data.size(); ++i) { | |
variance_sum += (average - data[i]) * (average - data[i]); | |
} | |
double variance = variance_sum / data.size(); | |
printf("avg = %.4f\n", average); | |
printf("variance = %.4f\n", variance); | |
printf("sd = %.4f\n\n", sqrt(variance)); | |
} | |
{ | |
int n = 0; | |
double sum_of_squared = 0.0; | |
double sum = 0.0; | |
for (double value : data) { | |
sum_of_squared += value * value; | |
sum += value; | |
n++; | |
} | |
double variance = (sum_of_squared - sum * sum / n) / n; | |
double average = sum / n; | |
printf("avg = %.4f\n", average); | |
printf("variance = %.4f\n", variance); | |
printf("sd = %.4f\n\n", sqrt(variance)); | |
} | |
{ | |
IncrementalStatatics statatics; | |
for (double value : data) { | |
statatics.addSample(value); | |
} | |
printf("avg = %.4f\n", statatics.avarage()); | |
printf("variance = %.4f\n", statatics.variance()); | |
printf("sd = %.4f\n\n", sqrt(statatics.variance())); | |
} | |
{ | |
OnlineVariance statatics; | |
for (double value : data) { | |
statatics.addSample(value); | |
} | |
printf("avg = %.4f\n", statatics.avarage()); | |
printf("variance = %.4f\n", statatics.variance()); | |
printf("sd = %.4f\n\n", sqrt(statatics.variance())); | |
} | |
{ | |
std::vector<double> data_a; | |
std::vector<double> data_b; | |
for (int i = 0; i < 30000; ++i) { | |
data_a.push_back(data[i]); | |
} | |
for (int i = 0; i < 70000; ++i) { | |
data_b.push_back(data[30000 + i]); | |
} | |
OnlineVariance statatics_a; | |
for (double value : data_a) { | |
statatics_a.addSample(value); | |
} | |
OnlineVariance statatics_b; | |
for (double value : data_b) { | |
statatics_b.addSample(value); | |
} | |
OnlineVariance statatics = statatics_a.merge(statatics_b); | |
printf("avg = %.4f\n", statatics.avarage()); | |
printf("variance = %.4f\n", statatics.variance()); | |
printf("sd = %.4f\n\n", sqrt(statatics.variance())); | |
} | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment