Skip to content

Instantly share code, notes, and snippets.

@zmjjmz
Last active February 24, 2020 18:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zmjjmz/6c9184971f8601dd923193a47640d651 to your computer and use it in GitHub Desktop.
Save zmjjmz/6c9184971f8601dd923193a47640d651 to your computer and use it in GitHub Desktop.
Jax shuffle vs. numpy shuffle
jax_jit_correlation jax_jit_time jax_nojit_correlation jax_nojit_time np_copy_correlation np_copy_time np_inplace_correlation np_inplace_time size
0 -0.087945 0.520224 -0.207465 0.515517 -0.086613 0.000577 -0.048701 0.000638 100
1 0.009647 0.409452 -0.009721 0.407500 -0.016213 0.000776 -0.032633 0.000694 1000
2 0.005357 0.652246 0.010250 0.633742 0.002910 0.003117 -0.001414 0.001872 10000
3 -0.002672 1.441106 0.000594 1.425177 0.004118 0.023031 -0.004778 0.017846 100000
4 0.001152 11.028948 0.000698 10.984871 -0.002028 0.228054 0.001034 0.163242 1000000
5 0.000007 250.970225 -0.000202 241.865425 0.000353 3.004593 -0.000141 3.533790 10000000
import time
from jax import random, jit
import pandas as pd
import numpy as np
from tqdm import tqdm
from scipy.stats import spearmanr
def jax_shuffle_nojit(arr, key):
return random.shuffle(key, arr)
jax_shuffle = jit(jax_shuffle_nojit)
def numpy_shuffle_inplace(arr, key):
np.random.seed(int(key[1]))
np.random.shuffle(arr)
return arr
def numpy_shuffle_copy(arr, key):
np.random.seed(int(key[1]))
arr_copy = arr.copy()
np.random.shuffle(arr_copy)
return arr_copy
methods = {
'jax_nojit':jax_shuffle_nojit,
'jax_jit':jax_shuffle,
'np_inplace':numpy_shuffle_inplace,
'np_copy':numpy_shuffle_copy,
}
if __name__ == "__main__":
records = []
start_key = random.PRNGKey(10)
N_SHUFFLES = 10 # we'll do multiple shuffles to hopefully take advantage of jax compilation time
for size in tqdm([10**i for i in range(2,8)], desc='Sizes'):
record = {'size':size}
shuffle_array = np.random.rand(size)
shuffle_array_copy = shuffle_array.copy()
for method, method_fn in methods.items():
_, *keys, start_key = random.split(start_key, N_SHUFFLES + 1)
tic = time.time()
for key in keys:
shuffle_array = method_fn(shuffle_array, key)
shuffle_array = np.array(shuffle_array)
toc = time.time() - tic
record['{0}_time'.format(method)] = toc
record['{0}_correlation'.format(method)] = spearmanr(shuffle_array, shuffle_array_copy)[0]
records.append(record)
records_df = pd.DataFrame.from_records(records).sort_values(by='size')
print(records_df)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment