Created
January 24, 2020 17:22
-
-
Save zmjjmz/099afcefe9ea3cf3e8015dc8c2e7a340 to your computer and use it in GitHub Desktop.
Jax random shuffle vs. numpy random shuffle
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as onp | |
import jax.numpy as jnp | |
import jax.random as jrand | |
from jax import jit | |
@jit | |
def jax_shuffler(all_inputs, key): | |
shuffled_input = jrand.shuffle(key, all_inputs, axis=0) | |
return shuffled_input | |
def onp_shuffler(all_inputs): | |
# numpy random shuffle only shuffles along the first axis | |
copied_inputs = all_inputs[:] | |
onp.random.shuffle(copied_inputs) | |
return copied_inputs | |
if __name__ == "__main__": | |
inputs = onp.arange(10).reshape(5,2) | |
og_diffs = tuple(sorted(inputs[:,0] - inputs[:,1])) | |
jax_shuffled = jax_shuffler(inputs, jrand.PRNGKey(0)) | |
jax_diffs = tuple(sorted(jax_shuffled[:,0] - jax_shuffled[:,1])) | |
onp.random.seed(0) | |
onp_shuffled = onp_shuffler(inputs) | |
onp_diffs = tuple(sorted(onp_shuffled[:,0] - onp_shuffled[:,1])) | |
try: | |
assert(og_diffs == onp_diffs) | |
except AssertionError: | |
print("ONP shuffle did not behave as expected!") | |
print("ONP shuffled: {0}".format(onp_shuffled)) | |
print("Original: {0}".format(inputs)) | |
try: | |
assert(og_diffs == jax_diffs) | |
except AssertionError: | |
print("Jax shuffle did not behave as expected!") | |
print("Jax shuffled: {0}".format(jax_shuffled)) | |
print("Original: {0}".format(inputs)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment