Created
January 10, 2022 23:12
-
-
Save zwegner/616aeac9a49a7e854c0743f2d7094791 to your computer and use it in GitHub Desktop.
Generate a random number with a given popcount, bisect + PDEP
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <immintrin.h> | |
#include <math.h> | |
#include <stdio.h> | |
#include <stdlib.h> | |
#include <sys/time.h> | |
// PRNG modified from the public domain RKISS by Bob Jenkins. See: | |
// http://www.burtleburtle.net/bob/rand/smallprng.html | |
typedef struct { | |
uint64_t a, b, c, d; | |
} rand_ctx_t; | |
uint64_t rotate_left(uint64_t x, uint64_t k) { | |
return (x << k) | (x >> (64 - k)); | |
} | |
uint64_t rand_next(rand_ctx_t *x) { | |
uint64_t e = x->a - rotate_left(x->b, 7); | |
x->a = x->b ^ rotate_left(x->c, 13); | |
x->b = x->c + rotate_left(x->d, 37); | |
x->c = x->d + e; | |
x->d = e + x->a; | |
return x->d; | |
} | |
void rand_init(rand_ctx_t *x) { | |
x->a = 0x89ABCDEF01234567ULL, x->b = x->c = x->d = 0xFEDCBA9876543210ULL; | |
for (int i = 0; i < 1000; i++) | |
(void)rand_next(x); | |
} | |
// Get the next value with the same number of bits | |
// from https://www.chessprogramming.org/Traversing_Subsets_of_a_Set#Snoobing_the_Universe | |
uint64_t snoob(uint64_t x) { | |
uint64_t smallest, ripple, ones; | |
smallest = x & -x; | |
ripple = x + smallest; | |
ones = x ^ ripple; | |
ones = (ones >> 2) / smallest; | |
return ripple | ones; | |
} | |
uint8_t choose_idx[8][8]; | |
uint8_t choose_len[8][8]; | |
uint8_t choose_table[1024]; | |
int iter_count; | |
// based on https://github.com/falk-hueffner/randkbits/blob/master/randkbits.cc | |
uint64_t rand_popcnt(rand_ctx_t *r, int k) { | |
uint64_t min = 0; | |
uint64_t max = ~(uint64_t)0; | |
int n = 0; | |
iter_count = 0; | |
while (n != k) { | |
iter_count++; | |
uint64_t x = rand_next(r); | |
x = min | (x & max); | |
n = _popcnt64(x); | |
if (n > k) | |
max = x; | |
else | |
min = x; | |
} | |
return min; | |
} | |
// variant of above with PDEP | |
uint64_t rand_popcnt_2(rand_ctx_t *r, int k) { | |
uint64_t min = 0; | |
uint64_t max = ~(uint64_t)0; | |
int n = 0, min_n = 0, max_n = 64; | |
iter_count = 0; | |
while (max_n - min_n > 7) { | |
iter_count++; | |
uint64_t x = rand_next(r); | |
x = min | (x & max); | |
n = _popcnt64(x); | |
if (n > k) { | |
max = x; | |
max_n = n; | |
} | |
else { | |
min = x; | |
min_n = n; | |
} | |
} | |
// Fill in extra bits | |
n = max_n - min_n; | |
k = k - min_n; | |
int offset = rand_next(r) % choose_len[n][k]; | |
uint64_t bits = choose_table[offset + choose_idx[n][k]]; | |
uint64_t extra = _pdep_u64(bits, min ^ max); | |
return min | extra; | |
} | |
uint64_t get_ms() { | |
struct timeval timeval; | |
gettimeofday(&timeval, NULL); | |
return (uint64_t)timeval.tv_sec * 1000 + (uint64_t)timeval.tv_usec / 1000; | |
} | |
int main() { | |
// Initialize n-choose-k tables | |
int offset = 0; | |
for (int i = 0; i < 8; i++) { | |
int max = 1 << i; | |
choose_len[i][0] = 1; | |
choose_idx[i][0] = offset; | |
choose_table[offset++] = 0; | |
for (int j = 1; j <= i; j++) { | |
int x = (1 << (j + 0)) - 1; | |
int c = 0; | |
choose_idx[i][j] = offset; | |
while (x < max) { | |
choose_table[offset++] = x; | |
x = snoob(x); | |
c++; | |
} | |
choose_len[i][j] = c; | |
} | |
} | |
rand_ctx_t r[1]; | |
rand_init(r); | |
const int TRIALS = 1<<21; | |
uint8_t results[TRIALS]; | |
for (int i = 0; i < 64; i++) { | |
int sum = 0, min = 100000, max = 0; | |
for (int j = 0; j < TRIALS; j++) { | |
uint64_t x = rand_popcnt(r, i); | |
if (_popcnt64(x) != i) { | |
printf("FAIL\n"); | |
return 1; | |
} | |
sum += iter_count; | |
if (iter_count < min) min = iter_count; | |
if (iter_count > max) max = iter_count; | |
results[j] = iter_count; | |
} | |
float mean = (float)sum / TRIALS; | |
float resid_sum = 0; | |
for (int j = 0; j < TRIALS; j++) { | |
float d = results[j] - mean; | |
resid_sum += d * d; | |
} | |
printf("k=%d min=%d max=%d mean=%.3f var=%.3f stddev=%.3f\n", i, min, max, mean, | |
resid_sum / TRIALS, sqrt(resid_sum / TRIALS)); | |
} | |
rand_init(r); | |
uint64_t sum = 0; // to keep results from getting DCE'd | |
time_t start = get_ms(); | |
for (int i = 0; i < 64; i++) | |
for (int j = 0; j < TRIALS; j++) | |
sum += rand_popcnt(r, i); | |
time_t end = get_ms(); | |
printf("rand_popcnt: csum=%2x t=%.3fs\n", sum & 0xFF, (end - start) / 1000.); | |
rand_init(r); | |
sum = 0; | |
start = get_ms(); | |
for (int i = 0; i < 64; i++) | |
for (int j = 0; j < TRIALS; j++) | |
sum += rand_popcnt_2(r, i); | |
end = get_ms(); | |
printf("rand_popcnt_2: csum=%2x t=%.3fs\n", sum & 0xFF, (end - start) / 1000.); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
k=0 min=2 max=7 mean=3.600 var=0.353 stddev=0.594 | |
k=1 min=2 max=8 mean=3.606 var=0.351 stddev=0.593 | |
k=2 min=2 max=7 mean=3.619 var=0.366 stddev=0.605 | |
k=3 min=2 max=7 mean=3.647 var=0.388 stddev=0.623 | |
k=4 min=2 max=8 mean=3.693 var=0.405 stddev=0.636 | |
k=5 min=2 max=8 mean=3.762 var=0.396 stddev=0.630 | |
k=6 min=2 max=7 mean=3.849 var=0.361 stddev=0.601 | |
k=7 min=2 max=8 mean=3.952 var=0.301 stddev=0.548 | |
k=8 min=2 max=8 mean=3.874 var=0.330 stddev=0.575 | |
k=9 min=2 max=8 mean=3.820 var=0.372 stddev=0.610 | |
k=10 min=2 max=8 mean=3.789 var=0.389 stddev=0.624 | |
k=11 min=2 max=8 mean=3.778 var=0.404 stddev=0.636 | |
k=12 min=2 max=8 mean=3.782 var=0.404 stddev=0.635 | |
k=13 min=2 max=8 mean=3.797 var=0.387 stddev=0.622 | |
k=14 min=2 max=7 mean=3.812 var=0.383 stddev=0.619 | |
k=15 min=2 max=8 mean=3.822 var=0.377 stddev=0.614 | |
k=16 min=2 max=8 mean=3.824 var=0.366 stddev=0.605 | |
k=17 min=2 max=8 mean=3.818 var=0.380 stddev=0.616 | |
k=18 min=2 max=8 mean=3.810 var=0.389 stddev=0.623 | |
k=19 min=2 max=8 mean=3.803 var=0.388 stddev=0.623 | |
k=20 min=2 max=8 mean=3.801 var=0.391 stddev=0.626 | |
k=21 min=2 max=9 mean=3.803 var=0.389 stddev=0.623 | |
k=22 min=2 max=8 mean=3.807 var=0.387 stddev=0.622 | |
k=23 min=2 max=8 mean=3.811 var=0.390 stddev=0.625 | |
k=24 min=2 max=7 mean=3.812 var=0.389 stddev=0.623 | |
k=25 min=2 max=8 mean=3.811 var=0.389 stddev=0.623 | |
k=26 min=2 max=9 mean=3.810 var=0.382 stddev=0.618 | |
k=27 min=2 max=8 mean=3.807 var=0.385 stddev=0.620 | |
k=28 min=2 max=8 mean=3.808 var=0.383 stddev=0.619 | |
k=29 min=2 max=8 mean=3.809 var=0.380 stddev=0.617 | |
k=30 min=2 max=8 mean=3.811 var=0.386 stddev=0.621 | |
k=31 min=2 max=8 mean=3.813 var=0.383 stddev=0.619 | |
k=32 min=2 max=8 mean=3.813 var=0.384 stddev=0.619 | |
k=33 min=2 max=8 mean=3.811 var=0.385 stddev=0.621 | |
k=34 min=2 max=7 mean=3.810 var=0.381 stddev=0.617 | |
k=35 min=2 max=8 mean=3.807 var=0.384 stddev=0.619 | |
k=36 min=2 max=8 mean=3.807 var=0.384 stddev=0.619 | |
k=37 min=2 max=8 mean=3.808 var=0.383 stddev=0.619 | |
k=38 min=2 max=7 mean=3.811 var=0.388 stddev=0.623 | |
k=39 min=2 max=7 mean=3.812 var=0.388 stddev=0.623 | |
k=40 min=2 max=7 mean=3.811 var=0.390 stddev=0.624 | |
k=41 min=2 max=8 mean=3.807 var=0.386 stddev=0.621 | |
k=42 min=2 max=8 mean=3.803 var=0.388 stddev=0.623 | |
k=43 min=2 max=8 mean=3.801 var=0.391 stddev=0.625 | |
k=44 min=2 max=8 mean=3.804 var=0.387 stddev=0.622 | |
k=45 min=2 max=8 mean=3.809 var=0.382 stddev=0.618 | |
k=46 min=2 max=8 mean=3.818 var=0.380 stddev=0.617 | |
k=47 min=2 max=8 mean=3.824 var=0.365 stddev=0.604 | |
k=48 min=2 max=7 mean=3.822 var=0.377 stddev=0.614 | |
k=49 min=2 max=8 mean=3.812 var=0.384 stddev=0.619 | |
k=50 min=2 max=8 mean=3.796 var=0.386 stddev=0.622 | |
k=51 min=2 max=8 mean=3.783 var=0.404 stddev=0.636 | |
k=52 min=2 max=8 mean=3.777 var=0.404 stddev=0.635 | |
k=53 min=2 max=7 mean=3.788 var=0.391 stddev=0.625 | |
k=54 min=2 max=8 mean=3.822 var=0.372 stddev=0.610 | |
k=55 min=2 max=8 mean=3.874 var=0.331 stddev=0.575 | |
k=56 min=2 max=8 mean=3.952 var=0.301 stddev=0.548 | |
k=57 min=2 max=8 mean=3.850 var=0.361 stddev=0.601 | |
k=58 min=2 max=8 mean=3.762 var=0.396 stddev=0.629 | |
k=59 min=2 max=7 mean=3.693 var=0.405 stddev=0.636 | |
k=60 min=2 max=8 mean=3.647 var=0.388 stddev=0.623 | |
k=61 min=2 max=7 mean=3.619 var=0.366 stddev=0.605 | |
k=62 min=2 max=8 mean=3.605 var=0.351 stddev=0.592 | |
k=63 min=2 max=7 mean=3.599 var=0.347 stddev=0.589 | |
rand_popcnt: csum=20 t=4.634s | |
rand_popcnt_2: csum=3c t=4.079s |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment