Skip to content

Instantly share code, notes, and snippets.

@rkern
Created June 6, 2022 16:51
Show Gist options
  • Save rkern/9361aa15a28ae8c6dced01840209cdbb to your computer and use it in GitHub Desktop.
Save rkern/9361aa15a28ae8c6dced01840209cdbb to your computer and use it in GitHub Desktop.
from itertools import cycle
import re
from secrets import randbits
import numpy as np
cimport numpy as np
np.import_array()
DECIMAL_RE = re.compile(r'[0-9]+')
cdef uint32_t DEFAULT_POOL_SIZE = 4 # Appears also in docstring for pool_size
cdef uint32_t INIT_A = 0x43b0d7e5
cdef uint32_t MULT_A = 0x931e8875
cdef uint32_t INIT_B = 0x8b51f9dd
cdef uint32_t MULT_B = 0x58f38ded
cdef uint32_t MIX_MULT_L = 0xca01f9dd
cdef uint32_t MIX_MULT_R = 0x4973f715
cdef uint32_t XSHIFT = np.dtype(np.uint32).itemsize * 8 // 2
cdef uint32_t MASK32 = 0xFFFFFFFF
def _int_to_uint32_array(n):
arr = []
if n < 0:
raise ValueError("expected non-negative integer")
if n == 0:
arr.append(np.uint32(n))
if isinstance(n, np.unsignedinteger):
# Cannot do n & MASK32, convert to python int
n = int(n)
while n > 0:
arr.append(np.uint32(n & MASK32))
n //= (2**32)
return np.array(arr, dtype=np.uint32)
def _coerce_to_uint32_array(x):
""" Coerce an input to a uint32 array.
If a `uint32` array, pass it through directly.
If a non-negative integer, then break it up into `uint32` words, lowest
bits first.
If a string starting with "0x", then interpret as a hex integer, as above.
If a string of decimal digits, interpret as a decimal integer, as above.
If a sequence of ints or strings, interpret each element as above and
concatenate.
Note that the handling of `int64` or `uint64` arrays are not just
straightforward views as `uint32` arrays. If an element is small enough to
fit into a `uint32`, then it will only take up one `uint32` element in the
output. This is to make sure that the interpretation of a sequence of
integers is the same regardless of numpy's default integer type, which
differs on different platforms.
Parameters
----------
x : int, str, sequence of int or str
Returns
-------
seed_array : uint32 array
Examples
--------
>>> import numpy as np
>>> from numpy.random.bit_generator import _coerce_to_uint32_array
>>> _coerce_to_uint32_array(12345)
array([12345], dtype=uint32)
>>> _coerce_to_uint32_array('12345')
array([12345], dtype=uint32)
>>> _coerce_to_uint32_array('0x12345')
array([74565], dtype=uint32)
>>> _coerce_to_uint32_array([12345, '67890'])
array([12345, 67890], dtype=uint32)
>>> _coerce_to_uint32_array(np.array([12345, 67890], dtype=np.uint32))
array([12345, 67890], dtype=uint32)
>>> _coerce_to_uint32_array(np.array([12345, 67890], dtype=np.int64))
array([12345, 67890], dtype=uint32)
>>> _coerce_to_uint32_array([12345, 0x10deadbeef, 67890, 0xdeadbeef])
array([ 12345, 3735928559, 16, 67890, 3735928559],
dtype=uint32)
>>> _coerce_to_uint32_array(1234567890123456789012345678901234567890)
array([3460238034, 2898026390, 3235640248, 2697535605, 3],
dtype=uint32)
"""
if isinstance(x, np.ndarray) and x.dtype == np.dtype(np.uint32):
return x.copy()
elif isinstance(x, str):
if x.startswith('0x'):
x = int(x, base=16)
elif DECIMAL_RE.match(x):
x = int(x)
else:
raise ValueError("unrecognized seed string")
if isinstance(x, (int, np.integer)):
return _int_to_uint32_array(x)
elif isinstance(x, (float, np.inexact)):
raise TypeError('seed must be integer')
else:
if len(x) == 0:
return np.array([], dtype=np.uint32)
# Should be a sequence of interpretable-as-ints. Convert each one to
# a uint32 array and concatenate.
subseqs = [_coerce_to_uint32_array(v) for v in x]
return np.concatenate(subseqs)
cdef uint32_t hashmix(uint32_t value, uint32_t * hash_const):
# We are modifying the multiplier as we go along, so it is input-output
value ^= hash_const[0]
hash_const[0] *= MULT_A
value *= hash_const[0]
value ^= value >> XSHIFT
return value
cdef uint32_t mix(uint32_t x, uint32_t y):
cdef uint32_t result = (MIX_MULT_L * x - MIX_MULT_R * y)
result ^= result >> XSHIFT
return result
cdef class SplitSeed():
"""
SplitSeed(entropy=None, *, pool_size=4)
`SplitSeed` mixes sources of entropy in a reproducible way to set the
initial state for independent and very probably non-overlapping
BitGenerators.
Once the `SplitSeed` is instantiated, you can call the `generate_state`
method to get an appropriately sized seed. Calling `split(n) <split>` will
create ``n`` SplitSeeds that can be used to seed independent
BitGenerators, i.e. for different threads. Unlike `SeedSequence.spawn`,
calling `split(n) <split>` multiple times will return the *same* results
for a more pure functional API.
Parameters
----------
entropy : {None, int, sequence[int]}, optional
The entropy for initially creating a `SplitSeed`. If `pool` is
provided, this will be stored but not used, and will simply reflect the
value that was used at the root of the split tree. The splitting path
is not stored.
pool_size : {int}, optional
Size of the pooled entropy to store. Default is 4 to give a 128-bit
entropy pool. 8 (for 256 bits) is another reasonable choice if working
with larger PRNGs, but there is very little to be gained by selecting
another value.
pool : uint32 array, optional
The internal hash pool. Only pass this if reconstructing a `SplitSeed`
from a serialized form.
hash_const : uint32, optional
The internal hash constant for mixing in new entropy. Only pass this if
reconstructing a `SplitSeed` from a serialized form.
"""
def __init__(self, entropy=None, *, pool_size=DEFAULT_POOL_SIZE, pool=None,
hash_const=None):
# FIXME: ignore this for now so we can experiment with smaller pool
# sizes.
# if pool_size < DEFAULT_POOL_SIZE:
# raise ValueError("The size of the entropy pool should be at least "
# f"{DEFAULT_POOL_SIZE}")
if entropy is None:
entropy = randbits(pool_size * 32)
elif not isinstance(entropy, (int, np.integer, list, tuple, range,
np.ndarray)):
raise TypeError('SeedSequence expects int or sequence of ints for '
'entropy not {}'.format(entropy))
self.entropy = entropy
self.pool_size = pool_size
if hash_const is None:
hash_const = INIT_A
self.hash_const = hash_const
if pool is None:
self.pool = np.zeros(pool_size, dtype=np.uint32)
self.mix_entropy(self.pool, self.get_assembled_entropy())
else:
self.pool = pool.copy()
def __repr__(self):
lines = [
f'{type(self).__name__}(',
f' entropy={self.entropy!r},',
f' pool={self.pool!r},',
f' hash_const={self.hash_const!r},',
]
# Omit some entries if they are left as the defaults in order to
# simplify things.
if self.pool_size != DEFAULT_POOL_SIZE:
lines.append(f' pool_size={self.pool_size!r},')
lines.append(')')
text = '\n'.join(lines)
return text
@property
def state(self):
return {k:getattr(self, k) for k in
['entropy', 'pool_size', 'pool',
'hash_const']
if getattr(self, k) is not None}
cdef mix_entropy(self, np.ndarray[np.npy_uint32, ndim=1] mixer,
np.ndarray[np.npy_uint32, ndim=1] entropy_array):
""" Mix in the given entropy to mixer.
Parameters
----------
mixer : 1D uint32 array, modified in-place
entropy_array : 1D uint32 array
"""
cdef uint32_t hash_const[1]
hash_const[0] = INIT_A
# Add in the entropy up to the pool size.
for i in range(len(mixer)):
if i < len(entropy_array):
mixer[i] = hashmix(entropy_array[i], hash_const)
else:
# Our pool size is bigger than our entropy, so just keep
# running the hash out.
mixer[i] = hashmix(0, hash_const)
# Mix all bits together so late bits can affect earlier bits.
for i_src in range(len(mixer)):
for i_dst in range(len(mixer)):
if i_src != i_dst:
mixer[i_dst] = mix(mixer[i_dst],
hashmix(mixer[i_src], hash_const))
# Add any remaining entropy, mixing each new entropy word with each
# pool word.
for i_src in range(len(mixer), len(entropy_array)):
for i_dst in range(len(mixer)):
mixer[i_dst] = mix(mixer[i_dst],
hashmix(entropy_array[i_src], hash_const))
self.hash_const = hash_const[0]
cdef mix_split_key(self, uint32_t i_split):
cdef int i_dst
cdef uint32_t hash_const[1]
cdef np.ndarray[np.npy_uint32, ndim=1] mixer = self.pool
hash_const[0] = self.hash_const
for i_dst in range(len(mixer)):
mixer[i_dst] = mix(mixer[i_dst],
hashmix(i_split, hash_const))
self.hash_const = hash_const[0]
cpdef get_assembled_entropy(self):
""" Convert and assemble all entropy sources into a uniform uint32
array.
Returns
-------
entropy_array : 1D uint32 array
"""
# Convert run-entropy and the spawn key into uint32
# arrays and concatenate them.
# We MUST have at least some run-entropy. The others are optional.
assert self.entropy is not None
run_entropy = _coerce_to_uint32_array(self.entropy)
if len(run_entropy) < self.pool_size:
# Explicitly fill out the entropy with 0s to the pool size to avoid
# conflict with spawn keys.
diff = self.pool_size - len(run_entropy)
run_entropy = np.concatenate(
[run_entropy, np.zeros(diff, dtype=np.uint32)])
entropy_array = run_entropy
return entropy_array
@np.errstate(over='ignore')
def generate_state(self, n_words, dtype=np.uint32):
"""
generate_state(n_words, dtype=np.uint32)
Return the requested number of words for PRNG seeding.
A BitGenerator should call this method in its constructor with
an appropriate `n_words` parameter to properly seed itself.
Parameters
----------
n_words : int
dtype : np.uint32 or np.uint64, optional
The size of each word. This should only be either `uint32` or
`uint64`. Strings (`'uint32'`, `'uint64'`) are fine. Note that
requesting `uint64` will draw twice as many bits as `uint32` for
the same `n_words`. This is a convenience for `BitGenerator`s that
express their states as `uint64` arrays.
Returns
-------
state : uint32 or uint64 array, shape=(n_words,)
"""
cdef uint32_t hash_const = INIT_B
cdef uint32_t data_val
out_dtype = np.dtype(dtype)
if out_dtype == np.dtype(np.uint32):
pass
elif out_dtype == np.dtype(np.uint64):
n_words *= 2
else:
raise ValueError("only support uint32 or uint64")
state = np.zeros(n_words, dtype=np.uint32)
src_cycle = cycle(self.pool)
for i_dst in range(n_words):
data_val = next(src_cycle)
data_val ^= hash_const
hash_const *= MULT_B
data_val *= hash_const
data_val ^= data_val >> XSHIFT
state[i_dst] = data_val
if out_dtype == np.dtype(np.uint64):
# For consistency across different endiannesses, view first as
# little-endian then convert the values to the native endianness.
state = state.astype('<u4').view('<u8').astype(np.uint64)
return state
def split(self, n_children):
"""
split(n_children)
Split off a number of child `SplitSeed` s by mixing in different
numbers into the entropy pool for each sub-stream.
Unlike `SeedSequence.spawn`, this method is idempotent. Calling it
multiple times will return the same `SplitSeed` values.
Parameters
----------
n_children : int
Returns
-------
seqs : list of `SplitSeed` s
"""
cdef uint32_t i_split
cdef SplitSeed ss
seqs = []
for i_split in range(n_children):
ss = type(self)(
self.entropy,
pool=self.pool,
hash_const=self.hash_const,
pool_size=self.pool_size,
)
ss.mix_split_key(i_split)
seqs.append(ss)
return seqs
np.random.bit_generator.ISeedSequence.register(SplitSeed)
cdef inline uint32_t rotate_left32(uint32_t x, uint32_t r):
return (x << r) | (x >> (32 - r))
cdef apply_round(uint32_t *x1, uint32_t *x2, uint32_t r):
cdef uint32_t y1, y2
y1 = x1[0]
y2 = x2[0]
y1 = y1 + y2
y2 = rotate_left32(y2, r)
y2 = y1 ^ y2
x1[0] = y1
x2[0] = y2
cpdef threefry2x32(uint32_t key1, uint32_t key2, uint32_t x1, uint32_t x2):
cdef uint32_t key3 = key1 ^ key2 ^ <uint32_t>(0x1BD11BDA)
x1 += key1
x2 += key2
apply_round(&x1, &x2, 13)
apply_round(&x1, &x2, 15)
apply_round(&x1, &x2, 26)
apply_round(&x1, &x2, 6)
x1 += key2
x2 += key3 + <uint32_t>(1)
apply_round(&x1, &x2, 17)
apply_round(&x1, &x2, 29)
apply_round(&x1, &x2, 16)
apply_round(&x1, &x2, 24)
x1 += key3
x2 += key1 + <uint32_t>(2)
apply_round(&x1, &x2, 13)
apply_round(&x1, &x2, 15)
apply_round(&x1, &x2, 26)
apply_round(&x1, &x2, 6)
x1 += key1
x2 += key2 + <uint32_t>(3)
apply_round(&x1, &x2, 17)
apply_round(&x1, &x2, 29)
apply_round(&x1, &x2, 16)
apply_round(&x1, &x2, 24)
x1 += key2
x2 += key3 + <uint32_t>(4)
apply_round(&x1, &x2, 13)
apply_round(&x1, &x2, 15)
apply_round(&x1, &x2, 26)
apply_round(&x1, &x2, 6)
x1 += key3
x2 += key1 + <uint32_t>(5)
return (x1, x2)
cpdef iterate_jax_key(np.ndarray[np.uint32_t, ndim=1] key):
"""Faithful implementation of ``jax.random.split(key)[0]``
"""
cdef uint32_t key1, key2
key1, key2 = key
out1, _ = threefry2x32(key1, key2, 0, 2)
out2, _ = threefry2x32(key1, key2, 1, 3)
return np.array([out1, out2], dtype=np.uint32)
cpdef fixed_iterate_jax_key(np.ndarray[np.uint32_t, ndim=1] key):
"""JAX key iteration if the ``threefry_random_bits()`` quirk were fixed.
"""
cdef uint32_t key1, key2
key1, key2 = key
out1, out2 = threefry2x32(key1, key2, 0, 0)
return np.array([out1, out2], dtype=np.uint32)
cpdef bijective_iterate_jax_key(np.ndarray[np.uint32_t, ndim=1] key):
"""Putative bijective key splitting (left branch).
"""
cdef uint32_t key1, key2
key1, key2 = key
out1, out2 = threefry2x32(0, 0, key1, key2)
return np.array([out1, out2], dtype=np.uint32)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment