Skip to content

Instantly share code, notes, and snippets.

@parthmittal
Last active June 20, 2018 10:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save parthmittal/da9da297ad5e7d287053 to your computer and use it in GitHub Desktop.
Save parthmittal/da9da297ad5e7d287053 to your computer and use it in GitHub Desktop.
A class implementing the segment tree data structure
class segment_tree:
"""A class which implements a range sum tree"""
def __init__(self, some_list):
"""
Initialises the tree, needs a list as a parameter
"""
self.N = len(some_list)
self.seg = [0] * (4 * len(some_list))
self.lazy = [0] * (4 * len(some_list))
self._build(some_list, 0, len(some_list) - 1, 1)
def query(self, x, y):
"""
A function that returns the sum of the sub-array x to y, requires two non-negative integers as arguments, both smaller than the size of the array
"""
assert x <= y < self.N
return self._query(0, self.N - 1, 1, x, y)
def update(self, x, y, z):
"""
A function that updates the sub-array x to y, incrementing each value by z
"""
assert x <= y < self.N
return self._update(0, self.N - 1, 1, x, y, z)
def _build(self, some_list, L, R, v):
"""
A function that builds the tree for a sub-array of some_list, called internally.
"""
if L == R:
self.seg[v] = some_list[L]
else:
mid = (L + R) / 2
self._build(some_list, L, mid, 2 * v)
self._build(some_list, mid + 1, R, 2 * v + 1)
self.seg[v] = self.seg[2 * v] + self.seg[2 * v + 1]
def _propagate(self, L, R, v):
"""
Implements the propagate aspect of 'lazy-propagation', called internally.
"""
if self.lazy[v]:
self.seg[v] += (R - L + 1) * self.lazy[v]
if L < R:
self.lazy[2 * v] += self.lazy[v]
self.lazy[2 * v + 1] += self.lazy[v]
self.lazy[v] = 0
def _query(self, L, R, v, x, y):
"""
A function that queries a node for a given query, called internally.
"""
self._propagate(L, R, v)
if R < x or y < L:
return 0
elif x <= R <= L <= y:
return self.seg[v]
else:
mid = (L + R) / 2
return self._query(L, mid, 2 * v, x, y) + self._query(mid + 1, R, 2 * v + 1, x, y)
def _update(self, L, R, v, x, y, z):
"""
A function that updates a node for a given update, called internally.
"""
self._propagate(L, R, v)
if not (R < x or y < L):
if x <= R <= L <= y:
self.lazy[v] += z
self._propagate(L, R, v)
else:
mid = (L + R) / 2
self._update(L, mid, 2 * v, x, y, z)
self._update(mid + 1, R, 2 * v + 1, x, y, z)
self.seg[v] = self.seg[2 * v] + self.seg[2 * v + 1]
def main():
print "\nThis is a test-interface to test the segment tree class"
print "Enter a space separated list of integers (the array to build the segment tree on)"
inp = list(map(int, raw_input().split()))
st = segment_tree(inp)
M = int(raw_input("Enter the number of queries/updates:\n"))
for i in range(M):
inp = raw_input().split()
if inp[0] == "query":
print st.query(*(list(map(int, inp[1:]))))
else:
st.update(*(list(map(int, inp[1:]))))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment