Last active
February 24, 2020 18:50
-
-
Save zmjjmz/6c9184971f8601dd923193a47640d651 to your computer and use it in GitHub Desktop.
Jax shuffle vs. numpy shuffle
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
jax_jit_correlation jax_jit_time jax_nojit_correlation jax_nojit_time np_copy_correlation np_copy_time np_inplace_correlation np_inplace_time size | |
0 -0.087945 0.520224 -0.207465 0.515517 -0.086613 0.000577 -0.048701 0.000638 100 | |
1 0.009647 0.409452 -0.009721 0.407500 -0.016213 0.000776 -0.032633 0.000694 1000 | |
2 0.005357 0.652246 0.010250 0.633742 0.002910 0.003117 -0.001414 0.001872 10000 | |
3 -0.002672 1.441106 0.000594 1.425177 0.004118 0.023031 -0.004778 0.017846 100000 | |
4 0.001152 11.028948 0.000698 10.984871 -0.002028 0.228054 0.001034 0.163242 1000000 | |
5 0.000007 250.970225 -0.000202 241.865425 0.000353 3.004593 -0.000141 3.533790 10000000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
from jax import random, jit | |
import pandas as pd | |
import numpy as np | |
from tqdm import tqdm | |
from scipy.stats import spearmanr | |
def jax_shuffle_nojit(arr, key): | |
return random.shuffle(key, arr) | |
jax_shuffle = jit(jax_shuffle_nojit) | |
def numpy_shuffle_inplace(arr, key): | |
np.random.seed(int(key[1])) | |
np.random.shuffle(arr) | |
return arr | |
def numpy_shuffle_copy(arr, key): | |
np.random.seed(int(key[1])) | |
arr_copy = arr.copy() | |
np.random.shuffle(arr_copy) | |
return arr_copy | |
methods = { | |
'jax_nojit':jax_shuffle_nojit, | |
'jax_jit':jax_shuffle, | |
'np_inplace':numpy_shuffle_inplace, | |
'np_copy':numpy_shuffle_copy, | |
} | |
if __name__ == "__main__": | |
records = [] | |
start_key = random.PRNGKey(10) | |
N_SHUFFLES = 10 # we'll do multiple shuffles to hopefully take advantage of jax compilation time | |
for size in tqdm([10**i for i in range(2,8)], desc='Sizes'): | |
record = {'size':size} | |
shuffle_array = np.random.rand(size) | |
shuffle_array_copy = shuffle_array.copy() | |
for method, method_fn in methods.items(): | |
_, *keys, start_key = random.split(start_key, N_SHUFFLES + 1) | |
tic = time.time() | |
for key in keys: | |
shuffle_array = method_fn(shuffle_array, key) | |
shuffle_array = np.array(shuffle_array) | |
toc = time.time() - tic | |
record['{0}_time'.format(method)] = toc | |
record['{0}_correlation'.format(method)] = spearmanr(shuffle_array, shuffle_array_copy)[0] | |
records.append(record) | |
records_df = pd.DataFrame.from_records(records).sort_values(by='size') | |
print(records_df) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment