Skip to content

Instantly share code, notes, and snippets.

@dkirkby
Last active January 2, 2022 19:46
Show Gist options
  • Save dkirkby/a36ef2b097710db5500b7a17a01028bf to your computer and use it in GitHub Desktop.
Save dkirkby/a36ef2b097710db5500b7a17a01028bf to your computer and use it in GitHub Desktop.
# See https://stackoverflow.com/questions/6910641/how-do-i-get-indices-of-n-maximum-values-in-a-numpy-array
# argpartition requires numpy >= 1.8.0
# See also http://seanlaw.github.io/2020/01/03/finding-top-or-bottom-k-in-a-numpy-array/
def argkmax1(a, k, axis=-1):
"""Return the indices of the k largest elements of a.
With k=1, this is identical to argmax except that it
returns an array of length 1 instead of a scalar.
"""
idx = np.argpartition(a, -k, axis=axis)[-k:]
return idx[np.argsort(a[idx])]
def argkmax2(a, k, axis=-1):
"""Return the indices of the k largest elements of a.
With k=1, this is identical to argmax except that it
returns an array of length 1 instead of a scalar.
"""
idx = np.empty(k, int)
amin = np.min(a)
save = np.empty(k, a.dtype)
for i in range(-1, -(k + 1), -1):
j = np.argmax(a)
idx[i] = j
save[i] = a[j]
a[j] = amin
a[idx] = save
return idx
# First form is faster for large k, second is faster for small k.
rng = np.random.RandomState(123)
R = rng.normal(size=10000000)
%time K = argkmax1(R, 5); # Wall time: 117 ms
%time K = argkmax2(R, 5); # Wall time: 66.6 ms
%time K = argkmax1(R, 15); # Wall time: 118 ms
%time K = argkmax2(R, 15); # Wall time: 165 ms
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment