A class implementing the segment tree data structure
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class segment_tree: | |
"""A class which implements a range sum tree""" | |
def __init__(self, some_list): | |
""" | |
Initialises the tree, needs a list as a parameter | |
""" | |
self.N = len(some_list) | |
self.seg = [0] * (4 * len(some_list)) | |
self.lazy = [0] * (4 * len(some_list)) | |
self._build(some_list, 0, len(some_list) - 1, 1) | |
def query(self, x, y): | |
""" | |
A function that returns the sum of the sub-array x to y, requires two non-negative integers as arguments, both smaller than the size of the array | |
""" | |
assert x <= y < self.N | |
return self._query(0, self.N - 1, 1, x, y) | |
def update(self, x, y, z): | |
""" | |
A function that updates the sub-array x to y, incrementing each value by z | |
""" | |
assert x <= y < self.N | |
return self._update(0, self.N - 1, 1, x, y, z) | |
def _build(self, some_list, L, R, v): | |
""" | |
A function that builds the tree for a sub-array of some_list, called internally. | |
""" | |
if L == R: | |
self.seg[v] = some_list[L] | |
else: | |
mid = (L + R) / 2 | |
self._build(some_list, L, mid, 2 * v) | |
self._build(some_list, mid + 1, R, 2 * v + 1) | |
self.seg[v] = self.seg[2 * v] + self.seg[2 * v + 1] | |
def _propagate(self, L, R, v): | |
""" | |
Implements the propagate aspect of 'lazy-propagation', called internally. | |
""" | |
if self.lazy[v]: | |
self.seg[v] += (R - L + 1) * self.lazy[v] | |
if L < R: | |
self.lazy[2 * v] += self.lazy[v] | |
self.lazy[2 * v + 1] += self.lazy[v] | |
self.lazy[v] = 0 | |
def _query(self, L, R, v, x, y): | |
""" | |
A function that queries a node for a given query, called internally. | |
""" | |
self._propagate(L, R, v) | |
if R < x or y < L: | |
return 0 | |
elif x <= R <= L <= y: | |
return self.seg[v] | |
else: | |
mid = (L + R) / 2 | |
return self._query(L, mid, 2 * v, x, y) + self._query(mid + 1, R, 2 * v + 1, x, y) | |
def _update(self, L, R, v, x, y, z): | |
""" | |
A function that updates a node for a given update, called internally. | |
""" | |
self._propagate(L, R, v) | |
if not (R < x or y < L): | |
if x <= R <= L <= y: | |
self.lazy[v] += z | |
self._propagate(L, R, v) | |
else: | |
mid = (L + R) / 2 | |
self._update(L, mid, 2 * v, x, y, z) | |
self._update(mid + 1, R, 2 * v + 1, x, y, z) | |
self.seg[v] = self.seg[2 * v] + self.seg[2 * v + 1] | |
def main(): | |
print "\nThis is a test-interface to test the segment tree class" | |
print "Enter a space separated list of integers (the array to build the segment tree on)" | |
inp = list(map(int, raw_input().split())) | |
st = segment_tree(inp) | |
M = int(raw_input("Enter the number of queries/updates:\n")) | |
for i in range(M): | |
inp = raw_input().split() | |
if inp[0] == "query": | |
print st.query(*(list(map(int, inp[1:])))) | |
else: | |
st.update(*(list(map(int, inp[1:])))) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment