Skip to content

Instantly share code, notes, and snippets.

@taku0
Last active November 6, 2016 02:19
Show Gist options
  • Save taku0/27359bddac782fdd76ab9170fba4b957 to your computer and use it in GitHub Desktop.
Save taku0/27359bddac782fdd76ab9170fba4b957 to your computer and use it in GitHub Desktop.
Box-Muller法による標準正規分布からのサンプリングと、Irwin–Hall分布による近似のベンチマーク
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <math.h>
#include <time.h>
/* 一様分布に従う疑似乱数値を返す */
static uint64_t uniform1(void);
static uint64_t uniform2(void);
static uint64_t uniform3(void);
static uint64_t uniform4(void);
static uint64_t uniform5(void);
static uint64_t uniform6(void);
/* 符号無し64ビット整数を、[1, 2)の範囲のdoubleに変換して返す。 */
static double uint64_to_double_plus_1(uint64_t x) {
x = x >> 12 | 0x3FF0000000000000ULL;
double result;
memcpy(&result, &x, sizeof(x));
return result;
}
/* 符号無し64ビット整数を、[0, 1)の範囲のdoubleに変換して返す。 */
static double uint64_to_double(uint64_t x) {
return uint64_to_double_plus_1(x) - 1;
}
/* Box-Muller法により、標準正規分布に従う疑似乱数値を返す。 */
static double box_muller(void) {
double r1 = uint64_to_double_plus_1(uniform1());
double r2 = uint64_to_double_plus_1(uniform2());
return sqrt(-2 * log(2 - r1)) * sin(2 * M_PI * r2);
}
/*
* n=3のIrwin-Hall分布を、平均0・分散1に補正した分布に従う疑似乱数値を返す。
* 3つの乱数生成器を1回ずつ使う。
*/
static double sum_of_3_uniforms(void) {
double r1 = uint64_to_double_plus_1(uniform1());
double r2 = uint64_to_double_plus_1(uniform2());
double r3 = uint64_to_double_plus_1(uniform3());
return (r1 + r2 + r3 - 1.5 - 3) * sqrt(12 / 3);
}
/*
* n=3のIrwin-Hall分布を、平均0・分散1に補正した分布に従う疑似乱数値を返す。
* 1つの乱数生成器を3回使う。
*/
static double sum_of_3_uniforms_seq(void) {
double r1 = uint64_to_double_plus_1(uniform1());
double r2 = uint64_to_double_plus_1(uniform1());
double r3 = uint64_to_double_plus_1(uniform1());
return (r1 + r2 + r3 - 1.5 - 3) * sqrt(12 / 3);
}
/*
* n=6のIrwin-Hall分布を、平均0・分散1に補正した分布に従う疑似乱数値を返す。
* 6つの乱数生成器を1回ずつ使う。
*/
static double sum_of_6_uniforms(void) {
double r1 = uint64_to_double_plus_1(uniform1());
double r2 = uint64_to_double_plus_1(uniform2());
double r3 = uint64_to_double_plus_1(uniform3());
double r4 = uint64_to_double_plus_1(uniform4());
double r5 = uint64_to_double_plus_1(uniform5());
double r6 = uint64_to_double_plus_1(uniform6());
return (r1 + r2 + r3 + r4 + r5 + r6 - 3 - 6) * sqrt(12 / 6);
}
/*
* n=6のIrwin-Hall分布を、平均0・分散1に補正した分布に従う疑似乱数値を返す。
* 1つの乱数生成器を6回使う。
*/
static double sum_of_6_uniforms_seq(void) {
double r1 = uint64_to_double_plus_1(uniform1());
double r2 = uint64_to_double_plus_1(uniform1());
double r3 = uint64_to_double_plus_1(uniform1());
double r4 = uint64_to_double_plus_1(uniform1());
double r5 = uint64_to_double_plus_1(uniform1());
double r6 = uint64_to_double_plus_1(uniform1());
return (r1 + r2 + r3 + r4 + r5 + r6 - 3 - 6) * sqrt(12 / 6);
}
/*
* n=6のIrwin-Hall分布を、平均0・分散1に補正した分布に従う疑似乱数値を返す。
* なるべく整数で演算する。
* 6つの乱数生成器を1回ずつ使う。
*/
static double sum_of_6_uniforms_int(void) {
int64_t u1 = uniform1() >> 4;
int64_t u2 = uniform2() >> 4;
int64_t u3 = uniform3() >> 4;
int64_t u4 = uniform4() >> 4;
int64_t u5 = uniform5() >> 4;
int64_t u6 = uniform6() >> 4;
int64_t denominator = (UINT64_MAX >> 4);
return ((double) (u1 + u2 + u3 + u4 + u5 + u6 - 3 * denominator)) / denominator * sqrt(12 / 6);
}
/*
* n=6のIrwin-Hall分布を、平均0・分散1に補正した分布に従う疑似乱数値を返す。
* なるべく整数で演算する。
* 1つの乱数生成器を6回使う。
*/
static double sum_of_6_uniforms_int_seq(void) {
int64_t u1 = uniform1() >> 4;
int64_t u2 = uniform1() >> 4;
int64_t u3 = uniform1() >> 4;
int64_t u4 = uniform1() >> 4;
int64_t u5 = uniform1() >> 4;
int64_t u6 = uniform1() >> 4;
int64_t denominator = (UINT64_MAX >> 4);
return ((double) (u1 + u2 + u3 + u4 + u5 + u6 - 3 * denominator)) / denominator * sqrt(12 / 6);
}
/*
* ziggurat法。
* 以下のソースより移植。
* https://github.com/komiya-atsushi/fast-rng-java/blob/master/fast-rng/src/main/java/biz/k11i/rng/GaussianRNG.java
*/
#define ZIGGURAT_TABLE_SIZE 256
#define ZIGGURAT_TABLE_MASK (ZIGGURAT_TABLE_SIZE - 1)
#define ZIGGURAT_R 3.6541528853610088
#define ZIGGURAT_V 0.00492867323399
static int64_t ziggurat_k[ZIGGURAT_TABLE_SIZE];
static double ziggurat_w[ZIGGURAT_TABLE_SIZE];
static double ziggurat_f[ZIGGURAT_TABLE_SIZE];
static double gaussian(double x) {
return exp(-0.5 * x * x);
}
/* zigguratアルゴリズムの表を作成する。 */
static void init_ziggurat_table(void) {
double fr = gaussian(ZIGGURAT_R);
ziggurat_k[0] = (int64_t) (INT64_MAX * ZIGGURAT_R * fr / ZIGGURAT_V);
ziggurat_k[1] = 0;
ziggurat_w[0] = ZIGGURAT_V / fr / INT64_MAX;
ziggurat_w[ZIGGURAT_TABLE_SIZE - 1] = ZIGGURAT_R / INT64_MAX;
ziggurat_f[0] = 1;
ziggurat_f[ZIGGURAT_TABLE_SIZE - 1] = fr;
double dn = ZIGGURAT_R;
double tn = ZIGGURAT_R;
for (int i = ZIGGURAT_TABLE_SIZE - 2; i >= 1; i--) {
dn = sqrt(-2 * log(ZIGGURAT_V / dn + gaussian(dn)));
ziggurat_k[i + 1] = (int64_t) (INT64_MAX * dn / tn);
tn = dn;
ziggurat_w[i] = dn / INT64_MAX;
ziggurat_f[i] = gaussian(dn);
}
}
/* zigguratアルゴリズムにより、標準正規分布に従う疑似乱数値を返す。 */
static double ziggurat(void) {
while (1) {
int64_t j = (int64_t) uniform1();
if (j == INT64_MIN) {
continue;
}
int i = (int) (j & ZIGGURAT_TABLE_MASK);
if (llabs(j) < ziggurat_k[i]) {
return j * ziggurat_w[i];
}
if (i == 0) {
double x, y;
do {
x = -log(uint64_to_double(uniform1())) / ZIGGURAT_R;
y = -log(uint64_to_double(uniform1()));
} while (2 * y < x * x);
return (j > 0) ? (ZIGGURAT_R + x) : -(ZIGGURAT_R + x);
}
double x = j * ziggurat_w[i];
if ((ziggurat_f[i - 1] - ziggurat_f[i]) * uint64_to_double(uniform1()) < gaussian(x) - ziggurat_f[i]) {
return x;
}
}
}
static void benchmark(void) {
int N = 10000000;
double sum = 0.0;
clock_t start = clock();
for (int i = 0; i < N; i++) {
sum += box_muller();
}
double duration = ((double) (clock() - start)) / CLOCKS_PER_SEC * 1000;
printf("box_muller: %f ms, sum: %f\n", duration, sum);
sum = 0;
start = clock();
for (int i = 0; i < N; i++) {
sum += sum_of_3_uniforms();
}
duration = ((double) (clock() - start)) / CLOCKS_PER_SEC * 1000;
printf("sum_of_3_uniforms: %f ms, sum: %f\n", duration, sum);
sum = 0;
start = clock();
for (int i = 0; i < N; i++) {
sum += sum_of_3_uniforms_seq();
}
duration = ((double) (clock() - start)) / CLOCKS_PER_SEC * 1000;
printf("sum_of_3_uniforms_seq: %f ms, sum: %f\n", duration, sum);
sum = 0;
start = clock();
for (int i = 0; i < N; i++) {
sum += sum_of_6_uniforms();
}
duration = ((double) (clock() - start)) / CLOCKS_PER_SEC * 1000;
printf("sum_of_6_uniforms: %f ms, sum: %f\n", duration, sum);
sum = 0;
start = clock();
for (int i = 0; i < N; i++) {
sum += sum_of_6_uniforms_seq();
}
duration = ((double) (clock() - start)) / CLOCKS_PER_SEC * 1000;
printf("sum_of_6_uniforms_seq: %f ms, sum: %f\n", duration, sum);
sum = 0;
start = clock();
for (int i = 0; i < N; i++) {
sum += sum_of_6_uniforms_int();
}
duration = ((double) (clock() - start)) / CLOCKS_PER_SEC * 1000;
printf("sum_of_6_uniforms_int: %f ms, sum: %f\n", duration, sum);
sum = 0;
start = clock();
for (int i = 0; i < N; i++) {
sum += sum_of_6_uniforms_int_seq();
}
duration = ((double) (clock() - start)) / CLOCKS_PER_SEC * 1000;
printf("sum_of_6_uniforms_int_seq: %f ms, sum: %f\n", duration, sum);
sum = 0;
start = clock();
for (int i = 0; i < N; i++) {
sum += ziggurat();
}
duration = ((double) (clock() - start)) / CLOCKS_PER_SEC * 1000;
printf("ziggurat: %f ms, sum: %f\n", duration, sum);
}
int main(void) {
init_ziggurat_table();
for (int i = 0; i < 100; i++) {
benchmark();
printf("\n");
}
exit(EXIT_SUCCESS);
}
// xorshift128+の実装
static uint64_t state1[2] = {
0xDFE872B1EF784CA2,
0x6C74F51762B3F1D4,
};
static uint64_t uniform1(void) {
uint64_t x = state1[0];
uint64_t y = state1[1];
state1[0] = y;
x ^= x << 23;
state1[1] = x ^ y ^ (x >> 17) ^ (y >> 26);
return state1[1] + y;
}
static uint64_t state2[2] = {
0x750B4FE530F72E0D,
0xEE7112F206FAEF26,
};
static uint64_t uniform2(void) {
uint64_t x = state2[0];
uint64_t y = state2[1];
state2[0] = y;
x ^= x << 23;
state2[1] = x ^ y ^ (x >> 17) ^ (y >> 26);
return state2[1] + y;
}
static uint64_t state3[2] = {
0xE1FBC56568D00BB2,
0xF896E69B67FFF376,
};
static uint64_t uniform3(void) {
uint64_t x = state3[0];
uint64_t y = state3[1];
state3[0] = y;
x ^= x << 23;
state3[1] = x ^ y ^ (x >> 17) ^ (y >> 26);
return state3[1] + y;
}
static uint64_t state4[2] = {
0x13705DFECA924ABA,
0x65EA06054CE6623D,
};
static uint64_t uniform4(void) {
uint64_t x = state4[0];
uint64_t y = state4[1];
state4[0] = y;
x ^= x << 23;
state4[1] = x ^ y ^ (x >> 17) ^ (y >> 26);
return state4[1] + y;
}
static uint64_t state5[2] = {
0xDAAE2E4BA1DE5DE3,
0x4ABB307D19322228,
};
static uint64_t uniform5(void) {
uint64_t x = state5[0];
uint64_t y = state5[1];
state5[0] = y;
x ^= x << 23;
state5[1] = x ^ y ^ (x >> 17) ^ (y >> 26);
return state5[1] + y;
}
static uint64_t state6[2] = {
0xB6047B679F17A720,
0x79B7CEF95A0A0C53,
};
static uint64_t uniform6(void) {
uint64_t x = state6[0];
uint64_t y = state6[1];
state6[0] = y;
x ^= x << 23;
state6[1] = x ^ y ^ (x >> 17) ^ (y >> 26);
return state6[1] + y;
}
@taku0
Copy link
Author

taku0 commented Nov 4, 2016

実行例 (GCC 5.4.0, -O3 -march=native, Linux, Core i7-6770HQ):

box_muller:                700.413000 ms, sum: -2039.357388
sum_of_3_uniforms:         43.305000 ms, sum: -3055.950091
sum_of_3_uniforms_seq:     43.656000 ms, sum: -2063.446809
sum_of_6_uniforms:         87.254000 ms, sum: -577.620068
sum_of_6_uniforms_seq:     79.455000 ms, sum: 5416.853414
sum_of_6_uniforms_int:     79.684000 ms, sum: -7334.328077
sum_of_6_uniforms_int_seq: 75.294000 ms, sum: 1507.350777

関数からstaticをはずした際の実行例:

box_muller:                713.401000 ms, sum: -2039.357388
sum_of_3_uniforms:         123.184000 ms, sum: -3055.950091
sum_of_3_uniforms_seq:     122.231000 ms, sum: -2063.446809
sum_of_6_uniforms:         243.478000 ms, sum: -577.620068
sum_of_6_uniforms_seq:     272.963000 ms, sum: 5416.853414
sum_of_6_uniforms_int:     107.300000 ms, sum: -7334.328077
sum_of_6_uniforms_int_seq: 132.558000 ms, sum: 1507.350777

追記 (ziggurat法追加。分散の補正の方法を修正):

box_muller:                698.166000 ms, sum: -1100.683034
sum_of_3_uniforms:         43.021000 ms, sum: 978.590960
sum_of_3_uniforms_seq:     43.208000 ms, sum: 95.209094
sum_of_6_uniforms:         84.783000 ms, sum: -1134.027693
sum_of_6_uniforms_seq:     80.352000 ms, sum: 6595.133271
sum_of_6_uniforms_int:     78.527000 ms, sum: -566.475028
sum_of_6_uniforms_int_seq: 79.541000 ms, sum: 3035.383612
ziggurat:                  37.752000 ms, sum: -6942.589633

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment