Skip to content

Instantly share code, notes, and snippets.

@zmjjmz
Created January 24, 2020 17:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zmjjmz/099afcefe9ea3cf3e8015dc8c2e7a340 to your computer and use it in GitHub Desktop.
Save zmjjmz/099afcefe9ea3cf3e8015dc8c2e7a340 to your computer and use it in GitHub Desktop.
Jax random shuffle vs. numpy random shuffle
import numpy as onp
import jax.numpy as jnp
import jax.random as jrand
from jax import jit
@jit
def jax_shuffler(all_inputs, key):
shuffled_input = jrand.shuffle(key, all_inputs, axis=0)
return shuffled_input
def onp_shuffler(all_inputs):
# numpy random shuffle only shuffles along the first axis
copied_inputs = all_inputs[:]
onp.random.shuffle(copied_inputs)
return copied_inputs
if __name__ == "__main__":
inputs = onp.arange(10).reshape(5,2)
og_diffs = tuple(sorted(inputs[:,0] - inputs[:,1]))
jax_shuffled = jax_shuffler(inputs, jrand.PRNGKey(0))
jax_diffs = tuple(sorted(jax_shuffled[:,0] - jax_shuffled[:,1]))
onp.random.seed(0)
onp_shuffled = onp_shuffler(inputs)
onp_diffs = tuple(sorted(onp_shuffled[:,0] - onp_shuffled[:,1]))
try:
assert(og_diffs == onp_diffs)
except AssertionError:
print("ONP shuffle did not behave as expected!")
print("ONP shuffled: {0}".format(onp_shuffled))
print("Original: {0}".format(inputs))
try:
assert(og_diffs == jax_diffs)
except AssertionError:
print("Jax shuffle did not behave as expected!")
print("Jax shuffled: {0}".format(jax_shuffled))
print("Original: {0}".format(inputs))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment