Create a gist now

Instantly share code, notes, and snippets.

#!/usr/bin/env python2
# Efficient means of determining which ranges from set overlap a given query
# range. See http://www.cise.ufl.edu/~sahni/cop5536/powerpoint/lec34.ppt for
# details.
class Node(object):
def __init__(self, intervals):
self._olap = []
self.l = None
self.r = None
self.pt = None
self._set_centre(intervals)
self._add_intervals(intervals)
def _set_centre(self, intervals):
endpoints = set()
for interval in intervals:
endpoints.add(interval.l)
endpoints.add(interval.r)
self.pt = self._find_median(endpoints)
def _add_intervals(self, intervals):
to_left = []
to_right = []
for interval in intervals:
if self.pt in interval:
self._olap.append(interval)
elif self.pt < interval.l:
to_right.append(interval)
elif self.pt > interval.r:
to_left.append(interval)
if to_left:
self.l = Node(to_left)
if to_right:
self.r = Node(to_right)
def _find_median(self, points):
midpoint = int((len(points) - 1) / 2)
return sorted(points)[midpoint]
def find_overlapping(self, interval):
olap = []
if self.pt in interval:
olap += self._olap
if self.l:
olap += self.l.find_overlapping(interval)
if self.r:
olap += self.r.find_overlapping(interval)
elif self.pt < interval.l:
for candidate in self._olap:
if candidate.r >= interval.l:
olap.append(candidate)
if self.r:
olap += self.r.find_overlapping(interval)
elif self.pt > interval.r:
for candidate in self._olap:
if candidate.l <= interval.r:
olap.append(candidate)
if self.l:
olap += self.l.find_overlapping(interval)
return olap
class Interval(object):
def __init__(self, l, r):
self.l = l
self.r = r
def __contains__(self, p):
return self.l <= p <= self.r
def __repr__(self):
return '[%s, %s]' % (self.l, self.r)
class IntervalTree(object):
def __init__(self, intervals):
self._root = Node([Interval(i[0], i[1]) for i in intervals])
def find_overlapping(self, interval):
return self._root.find_overlapping(interval)
def main():
r1 = IntervalTree([ (1, 1000), (1100, 1200) ])
r2 = IntervalTree([ (30, 50), (60, 200), (1150, 1300) ])
print(r2.find_overlapping(Interval(30, 60)))
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment