Skip to content

Instantly share code, notes, and snippets.

@rgommers
Created June 1, 2022 17:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rgommers/997ab5a287e9a19c345771d2f5712574 to your computer and use it in GitHub Desktop.
Save rgommers/997ab5a287e9a19c345771d2f5712574 to your computer and use it in GitHub Desktop.
Comparing JAX and NumPy APIs for random number generation - serial and parallel
"""
Implement `jax.random` APIs with NumPy, and `numpy.random` APIs with JAX.
The purpose of this is to be able to compare APIs more easily, and clarify
where they are and aren't similar.
"""
import secrets
import multiprocessing
import numpy as np
import jax
USE_FIXED_SEED = False
if USE_FIXED_SEED:
seed = 38968222334307
else:
# Generate a random high-entropy seed for use in the below examples
# jax.random.PRNGKey doesn't accept None to do this automatically
seed = secrets.randbits(32) # JAX can't deal with >32-bits
# NumPy serial
rng = np.random.default_rng(seed=seed)
vals = rng.uniform(size=3)
val = rng.uniform(size=1)
# NumPy parallel
sseq = np.random.SeedSequence(entropy=seed)
child_seeds = sseq.spawn(4)
rngs = [np.random.default_rng(seed=s) for s in child_seeds]
def use_rngs_numpy(rng):
vals = rng.uniform(size=3)
val = rng.uniform(size=1)
print(vals, val)
def main_numpy():
with multiprocessing.Pool(processes=4) as pool:
pool.map(use_rngs_numpy, rngs)
# JAX serial (also auto-parallelizes fine by design)
key = jax.random.PRNGKey(seed)
key, subkey = jax.random.split(key) # this one could be left out, but best practice is probably to always use `split` first
vals = jax.random.uniform(subkey, shape=(3,))
key, subkey = jax.random.split(key) # don't forget this!
val = jax.random.uniform(subkey, shape=(1,))
# JAX parallel with multiprocessing
def use_rngs_jax(key):
key, subkey = jax.random.split(key)
vals = jax.random.uniform(subkey, shape=(3,))
key, subkey = jax.random.split(key)
val = jax.random.uniform(subkey, shape=(1,))
print(vals, val)
def main_jax():
key = jax.random.PRNGKey(seed)
key, *subkeys = jax.random.split(key, 5) # gotcha: "5" gives us 4 subkeys
with multiprocessing.Pool(processes=4) as pool:
pool.map(use_rngs_jax, subkeys)
# An API matching JAX on top of `numpy.random`
##############################################
def PRNGKey(seed):
"""
Create a key from a seed. `seed` must be a 32-bit (or 64-bit?) integer.
"""
# Note: selecting a non-default PRNG algorithm is done via a global config
# flag (not good, should be a keyword or similar ...)
seed = np.random.SeedSequence(seed)
rng = np.random.default_rng(seed)
key = (seed, rng)
return key
def split(key, num=2):
"""
Parameters
----------
key : tuple
Size-2 tuple, the first element a `SeedSequence` instance, the second
containing the algorithm selector.
num : int, optional
The number of keys to produce (default: 2).
Returns
-------
keys : tuple of 2-tuples
`num` number of keys (each key being a 2-tuple)
"""
seed, rng = key
child_seeds = seed.spawn(num)
keys = ((s, rng) for s in child_seeds)
return keys
def uniform(key, shape=(), dtype=np.float64, minval=0.0, maxval=1.0):
seed, rng = key
# Creating a new Generator instance from an old one with the same
# underlying BitGenerator type requires using non-public API:
rng = np.random.Generator(rng._bit_generator.__class__(seed))
return rng.uniform(low=minval, high=maxval, size=shape).astype(dtype)
def use_jaxlike_api(key=None):
if key is None:
key = PRNGKey(seed)
key, subkey = split(key)
vals = uniform(subkey, shape=(3,))
key, subkey = split(key) # don't forget this!
val = uniform(subkey, shape=(1,))
print(vals, val)
def use_jaxlike_api_mp():
key = PRNGKey(seed)
key, *subkeys = split(key, 5)
with multiprocessing.Pool(processes=4) as pool:
pool.map(use_jaxlike_api, subkeys)
if __name__ == '__main__':
# JAX does not work with the default `fork` (due to internal threading)
multiprocessing.set_start_method('forkserver')
print('\nNumPy with multiprocessing:\n')
main_numpy()
print('\n\nJAX with multiprocessing:\n')
main_jax()
print('\n\nUse JAX-like API (serial):\n')
use_jaxlike_api()
print('\n\nUse JAX-like API (multiprocessing):\n')
use_jaxlike_api_mp()
# Gotcha with seed creation:
"""
In [24]: seed = secrets.randbits(64)
In [25]: jax.random.PRNGKey(seed)
---------------------------------------------------------------------------
OverflowError Traceback (most recent call last)
<ipython-input-25-7a8d328c270c> in <module>
----> 1 jax.random.PRNGKey(seed)
~/anaconda3/envs/many-libs/lib/python3.7/site-packages/jax/_src/random.py in PRNGKey(seed)
57 # Explicitly cast to int64 for JIT invariance of behavior on large ints.
58 if isinstance(seed, int):
---> 59 seed = np.int64(seed)
60 # Converting to jnp.array may truncate bits when jax_enable_x64=False, but this
61 # is necessary for the sake of JIT invariance of the result for such values.
OverflowError: Python int too large to convert to C long
In [26]: seed = secrets.randbits(128)
In [27]: jax.random.PRNGKey(seed)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-27-7a8d328c270c> in <module>
----> 1 jax.random.PRNGKey(seed)
~/anaconda3/envs/many-libs/lib/python3.7/site-packages/jax/_src/random.py in PRNGKey(seed)
53 raise TypeError(f"PRNGKey seed must be a scalar; got {seed!r}.")
54 if not np.issubdtype(np.result_type(seed), np.integer):
---> 55 raise TypeError(f"PRNGKey seed must be an integer; got {seed!r}")
56
57 # Explicitly cast to int64 for JIT invariance of behavior on large ints.
TypeError: PRNGKey seed must be an integer; got 67681183633192462759155065893448052088
In [28]: seed = secrets.randbits(64)
In [29]: jax.random.PRNGKey(seed)
---------------------------------------------------------------------------
OverflowError Traceback (most recent call last)
<ipython-input-29-7a8d328c270c> in <module>
----> 1 jax.random.PRNGKey(seed)
~/anaconda3/envs/many-libs/lib/python3.7/site-packages/jax/_src/random.py in PRNGKey(seed)
57 # Explicitly cast to int64 for JIT invariance of behavior on large ints.
58 if isinstance(seed, int):
---> 59 seed = np.int64(seed)
60 # Converting to jnp.array may truncate bits when jax_enable_x64=False, but this
61 # is necessary for the sake of JIT invariance of the result for such values.
OverflowError: Python int too large to convert to C long
In [30]: seed = secrets.randbits(32)
In [31]: jax.random.PRNGKey(seed)
Out[31]: DeviceArray([ 0, 3279739543], dtype=uint32)
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment