Skip to content

Instantly share code, notes, and snippets.

@rkern
Last active June 4, 2020 21:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rkern/f46552e030e59b5f1ebbd3b3ec045759 to your computer and use it in GitHub Desktop.
Save rkern/f46552e030e59b5f1ebbd3b3ec045759 to your computer and use it in GitHub Desktop.
Verifying numpy#16313
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import json
import sys
import numpy as np
from numpy.random import PCG64
MASK128 = (1 << 128) - 1
PCG64_MULT = (2549297995355413924 << 64) + 4865540595714422341
PCG64_CHEAP_MULT = 0xda942042e4dd58b5
def pcg64_distance(bg0, bg1, mult=PCG64_MULT):
s0 = bg0.state['state']['state']
inc0 = bg0.state['state']['inc']
s1 = bg1.state['state']['state']
inc1 = bg1.state['state']['inc']
if inc1 == inc0:
return pcg64_state_distance(s1, s0, inc0, mult=mult)
else:
diff0 = (inc0 + (mult - 1) * s0) & MASK128
diff1 = (inc1 + (mult - 1) * s1) & MASK128
if (diff0 & 3) != (diff1 & 3):
diff1 = (-diff1) & MASK128
return pcg64_state_distance(diff1, diff0, 0, mult=mult)
def pcg64_state_distance(state, dest, plus, mult=PCG64_MULT, mask=MASK128):
is_mcg = plus == 0
the_bit = 4 if is_mcg else 1
distance = 0
while (state & mask) != (dest & mask):
if (state & the_bit) != (dest & the_bit):
state = (state * mult + plus) & MASK128
distance |= the_bit
assert (state & the_bit) == (dest & the_bit)
the_bit <<= 1
plus = ((mult + 1) * plus) & MASK128
mult = (mult * mult) & MASK128
if is_mcg:
distance >>= 2
return distance
def gen_interleaved_bytes(bitgens, n_per_gen=1024, output_dtype=np.uint32):
while True:
draws = [g.random_raw(n_per_gen).astype(output_dtype) for g in bitgens]
interleaved = np.column_stack(draws).ravel()
bytes_chunk = bytes(interleaved.data)
yield bytes_chunk
def bitgen_from_state(state):
cls = getattr(np.random, state['bit_generator'])
bitgen = cls()
bitgen.state = state
return bitgen
def dump_states(bitgens, file=sys.stderr):
text = json.dumps([g.state for g in bitgens], indent=4)
print(text, file=file)
def from_json(filename):
with open(filename) as f:
states = json.load(f)
bitgens = [bitgen_from_state(s) for s in states]
return bitgens
def main():
import argparse
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-j', '--json',
help='Load BitGenerators from JSON file.')
parser.add_argument('-i', '--same-increment', action='store_true',
help='Force the same increment.')
parser.add_argument('-m', '--matching-bits', type=int, default=64,
help='The number of low-order bits to match.')
parser.add_argument('-d', '--dxsm', action='store_true',
help='Use the DXSM output function.')
parser.add_argument('-a', '--advanced', action='store_true',
help='Advance so that they are 0 distance apart.')
args = parser.parse_args()
if args.json is not None:
bitgens = from_json(args.json)
else:
if args.dxsm:
try:
from numpy.random import PCG64DXSM
except ImportError:
raise SystemExit("PCG64DXSM only available on a branch.\n"
"https://github.com/rkern/numpy/commit/"
"6510712339e5539664ac24c0361c41b091514a69")
bg0 = PCG64DXSM()
bg1 = PCG64DXSM()
MULT = PCG64_CHEAP_MULT
else:
bg0 = PCG64()
bg1 = PCG64()
MULT = PCG64_MULT
state1 = bg1.state
s0 = bg0.state['state']['state']
s1 = bg1.state['state']['state']
same_mask = (1 << args.matching_bits) - 1
random_mask = (MASK128 << args.matching_bits) & MASK128
if args.same_increment:
state1['state']['inc'] = bg0.state['state']['inc']
else:
# bg0 and bg1 use different increments. To get the proper lower-bit
# pattern to apply to the second generator bg1, we first advance
# its state so that it has 0 distance from the bg0 state. Then we
# use that state to provide the bit-pattern to match.
bg1.advance(pcg64_distance(bg0, bg1, mult=MULT))
s0 = bg1.state['state']['state']
state1['state']['state'] = ((s1 & random_mask) | (s0 & same_mask))
bg1.state = state1
if args.advanced:
bg1.advance(pcg64_distance(bg0, bg1, mult=MULT))
s0 = bg0.state['state']['state']
s1 = bg1.state['state']['state']
print(f"s0 = 0b{s0:0128b}", file=sys.stderr)
print(f"s1 = 0b{s1:0128b}", file=sys.stderr)
dist = pcg64_distance(bg0, bg1, mult=MULT)
print(f"dist = 0x{dist:032x}", file=sys.stderr)
bitgens = [bg0, bg1]
dump_states(bitgens)
for chunk in gen_interleaved_bytes(bitgens):
sys.stdout.buffer.write(chunk)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment