Skip to content

Instantly share code, notes, and snippets.

@zwegner
Created January 10, 2022 23:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zwegner/616aeac9a49a7e854c0743f2d7094791 to your computer and use it in GitHub Desktop.
Save zwegner/616aeac9a49a7e854c0743f2d7094791 to your computer and use it in GitHub Desktop.
Generate a random number with a given popcount, bisect + PDEP
#include <immintrin.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/time.h>
// PRNG modified from the public domain RKISS by Bob Jenkins. See:
// http://www.burtleburtle.net/bob/rand/smallprng.html
typedef struct {
uint64_t a, b, c, d;
} rand_ctx_t;
uint64_t rotate_left(uint64_t x, uint64_t k) {
return (x << k) | (x >> (64 - k));
}
uint64_t rand_next(rand_ctx_t *x) {
uint64_t e = x->a - rotate_left(x->b, 7);
x->a = x->b ^ rotate_left(x->c, 13);
x->b = x->c + rotate_left(x->d, 37);
x->c = x->d + e;
x->d = e + x->a;
return x->d;
}
void rand_init(rand_ctx_t *x) {
x->a = 0x89ABCDEF01234567ULL, x->b = x->c = x->d = 0xFEDCBA9876543210ULL;
for (int i = 0; i < 1000; i++)
(void)rand_next(x);
}
// Get the next value with the same number of bits
// from https://www.chessprogramming.org/Traversing_Subsets_of_a_Set#Snoobing_the_Universe
uint64_t snoob(uint64_t x) {
uint64_t smallest, ripple, ones;
smallest = x & -x;
ripple = x + smallest;
ones = x ^ ripple;
ones = (ones >> 2) / smallest;
return ripple | ones;
}
uint8_t choose_idx[8][8];
uint8_t choose_len[8][8];
uint8_t choose_table[1024];
int iter_count;
// based on https://github.com/falk-hueffner/randkbits/blob/master/randkbits.cc
uint64_t rand_popcnt(rand_ctx_t *r, int k) {
uint64_t min = 0;
uint64_t max = ~(uint64_t)0;
int n = 0;
iter_count = 0;
while (n != k) {
iter_count++;
uint64_t x = rand_next(r);
x = min | (x & max);
n = _popcnt64(x);
if (n > k)
max = x;
else
min = x;
}
return min;
}
// variant of above with PDEP
uint64_t rand_popcnt_2(rand_ctx_t *r, int k) {
uint64_t min = 0;
uint64_t max = ~(uint64_t)0;
int n = 0, min_n = 0, max_n = 64;
iter_count = 0;
while (max_n - min_n > 7) {
iter_count++;
uint64_t x = rand_next(r);
x = min | (x & max);
n = _popcnt64(x);
if (n > k) {
max = x;
max_n = n;
}
else {
min = x;
min_n = n;
}
}
// Fill in extra bits
n = max_n - min_n;
k = k - min_n;
int offset = rand_next(r) % choose_len[n][k];
uint64_t bits = choose_table[offset + choose_idx[n][k]];
uint64_t extra = _pdep_u64(bits, min ^ max);
return min | extra;
}
uint64_t get_ms() {
struct timeval timeval;
gettimeofday(&timeval, NULL);
return (uint64_t)timeval.tv_sec * 1000 + (uint64_t)timeval.tv_usec / 1000;
}
int main() {
// Initialize n-choose-k tables
int offset = 0;
for (int i = 0; i < 8; i++) {
int max = 1 << i;
choose_len[i][0] = 1;
choose_idx[i][0] = offset;
choose_table[offset++] = 0;
for (int j = 1; j <= i; j++) {
int x = (1 << (j + 0)) - 1;
int c = 0;
choose_idx[i][j] = offset;
while (x < max) {
choose_table[offset++] = x;
x = snoob(x);
c++;
}
choose_len[i][j] = c;
}
}
rand_ctx_t r[1];
rand_init(r);
const int TRIALS = 1<<21;
uint8_t results[TRIALS];
for (int i = 0; i < 64; i++) {
int sum = 0, min = 100000, max = 0;
for (int j = 0; j < TRIALS; j++) {
uint64_t x = rand_popcnt(r, i);
if (_popcnt64(x) != i) {
printf("FAIL\n");
return 1;
}
sum += iter_count;
if (iter_count < min) min = iter_count;
if (iter_count > max) max = iter_count;
results[j] = iter_count;
}
float mean = (float)sum / TRIALS;
float resid_sum = 0;
for (int j = 0; j < TRIALS; j++) {
float d = results[j] - mean;
resid_sum += d * d;
}
printf("k=%d min=%d max=%d mean=%.3f var=%.3f stddev=%.3f\n", i, min, max, mean,
resid_sum / TRIALS, sqrt(resid_sum / TRIALS));
}
rand_init(r);
uint64_t sum = 0; // to keep results from getting DCE'd
time_t start = get_ms();
for (int i = 0; i < 64; i++)
for (int j = 0; j < TRIALS; j++)
sum += rand_popcnt(r, i);
time_t end = get_ms();
printf("rand_popcnt: csum=%2x t=%.3fs\n", sum & 0xFF, (end - start) / 1000.);
rand_init(r);
sum = 0;
start = get_ms();
for (int i = 0; i < 64; i++)
for (int j = 0; j < TRIALS; j++)
sum += rand_popcnt_2(r, i);
end = get_ms();
printf("rand_popcnt_2: csum=%2x t=%.3fs\n", sum & 0xFF, (end - start) / 1000.);
}
k=0 min=2 max=7 mean=3.600 var=0.353 stddev=0.594
k=1 min=2 max=8 mean=3.606 var=0.351 stddev=0.593
k=2 min=2 max=7 mean=3.619 var=0.366 stddev=0.605
k=3 min=2 max=7 mean=3.647 var=0.388 stddev=0.623
k=4 min=2 max=8 mean=3.693 var=0.405 stddev=0.636
k=5 min=2 max=8 mean=3.762 var=0.396 stddev=0.630
k=6 min=2 max=7 mean=3.849 var=0.361 stddev=0.601
k=7 min=2 max=8 mean=3.952 var=0.301 stddev=0.548
k=8 min=2 max=8 mean=3.874 var=0.330 stddev=0.575
k=9 min=2 max=8 mean=3.820 var=0.372 stddev=0.610
k=10 min=2 max=8 mean=3.789 var=0.389 stddev=0.624
k=11 min=2 max=8 mean=3.778 var=0.404 stddev=0.636
k=12 min=2 max=8 mean=3.782 var=0.404 stddev=0.635
k=13 min=2 max=8 mean=3.797 var=0.387 stddev=0.622
k=14 min=2 max=7 mean=3.812 var=0.383 stddev=0.619
k=15 min=2 max=8 mean=3.822 var=0.377 stddev=0.614
k=16 min=2 max=8 mean=3.824 var=0.366 stddev=0.605
k=17 min=2 max=8 mean=3.818 var=0.380 stddev=0.616
k=18 min=2 max=8 mean=3.810 var=0.389 stddev=0.623
k=19 min=2 max=8 mean=3.803 var=0.388 stddev=0.623
k=20 min=2 max=8 mean=3.801 var=0.391 stddev=0.626
k=21 min=2 max=9 mean=3.803 var=0.389 stddev=0.623
k=22 min=2 max=8 mean=3.807 var=0.387 stddev=0.622
k=23 min=2 max=8 mean=3.811 var=0.390 stddev=0.625
k=24 min=2 max=7 mean=3.812 var=0.389 stddev=0.623
k=25 min=2 max=8 mean=3.811 var=0.389 stddev=0.623
k=26 min=2 max=9 mean=3.810 var=0.382 stddev=0.618
k=27 min=2 max=8 mean=3.807 var=0.385 stddev=0.620
k=28 min=2 max=8 mean=3.808 var=0.383 stddev=0.619
k=29 min=2 max=8 mean=3.809 var=0.380 stddev=0.617
k=30 min=2 max=8 mean=3.811 var=0.386 stddev=0.621
k=31 min=2 max=8 mean=3.813 var=0.383 stddev=0.619
k=32 min=2 max=8 mean=3.813 var=0.384 stddev=0.619
k=33 min=2 max=8 mean=3.811 var=0.385 stddev=0.621
k=34 min=2 max=7 mean=3.810 var=0.381 stddev=0.617
k=35 min=2 max=8 mean=3.807 var=0.384 stddev=0.619
k=36 min=2 max=8 mean=3.807 var=0.384 stddev=0.619
k=37 min=2 max=8 mean=3.808 var=0.383 stddev=0.619
k=38 min=2 max=7 mean=3.811 var=0.388 stddev=0.623
k=39 min=2 max=7 mean=3.812 var=0.388 stddev=0.623
k=40 min=2 max=7 mean=3.811 var=0.390 stddev=0.624
k=41 min=2 max=8 mean=3.807 var=0.386 stddev=0.621
k=42 min=2 max=8 mean=3.803 var=0.388 stddev=0.623
k=43 min=2 max=8 mean=3.801 var=0.391 stddev=0.625
k=44 min=2 max=8 mean=3.804 var=0.387 stddev=0.622
k=45 min=2 max=8 mean=3.809 var=0.382 stddev=0.618
k=46 min=2 max=8 mean=3.818 var=0.380 stddev=0.617
k=47 min=2 max=8 mean=3.824 var=0.365 stddev=0.604
k=48 min=2 max=7 mean=3.822 var=0.377 stddev=0.614
k=49 min=2 max=8 mean=3.812 var=0.384 stddev=0.619
k=50 min=2 max=8 mean=3.796 var=0.386 stddev=0.622
k=51 min=2 max=8 mean=3.783 var=0.404 stddev=0.636
k=52 min=2 max=8 mean=3.777 var=0.404 stddev=0.635
k=53 min=2 max=7 mean=3.788 var=0.391 stddev=0.625
k=54 min=2 max=8 mean=3.822 var=0.372 stddev=0.610
k=55 min=2 max=8 mean=3.874 var=0.331 stddev=0.575
k=56 min=2 max=8 mean=3.952 var=0.301 stddev=0.548
k=57 min=2 max=8 mean=3.850 var=0.361 stddev=0.601
k=58 min=2 max=8 mean=3.762 var=0.396 stddev=0.629
k=59 min=2 max=7 mean=3.693 var=0.405 stddev=0.636
k=60 min=2 max=8 mean=3.647 var=0.388 stddev=0.623
k=61 min=2 max=7 mean=3.619 var=0.366 stddev=0.605
k=62 min=2 max=8 mean=3.605 var=0.351 stddev=0.592
k=63 min=2 max=7 mean=3.599 var=0.347 stddev=0.589
rand_popcnt: csum=20 t=4.634s
rand_popcnt_2: csum=3c t=4.079s
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment