Skip to content

Instantly share code, notes, and snippets.

@ekreutz
Created January 18, 2024 15:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ekreutz/16c716aa90b74386637406fc948e9bbe to your computer and use it in GitHub Desktop.
Save ekreutz/16c716aa90b74386637406fc948e9bbe to your computer and use it in GitHub Desktop.
Fast O(n) running maximum using numba
from numba import njit
from numpy.typing import NDArray as array
@njit
def running_max(values: array, w: int) -> array:
"""Fast O(n) running maximum.
For large values of `w` this solution is 100x faster or more, than the naive version.
"""
n: int = len(values)
# We'll fake a queue using an array, since numba doesn't have queues or linked
# lists with efficient pops/pushes
queue: array = np.zeros(w, np.int64)
bl: int = 0 # index of oldest element inserted
br: int = -1 # index of newest element inserted
bn: int = 0
max_vals: array = np.zeros(n, values.dtype)
for i in range(n):
# remove (pop left) elements that fell out of the window
while bn > 0 and queue[bl] < i - w + 1:
bl = (bl + 1) % w
bn -= 1
# remove (pop right) elements whose values are less than the current value
# found at values[i]
while bn > 0 and values[queue[br]] <= values[i]:
br = (w + br - 1) % w
bn -= 1
# Add current index (on the right)
br = (br + 1) % w
bn += 1
queue[br] = i
# The max element is always at the left (bl) of the queue
max_vals[i] = values[queue[bl]]
return max_vals
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment