Skip to content

Instantly share code, notes, and snippets.

@Shaunwei
Created December 12, 2015 08:18
Show Gist options
  • Save Shaunwei/cfbe6b462d4d68bea5cc to your computer and use it in GitHub Desktop.
Save Shaunwei/cfbe6b462d4d68bea5cc to your computer and use it in GitHub Desktop.
Python SegmentTree Implementation
#!/usr/bin/python
# -*- coding: utf-8 -*-
class SegmentTree(object):
"""
process sequence mutable structure, O(log(n)) query time, O(log(n)) update time.
top-down-recersive-split
down-top-backtrack-update
- build
- query
- update
"""
class Node:
def __init__(self, start, end, val=0):
self.st = start
self.ed = end
self.sum = val
self.left = self.right = None
def __init__(self, nums):
if not nums:
raise ValueError('Input data could not be empty.')
self.root = self.build(nums, 0, len(nums) - 1)
def build(self, nums, st, ed):
if st == ed:
return SegmentTree.Node(st, ed, nums[st])
mid = (st + ed) / 2
root = SegmentTree.Node(st, ed)
root.left = self.build(nums, st, mid)
root.right = self.build(nums, mid + 1, ed)
root.sum = root.left.sum + root.right.sum
return root
def query(self, start, end):
return self._q(self.root, start, end)
def _q(self, root, start, end):
if root.st == start and root.ed == end:
return root.sum
mid = (start + end) / 2
lsum = rsum = 0
if root.st <= mid:
lsum = self._q(root.left, root.st, min(mid, root.end))
if mid + 1 <= end:
rsum = self._q(root.right, max(mid, root.st), root.end)
return lsum + rsum
def update(self, index, value):
self._u(self.root, index, value)
def _u(self, root, index, value):
if root.st == root.ed == index:
root.sum = value
return
mid = (root.st + root.ed) / 2
if index <= mid:
self._u(root.left, index, value)
else:
self._u(root.right, index, value)
root.sum = root.left.sum + root.right.sum
class Solution:
def __init__(self, nums):
self.st = SegmentTree(nums)
def sumRange(self, start, end):
return self.st.query(start, end)
def update(self, index, value):
self.st.update(index, value)
if __name__ == '__main__':
'''
Question:
Given an integer array nums, find the sum of the elements between indices i and j (i ≤ j), inclusive.
The update(i, val) function modifies nums by updating the element at index i to val.
Example:
Given nums = [1, 3, 5]
sumRange(0, 2) -> 9
update(1, 2)
sumRange(0, 2) -> 8
'''
nums = [1, 3, 5]
s = Solution(nums)
print(s.sumRange(0, 2))
s.update(1, 2)
print(s.sumRange(0, 2))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment