#!/usr/bin/env python2 | |
# Efficient means of determining which ranges from set overlap a given query | |
# range. See http://www.cise.ufl.edu/~sahni/cop5536/powerpoint/lec34.ppt for | |
# details. | |
class Node(object): | |
def __init__(self, intervals): | |
self._olap = [] | |
self.l = None | |
self.r = None | |
self.pt = None | |
self._set_centre(intervals) | |
self._add_intervals(intervals) | |
def _set_centre(self, intervals): | |
endpoints = set() | |
for interval in intervals: | |
endpoints.add(interval.l) | |
endpoints.add(interval.r) | |
self.pt = self._find_median(endpoints) | |
def _add_intervals(self, intervals): | |
to_left = [] | |
to_right = [] | |
for interval in intervals: | |
if self.pt in interval: | |
self._olap.append(interval) | |
elif self.pt < interval.l: | |
to_right.append(interval) | |
elif self.pt > interval.r: | |
to_left.append(interval) | |
if to_left: | |
self.l = Node(to_left) | |
if to_right: | |
self.r = Node(to_right) | |
def _find_median(self, points): | |
midpoint = int((len(points) - 1) / 2) | |
return sorted(points)[midpoint] | |
def find_overlapping(self, interval): | |
olap = [] | |
if self.pt in interval: | |
olap += self._olap | |
if self.l: | |
olap += self.l.find_overlapping(interval) | |
if self.r: | |
olap += self.r.find_overlapping(interval) | |
elif self.pt < interval.l: | |
for candidate in self._olap: | |
if candidate.r >= interval.l: | |
olap.append(candidate) | |
if self.r: | |
olap += self.r.find_overlapping(interval) | |
elif self.pt > interval.r: | |
for candidate in self._olap: | |
if candidate.l <= interval.r: | |
olap.append(candidate) | |
if self.l: | |
olap += self.l.find_overlapping(interval) | |
return olap | |
class Interval(object): | |
def __init__(self, l, r): | |
self.l = l | |
self.r = r | |
def __contains__(self, p): | |
return self.l <= p <= self.r | |
def __repr__(self): | |
return '[%s, %s]' % (self.l, self.r) | |
class IntervalTree(object): | |
def __init__(self, intervals): | |
self._root = Node([Interval(i[0], i[1]) for i in intervals]) | |
def find_overlapping(self, interval): | |
return self._root.find_overlapping(interval) | |
def main(): | |
r1 = IntervalTree([ (1, 1000), (1100, 1200) ]) | |
r2 = IntervalTree([ (30, 50), (60, 200), (1150, 1300) ]) | |
print(r2.find_overlapping(Interval(30, 60))) | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment