Skip to content

Instantly share code, notes, and snippets.

@lucidrains
Created June 2, 2021 17:08
Show Gist options
  • Save lucidrains/b46418ec1e7e7d976c20b29397325cb5 to your computer and use it in GitHub Desktop.
Save lucidrains/b46418ec1e7e7d976c20b29397325cb5 to your computer and use it in GitHub Desktop.
faster rng for jax
def hardware_uniform(rng_key: PRNGKey,
shape: Shape,
dtype: Dtype = np.float32,
minval: Array = np.float32(0),
maxval: Array = np.float32(1)) -> Array:
del rng_key # non-deterministic prng.
minval = lax.convert_element_type(minval, dtype)
maxval = lax.convert_element_type(maxval, dtype)
return lax.rng_uniform(minval, maxval, shape)
# For dropout-only hardware rng.
def hardware_bernoulli(rng_key: PRNGKey,
p: np.ndarray = np.float32(0.5),
shape: Shape = None) -> Array:
del rng_key # non-deterministic prng.
return lax.rng_uniform(0.0, 1.0, shape) < p
def set_hardware_rng():
jax.random.bernoulli = hardware_bernoulli
jax.random.uniform = hardware_uniform
jax._src.random.uniform = hardware_uniform
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment