Skip to content

Instantly share code, notes, and snippets.

@bkj
Created November 12, 2019 16:42
Show Gist options
  • Save bkj/c4d26d9cd4710abd26c35ff1cfee1368 to your computer and use it in GitHub Desktop.
Save bkj/c4d26d9cd4710abd26c35ff1cfee1368 to your computer and use it in GitHub Desktop.
fast_argmax.py
#!/usr/bin/env python
"""
fast_argmax.py
"""
import numpy as np
from time import time
from numba import jit, prange
@jit(nopython=True, parallel=True)
def fast_argmax(x, axis=1):
assert len(x.shape) == 2
rows, cols = x.shape
if axis == 1:
out = np.zeros(x.shape[0], dtype=np.int32)
for row in prange(rows):
out[row] = np.argmax(x[row])
elif axis == 0:
out = np.zeros(x.shape[1], dtype=np.int32)
for col in prange(cols):
best_val = x[0, col]
best_idx = 0
for row in range(rows):
if x[row, col] > best_val:
best_val = x[row, col]
best_idx = row
out[col] = best_idx
return out
# --
# Test
z = np.random.uniform(0, 1, (4000, 4000))
t = time()
a1 = z.argmax(axis=-1)
time() - t
t = time()
a0 = z.argmax(axis=0)
time() - t
t = time()
b1 = fast_argmax(z, axis=1)
time() - t
t = time()
b0 = fast_argmax(z, axis=0)
time() - t
assert (a1 == b1).all()
assert (a0 == b0).all()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment