Skip to content

Instantly share code, notes, and snippets.

@mrchnk
Last active May 18, 2021 08:09
Show Gist options
  • Save mrchnk/154d695e866d465ab3427c79cf93573b to your computer and use it in GitHub Desktop.
Save mrchnk/154d695e866d465ab3427c79cf93573b to your computer and use it in GitHub Desktop.
class SegmentTree:
def __init__(self, size, merge=min, default=0):
self.size = size
self.merge = merge
self.tree = [default] * size * 4
pass
def build(self, values):
self._build(0, 0, self.size-1, values)
def _build(self, v, tl, tr, values):
if tl == tr:
self.tree[v] = values[tl]
else:
tm = (tl + tr) // 2
vl = self._build(v * 2 + 1, tl, tm, values)
vr = self._build(v * 2 + 2, tm + 1, tr, values)
self.tree[v] = self.merge(vl, vr)
return self.tree[v]
def update(self, index, value):
self._update(0, 0, self.size - 1, index, value)
def _update(self, v, tl, tr, index, value):
if tl != tr:
tm = (tl + tr) // 2
if index <= tm:
vl = self._update(v * 2 + 1, tl, tm, index, value)
vr = self.tree[v * 2 + 2]
else:
vl = self.tree[v * 2 + 1]
vr = self._update(v * 2 + 2, tm + 1, tr, index, value)
value = self.merge(vl, vr)
self.tree[v] = value
return value
def query(self, l, r):
return self._query(0, 0, self.size-1, l, r)
def _query(self, v, tl, tr, l, r):
if l == tl and r == tr:
return self.tree[v]
tm = (tl + tr) // 2
if r <= tm:
return self._query(v * 2 + 1, tl, tm, l, r)
if l >= tm + 1:
return self._query(v * 2 + 2, tm + 1, tr, l, r)
lq = self._query(v * 2 + 1, tl, tm, l, tm)
rq = self._query(v * 2 + 2, tm + 1, tr, tm + 1, r)
return self.merge(lq, rq)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment