Last active
June 4, 2020 21:04
-
-
Save rkern/f46552e030e59b5f1ebbd3b3ec045759 to your computer and use it in GitHub Desktop.
Verifying numpy#16313
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
import json | |
import sys | |
import numpy as np | |
from numpy.random import PCG64 | |
MASK128 = (1 << 128) - 1 | |
PCG64_MULT = (2549297995355413924 << 64) + 4865540595714422341 | |
PCG64_CHEAP_MULT = 0xda942042e4dd58b5 | |
def pcg64_distance(bg0, bg1, mult=PCG64_MULT): | |
s0 = bg0.state['state']['state'] | |
inc0 = bg0.state['state']['inc'] | |
s1 = bg1.state['state']['state'] | |
inc1 = bg1.state['state']['inc'] | |
if inc1 == inc0: | |
return pcg64_state_distance(s1, s0, inc0, mult=mult) | |
else: | |
diff0 = (inc0 + (mult - 1) * s0) & MASK128 | |
diff1 = (inc1 + (mult - 1) * s1) & MASK128 | |
if (diff0 & 3) != (diff1 & 3): | |
diff1 = (-diff1) & MASK128 | |
return pcg64_state_distance(diff1, diff0, 0, mult=mult) | |
def pcg64_state_distance(state, dest, plus, mult=PCG64_MULT, mask=MASK128): | |
is_mcg = plus == 0 | |
the_bit = 4 if is_mcg else 1 | |
distance = 0 | |
while (state & mask) != (dest & mask): | |
if (state & the_bit) != (dest & the_bit): | |
state = (state * mult + plus) & MASK128 | |
distance |= the_bit | |
assert (state & the_bit) == (dest & the_bit) | |
the_bit <<= 1 | |
plus = ((mult + 1) * plus) & MASK128 | |
mult = (mult * mult) & MASK128 | |
if is_mcg: | |
distance >>= 2 | |
return distance | |
def gen_interleaved_bytes(bitgens, n_per_gen=1024, output_dtype=np.uint32): | |
while True: | |
draws = [g.random_raw(n_per_gen).astype(output_dtype) for g in bitgens] | |
interleaved = np.column_stack(draws).ravel() | |
bytes_chunk = bytes(interleaved.data) | |
yield bytes_chunk | |
def bitgen_from_state(state): | |
cls = getattr(np.random, state['bit_generator']) | |
bitgen = cls() | |
bitgen.state = state | |
return bitgen | |
def dump_states(bitgens, file=sys.stderr): | |
text = json.dumps([g.state for g in bitgens], indent=4) | |
print(text, file=file) | |
def from_json(filename): | |
with open(filename) as f: | |
states = json.load(f) | |
bitgens = [bitgen_from_state(s) for s in states] | |
return bitgens | |
def main(): | |
import argparse | |
parser = argparse.ArgumentParser( | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
parser.add_argument('-j', '--json', | |
help='Load BitGenerators from JSON file.') | |
parser.add_argument('-i', '--same-increment', action='store_true', | |
help='Force the same increment.') | |
parser.add_argument('-m', '--matching-bits', type=int, default=64, | |
help='The number of low-order bits to match.') | |
parser.add_argument('-d', '--dxsm', action='store_true', | |
help='Use the DXSM output function.') | |
parser.add_argument('-a', '--advanced', action='store_true', | |
help='Advance so that they are 0 distance apart.') | |
args = parser.parse_args() | |
if args.json is not None: | |
bitgens = from_json(args.json) | |
else: | |
if args.dxsm: | |
try: | |
from numpy.random import PCG64DXSM | |
except ImportError: | |
raise SystemExit("PCG64DXSM only available on a branch.\n" | |
"https://github.com/rkern/numpy/commit/" | |
"6510712339e5539664ac24c0361c41b091514a69") | |
bg0 = PCG64DXSM() | |
bg1 = PCG64DXSM() | |
MULT = PCG64_CHEAP_MULT | |
else: | |
bg0 = PCG64() | |
bg1 = PCG64() | |
MULT = PCG64_MULT | |
state1 = bg1.state | |
s0 = bg0.state['state']['state'] | |
s1 = bg1.state['state']['state'] | |
same_mask = (1 << args.matching_bits) - 1 | |
random_mask = (MASK128 << args.matching_bits) & MASK128 | |
if args.same_increment: | |
state1['state']['inc'] = bg0.state['state']['inc'] | |
else: | |
# bg0 and bg1 use different increments. To get the proper lower-bit | |
# pattern to apply to the second generator bg1, we first advance | |
# its state so that it has 0 distance from the bg0 state. Then we | |
# use that state to provide the bit-pattern to match. | |
bg1.advance(pcg64_distance(bg0, bg1, mult=MULT)) | |
s0 = bg1.state['state']['state'] | |
state1['state']['state'] = ((s1 & random_mask) | (s0 & same_mask)) | |
bg1.state = state1 | |
if args.advanced: | |
bg1.advance(pcg64_distance(bg0, bg1, mult=MULT)) | |
s0 = bg0.state['state']['state'] | |
s1 = bg1.state['state']['state'] | |
print(f"s0 = 0b{s0:0128b}", file=sys.stderr) | |
print(f"s1 = 0b{s1:0128b}", file=sys.stderr) | |
dist = pcg64_distance(bg0, bg1, mult=MULT) | |
print(f"dist = 0x{dist:032x}", file=sys.stderr) | |
bitgens = [bg0, bg1] | |
dump_states(bitgens) | |
for chunk in gen_interleaved_bytes(bitgens): | |
sys.stdout.buffer.write(chunk) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment