Created
May 3, 2023 03:55
-
-
Save 338rajesh/1fe47cb8f6b372e5c74e94f51f1682c6 to your computer and use it in GitHub Desktop.
Shuffles multiple arrays concurrently
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
# shuffling arrays concurrently | |
def shuffle_arrays(*arr, order=1): | |
num_rows = arr[0].shape[0] | |
assert all([i.shape[0] == num_rows for i in arr]) | |
indicies = np.array(list(range(num_rows)), dtype=np.int64) | |
sorted_arrays = list(arr) | |
for _ in range(order): | |
np.random.shuffle(indicies) | |
for (k, a_arr) in enumerate(arr): | |
sorted_arrays[k] = arr[k][indicies] | |
return tuple(sorted_arrays) | |
if __name__ == '__main__': | |
X = np.random.rand(5, 4) | |
Y = np.random.rand(5, 2) | |
print(np.concatenate((X, Y), axis=1)) | |
X, Y = shuffle_arrays(X, Y, order=5) # returns arrays shuffled once | |
# X, Y = shuffle_arrays(X, Y, order=0) # returns the same arrays | |
print(np.concatenate((X, Y), axis=1)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment