Skip to content

Instantly share code, notes, and snippets.

@Ending2015a
Last active April 26, 2022 07:41
Show Gist options
  • Save Ending2015a/aed404892de353083c148a74768c3445 to your computer and use it in GitHub Desktop.
Save Ending2015a/aed404892de353083c148a74768c3445 to your computer and use it in GitHub Desktop.
# --- built in ---
import abc
import math
# --- 3rd party ---
import numpy as np
class SegmentTree(metaclass=abc.ABCMeta):
def __init__(self, size: int):
'''An implementation of segment tree used to efficiently O(logN)
compute the sum of a query range [start, end)
Args:
size (int): Number of elements.
'''
assert isinstance(size, int) and size > 0
base = 1<<(size-1).bit_length()
self._size = size
self._base = base
self._value = np.zeros([base * 2], dtype=np.float64)
def __getitem__(self, key: np.ndarray):
# formalize indices
if isinstance(key, (int, slice)):
key = np.asarray(range(self._size)[key], dtype=np.int64)
else:
key = np.asarray(key, dtype=np.int64)
key = key % self._size + self._base
return self._value[key]
def __setitem__(self, key: np.ndarray, value: np.ndarray):
self.update(key, value)
def update(self, key: np.ndarray, value: np.ndarray):
'''Update elements' values'''
# formalize indices
if isinstance(key, (int, slice)):
key = np.asarray(range(self._size)[key], dtype=np.int64)
else:
key = np.asarray(key, dtype=np.int64)
key = key % self._size + self._base
key = key.flatten()
value = np.asarray(value, dtype=np.float64).flatten()
# set values
self._value[key] = value
# update tree (all keys have the same depth)
while key[0] > 1:
self._value[key>>1] = self._value[key] + self._value[key^1]
key >>= 1
def sum(self, start: int=None, end: int=None):
'''Compute the sum of the given range [start, end)'''
if (start == None) and (end == None):
# shortcut
return self._value[1]
start, end, _ = slice(start, end).indices(self._size)
start += self._base
end += self._base
res = 0.0
while start < end:
if start & 1:
res += self._value[start]
if end & 1:
res += self._value[end-1]
start = (start+1) >> 1
end = end >> 1
return res
def index(self, value: np.ndarray):
'''Return the largest index such that
value[0:index+1].sum() < value
'''
assert np.min(value) >= 0.0
assert np.max(value) < self._value[1]
# if input is a scalar, return should be a scalar too.
one_value = np.isscalar(value)
# convert to 1D array
value = np.asarray(value, dtype=np.float64)
orig_shape = value.shape
value = value.flatten()
inds = np.ones_like(value, dtype=np.int64)
# find inds (all inds have the same depth)
while inds[0] < self._base:
inds <<= 1
lsum = self._value[inds]
d = lsum < value
value -= lsum * d
inds += d
inds -= self._base
inds = inds.reshape(orig_shape)
return inds.item() if one_value else inds
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment