Last active
June 19, 2020 13:54
-
-
Save willprice/61c2740dae3d9d00931f6e8d1f1a21c5 to your computer and use it in GitHub Desktop.
Obtain ranks for each row in a 2D array in numpy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def get_row_ranks(xs): | |
""" | |
Args: | |
xs: array of shape :math:`(N, E)` | |
Returns: | |
Ranks for each row of ``xs``. A rank of 0 is the highest value, and a rank of ``E - 1`` is the lowest value. | |
Examples: | |
>>> get_row_ranks(np.array([ ... | |
[1, 2, 3], ... | |
[1, 3, 2], ... | |
[2, 1, 3], ... | |
[2, 3, 1], ... | |
[3, 1, 2], ... | |
[3, 2, 1], ... | |
])) | |
array([[2, 1, 0], | |
[2, 0, 1], | |
[1, 2, 0], | |
[1, 0, 2], | |
[0, 2, 1], | |
[0, 1, 2]]) | |
""" | |
assert xs.ndim <= 2 | |
n_elements = xs.shape[-1] | |
sorted_idxs = xs.argsort(axis=-1) | |
return (n_elements - 1) - sorted_idxs.argsort(axis=-1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment