Skip to content

Instantly share code, notes, and snippets.

@kernel1994
Last active October 26, 2021 01:06
Show Gist options
  • Save kernel1994/c8929d3d1465a22a3d55996a445ddbc7 to your computer and use it in GitHub Desktop.
Save kernel1994/c8929d3d1465a22a3d55996a445ddbc7 to your computer and use it in GitHub Desktop.
Top k element index in numpy array
import numpy as np
def largest_indices(array: np.ndarray, n: int) -> tuple:
"""Returns the n largest indices from a numpy array.
Arguments:
array {np.ndarray} -- data array
n {int} -- number of elements to select
Returns:
tuple[np.ndarray, np.ndarray] -- tuple of ndarray
each ndarray is index
"""
flat = array.flatten()
indices = np.argpartition(flat, -n)[-n:]
indices = indices[np.argsort(-flat[indices])]
return np.unravel_index(indices, array.shape)
def least_indices(array: np.ndarray, n: int) -> tuple:
"""Returns the n least indices from a numpy array.
Arguments:
array {np.ndarray} -- data array
n {int} -- number of elements to select
Returns:
tuple[np.ndarray, np.ndarray] -- tuple of ndarray
each ndarray is index
"""
flat = array.flatten()
indices = np.argpartition(flat, n)[:n]
indices = indices[np.argsort(flat[indices])]
return np.unravel_index(indices, array.shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment