-
-
Save mtomko/2e90ce76087f595daf8b2334553fbd91 to your computer and use it in GitHub Desktop.
An attempt to create an interval tree based on the red-black tree implementation given by Okasaki in his book on persistent data structures
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Copyright 2015. Genomic Perturbation Platform, The Broad Institute of Harvard and MIT. | |
* http://www.broadinstitute.org | |
*/ | |
package org.broadinstitute.rnai.scalamari.interval | |
import scala.annotation.tailrec | |
/** A representation of a closed interval in the domain of `A` */ | |
trait Interval[+A] { | |
def e1: A | |
def e2: A | |
} | |
object Interval { | |
/** Creates an `Ordering[Interval[A]]` from an `Ordering[A]` suitable for use by the | |
* IntervalTree */ | |
def ordering[A](implicit ordering: Ordering[A]): Ordering[Interval[A]] = | |
Ordering.by { i: Interval[A] => (i.e1, i.e2) } | |
/** Returns the lower endpoint of the Interval[A] based on the implicit ordering */ | |
def min[A](i: Interval[A])(implicit ordering: Ordering[A]): A = ordering.min(i.e1, i.e2) | |
/** Returns the higher endpoint of the Interval[A] based on the implicit ordering */ | |
def max[A](i: Interval[A])(implicit ordering: Ordering[A]): A = ordering.max(i.e1, i.e2) | |
/** Returns a tuple representing the endpoints of `i` sorted by the implicit ordering */ | |
def ordered[A](i: Interval[A])(implicit ordering: Ordering[A]): (A, A) = { | |
import ordering._ | |
if (i.e1 > i.e2) (i.e2, i.e1) | |
else (i.e1, i.e2) | |
} | |
def ordered[A](x: A, y: A)(implicit ordering: Ordering[A]): (A, A) = { | |
import ordering._ | |
if (x > y) (y, x) | |
else (x, y) | |
} | |
/** Returns true iff the closed intervals `x` and `y` overlap */ | |
def overlaps[A](x: Interval[A], y: Interval[A])(implicit ordering: Ordering[A]): Boolean = | |
overlaps(x.e1, x.e2, y.e1, y.e2) | |
/** Returns true iff the closed intervals `(e1, e2)` and `i` overlap */ | |
def overlaps[A](e1: A, e2: A, i: Interval[A])(implicit ordering: Ordering[A]): Boolean = | |
overlaps(e1, e2, i.e1, i.e2) | |
/** Returns true iff the closed intervals `(x1, x2)` and `(y1, y2)` overlap */ | |
def overlaps[A](x1: A, x2: A, y1: A, y2: A)(implicit ordering: Ordering[A]): Boolean = { | |
import ordering._ | |
val (xl, xh) = Interval.ordered(x1, x2) | |
val (yl, yh) = Interval.ordered(y1, y2) | |
// |--- x ----| | |
// |- y -| | |
(xl >= yl && xl <= yh) || | |
// |--- x ---| | |
// |---y---| | |
(xh >= yl && xh <= yh) | |
} | |
} | |
/** An immutable red-black tree representation of a collection of closed intervals, providing efficient | |
* methods for finding intervals that overlap a given query interval */ | |
class IntervalTree[E, I <: Interval[E]] private[interval] (val t: IntervalTree.Tree[E, I])(implicit ordering: Ordering[E]) { | |
import IntervalTree._ | |
/** Returns a new `IntervalTree[E, I]` containing all the intervals in this tree, as well as the | |
* provided interval */ | |
def +(x: I): IntervalTree[E, I] = new IntervalTree(insert(x, t)) | |
/** Returns an interval that overlaps `i` from the tree, if one exists */ | |
def search(i: Interval[E]): Option[I] = { | |
import ordering._ | |
val iLow = Interval.min(i) | |
@tailrec | |
def loop(x: Tree[E, I]): Option[I] = { | |
x match { | |
case E => None | |
case T(_, left, ti, _, right) => | |
if (Interval.overlaps(i, ti)) Some(ti) | |
else { | |
left match { | |
case T(_, _, _, leftMax, _) if leftMax >= iLow => loop(left) | |
case _ => loop(right) | |
} | |
} | |
} | |
} | |
loop(t) | |
} | |
/** Returns all intervals that overlap `(min(x, y), max(x, y))` in the tree */ | |
def overlapping(x: E, y: E): Seq[I] = { | |
import ordering._ | |
val (iLow, iHi) = Interval.ordered(x, y) | |
def loop(x: Tree[E, I]): List[I] = { | |
x match { | |
case E => Nil | |
case T(_, left, ti, _, right) => | |
val lefts = left match { | |
// explore the left subtree if `i` begins below the max upper bound in the left subtree | |
case T(_, _, _, leftMax, _) if leftMax >= iLow => loop(left) | |
case _ => Nil | |
} | |
// explore the right subtree if this interval starts before `i` ends | |
val rights = if (Interval.min(ti) < iHi) loop(right) else Nil | |
if (Interval.overlaps(iLow, iHi, ti)) ti :: (rights ++ lefts) | |
else rights ++ lefts | |
} | |
} | |
loop(t) | |
} | |
/** Returns all intervals that overlap `i` in the tree */ | |
def overlapping(i: Interval[E]): Seq[I] = overlapping(i.e1, i.e2) | |
} | |
object IntervalTree { | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
// Public API | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
/** Construct a new empty RBTree */ | |
def apply[E, I <: Interval[E]]()(implicit ordering: Ordering[E]): IntervalTree[E, I] = | |
new IntervalTree(E) | |
/** Construct a new RBTree containing the single provided element */ | |
def apply[E, I <: Interval[E]](x: I)(implicit ordering: Ordering[E]) = | |
new IntervalTree(insert[E, I](x, E)) | |
/** Construct a new RBTree with the provided contents */ | |
def apply[E, I <: Interval[E]](xs: Traversable[I])(implicit ordering: Ordering[E]) = | |
new IntervalTree(insertAll[E, I](xs, E)) | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
// Internal tree representation | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
private[interval] sealed trait Color | |
private[interval] case object R extends Color | |
private[interval] case object B extends Color | |
private[interval] sealed trait Tree[+E, +I <: Interval[E]] | |
private[interval] case object E extends Tree[Nothing, Nothing] | |
private[interval] case class T[+E, +I <: Interval[E]](color: Color, | |
left: Tree[E, I], | |
interval: I, | |
max: E, | |
right: Tree[E, I]) | |
(implicit ordering: Ordering[E]) extends Tree[E, I] { | |
// this check was necessary for debugging; leaving it in place now | |
require(left != E || right != E || (left == E && right == E && max == Interval.max(interval)), | |
"Empty tree must have the correct max subtree value") | |
} | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
// Private operations | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
private[this] def insertAll[E, I <: Interval[E]](xs: Traversable[I], t: Tree[E, I]) | |
(implicit ordering: Ordering[E]): Tree[E, I] = | |
xs.foldLeft(t) { (t: Tree[E, I], x: I) => insert(x, t) } | |
/** Inserts the given element into the tree */ | |
private[IntervalTree] def insert[E, I <: Interval[E]](x: I, t: Tree[E, I]) | |
(implicit ordering: Ordering[E]): Tree[E, I] = { | |
val intervalOrdering = Interval.ordering | |
import intervalOrdering._ | |
def ins(t: Tree[E, I]): T[E, I] = t match { | |
case E => T(R, E, x, Interval.max(x), E) | |
case s@T(color, a, y, ymax, b) => | |
if (x <= y) balance((color, ins(a), y, ymax, b)) | |
else balance((color, a, y, ymax, ins(b))) | |
} | |
ins(t).copy(color = B) | |
} | |
private[interval] def subtreeMax[E](l: Tree[E, Interval[E]], | |
k: Interval[E], | |
r: Tree[E, Interval[E]]) | |
(implicit ordering: Ordering[E]): E = { | |
import ordering._ | |
val upper = Interval.max(k) | |
(l, r) match { | |
// k is an interval at a leaf node | |
case (E, E) => upper | |
// k is an interval at a node with a left child only | |
case (T(_, _, _, lmax, _), E) => max(lmax, upper) | |
// k is an interval at a node with a right child only | |
case (E, T(_, _, _, rmax, _)) => max(rmax, upper) | |
// k is an interval at a node with 2 children | |
case (T(_, _, _, lmax, _), T(_, _, _, rmax, _)) => max(max(lmax, upper), rmax) | |
} | |
} | |
// DON'T PANIC | |
private[this] def balance[E, I <: Interval[E]](t: (Color, Tree[E, I], I, E, Tree[E, I])) | |
(implicit ordering: Ordering[E]): T[E, I] = { | |
val balanced = t match { | |
// corresponds to the graph at 9:00 in Fig. 3.5 (Okasaki) | |
case (B, T(R, T(R, a, x, xmax, b), y, _, c), z, zmax, d) => | |
// this subtree only changes color from R -> B; its max does not change | |
val left = T(B, a, x, xmax.asInstanceOf[E], b) | |
// z is the new right child, max should be chosen from c, z, and d | |
val right: T[E, I] = T(B, c, z, subtreeMax(c, z, d), d) | |
T(R, left, y, subtreeMax(left, y, right), right) | |
// corresponds to the graph at 12:00 in Fig. 3.5 (Okasaki) | |
case (B, T(R, a, x, _, T(R, b, y, _, c)), z, zmax, d) => | |
// x no longer contains c, so compute max from a, x, b | |
val left: T[E, I] = T(B, a, x, subtreeMax(a, x, b), b) | |
// z no longer contains a or b, compute max from c, z, d | |
val right: T[E, I] = T(B, c, z, subtreeMax(c, z, d), d) | |
T(R, left, y, subtreeMax(left, y, right), right) | |
// corresponds to the graph at 6:00 in Fig. 3.5 (Okasaki) | |
case (B, a, x, xmax, T(R, T(R, b, y, _, c), z, _, d)) => | |
// x no longer contains c or d, compute max from a, x, b | |
val left: T[E, I] = T(B, a, x, subtreeMax(a, x, b), b) | |
// z no longer contains b, compute max from c, z, d | |
val right: T[E, I] = T(B, c, z, subtreeMax(c, z, d), d) | |
T(R, left, y, subtreeMax(left, y, right), right) | |
// corresponds to the graph at 3:00 in Fig. 3.5 (Okasaki) | |
case (B, a, x, xmax, T(R, b, y, _, T(R, c, z, zmax, d))) => | |
// x no longer contains c or d; compute max from a, x, b | |
val left: T[E, I] = T(B, a, x, subtreeMax(a, x, b), b) | |
// this subtree only changes color from R -> B; its max does not change | |
val right = T(B, c, z, zmax.asInstanceOf[E], d) | |
T(R, left, y, subtreeMax(left, y, right), right) | |
// the default case | |
case (color, a, e, _, b) => | |
//compute `subtreeMax(a, e, b)` because the old max may have been changed by ins() | |
T(color, a, e, subtreeMax(a, e, b), b) | |
} | |
balanced | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Copyright 2015. Genomic Perturbation Platform, The Broad Institute of Harvard and MIT. | |
* http://www.broadinstitute.org | |
*/ | |
package org.broadinstitute.rnai.scalamari.interval | |
import org.scalatest.{FlatSpec, Matchers} | |
import scala.util.Random | |
class IntervalTreeSpec extends FlatSpec with Matchers { | |
import IntervalTree._ | |
def checkInvariants[A](tree: IntervalTree[A, Interval[A]])(implicit ordering: Ordering[A]) = { | |
// No red node has a red child. | |
def redBlackTreeInvariant1(tree: Tree[A, Interval[A]]): Unit = { | |
tree match { | |
case E => () | |
case T(R, T(R, _, _, _, _), _, _, _) | |
| T(R, _, _, _, T(R, _, _, _, _)) => | |
fail("Invariant check failed: tree has a red node that has a red child.") | |
case T(_, a, _, _, b) => | |
redBlackTreeInvariant1(a) | |
redBlackTreeInvariant1(b) | |
} | |
} | |
// Every path from the root to an empty node contains the same number of black nodes. | |
def redBlackTreeInvariant2(tree: Tree[_, _]): Unit = { | |
def traverse(tree: Tree[_, _], blackNodes: Int): Int = { | |
tree match { | |
case E => blackNodes | |
case T(color, a, _, _, b) => | |
val subBlackNodes = if (color == B) blackNodes + 1 else blackNodes | |
val left = traverse(a, subBlackNodes) | |
val right = traverse(b, subBlackNodes) | |
if (left != right) | |
fail("Invariant check failed: tree contains 2 paths to empty nodes with different numbers of black nodes") | |
else left | |
} | |
} | |
traverse(tree, 0) | |
} | |
// The nodes in the tree are stored in order of the lowest point in the interval | |
def intervalTreeInvariant1(tree: Tree[A, Interval[A]])(implicit ordering: Ordering[A]): Unit = { | |
import ordering._ | |
def inOrder(tree: Tree[A, Interval[A]], prev: Option[A]): Option[A] = { | |
tree match { | |
case E => prev | |
case T(_, a, k, _, b) => | |
val pa = inOrder(a, prev) | |
pa.foreach { p => | |
if (p > Interval.min(k)) | |
fail("Invariant check failed: tree is not sorted in order of the lowest point in the interval") | |
} | |
inOrder(b, Some(Interval.min(k))) | |
} | |
} | |
inOrder(tree, None) | |
} | |
// The max value stored at each node should be at least as large as any value in the subtree | |
def intervalTreeInvariant2(tree: Tree[A, Interval[A]])(implicit ordering: Ordering[A]): Unit = { | |
import ordering._ | |
def traverse(tree: Tree[A, Interval[A]], parentMax: Option[A]): Unit = { | |
tree match { | |
case E => // fine | |
case T(_, a, k, subtreeMax, b) => | |
if (Interval.max(k) > subtreeMax) { | |
fail("Invariant check failed: tree has inconsistent interval and max") | |
} | |
parentMax.foreach { m => | |
if (subtreeMax > m) fail("Invariant check failed: tree has a subtree with a max larger than the parent's max") | |
} | |
traverse(a, Some(subtreeMax)) | |
traverse(b, Some(subtreeMax)) | |
} | |
} | |
traverse(tree, None) | |
} | |
redBlackTreeInvariant1(tree.t) | |
redBlackTreeInvariant2(tree.t) | |
intervalTreeInvariant1(tree.t) | |
intervalTreeInvariant2(tree.t) | |
} | |
case class TomsInterval(override val e1: Int, override val e2: Int) extends Interval[Int] | |
case class MarksInterval(override val e1: Int, override val e2: Int) extends Interval[Int] | |
case class NamedInterval(override val e1: Int, override val e2: Int, name: String) extends Interval[Int] | |
// this set of intervals is borrowed from the interval tree example in CLR ... | |
val Intervals = Seq(TomsInterval(0, 3), | |
TomsInterval(5, 8), | |
TomsInterval(6, 10), | |
TomsInterval(8, 9), | |
TomsInterval(15, 23), | |
TomsInterval(16, 21), | |
TomsInterval(17, 19), | |
TomsInterval(19, 20), | |
TomsInterval(25, 30), | |
TomsInterval(26, 26)) | |
"Interval.min" should "find the smallest endpoint in an interval" in { | |
Interval.min(TomsInterval(4, 22)) should be (4) | |
Interval.min(TomsInterval(22, 4)) should be (4) | |
Interval.min(TomsInterval(0, -3)) should be (-3) | |
Interval.min(TomsInterval(-3, 1)) should be (-3) | |
} | |
"Interval.max" should "find the largest endpoint in an interval" in { | |
Interval.max(TomsInterval(4, 22)) should be (22) | |
Interval.max(TomsInterval(22, 4)) should be (22) | |
Interval.max(TomsInterval(0, -3)) should be (0) | |
Interval.max(TomsInterval(-3, 1)) should be (1) | |
} | |
"Interval.ordered" should "sort the elements in an interval" in { | |
Interval.ordered(4, 22) should be ((4, 22)) | |
Interval.ordered(TomsInterval(4, 22)) should be ((4, 22)) | |
Interval.ordered(22, 4) should be ((4, 22)) | |
Interval.ordered(TomsInterval(22, 4)) should be ((4, 22)) | |
Interval.ordered(0, -3) should be ((-3, 0)) | |
Interval.ordered(TomsInterval(0, -3)) should be (-3, 0) | |
Interval.ordered(-3, 0) should be ((-3, 0)) | |
Interval.ordered(TomsInterval(-3, 0)) should be (-3, 0) | |
} | |
"Interval.overlaps" should "determine whether two intervals overlap" in { | |
val i1 = TomsInterval(4, 9) | |
val i2 = MarksInterval(9, 22) | |
val i3 = TomsInterval(0, 3) | |
val i4 = MarksInterval(15, 33) | |
val i5 = TomsInterval(15, 30) | |
Interval.overlaps(i1, i2) should be (true) | |
Interval.overlaps(i2, i1) should be (true) | |
Interval.overlaps(i1, i3) should be (false) | |
Interval.overlaps(i1, i4) should be (false) | |
Interval.overlaps(i2, i3) should be (false) | |
Interval.overlaps(i2, i4) should be (true) | |
Interval.overlaps(i3, i4) should be (false) | |
Interval.overlaps(i4, i5) should be (true) | |
} | |
"IntervalTree" should "hold intervals without violating its invariants" in { | |
// ... because we're perverse, we're going to try all permutations | |
val intervalPermutations = Intervals.permutations | |
intervalPermutations.foreach { is => | |
val tree = IntervalTree[Int, Interval[Int]](is) | |
checkInvariants(tree) | |
} | |
} | |
it should "find an overlapping interval" in { | |
val tree = IntervalTree[Int, TomsInterval](Intervals.sortBy { _ => Random.nextInt() }) | |
val x = tree.search(TomsInterval(24, 26)) | |
x.isDefined should be (true) | |
x map Interval.ordered[Int] foreach { case (a, b) => | |
(a >= 24 && a <= 26) || (b >= 24 && b <= 26) should be (true) | |
} | |
} | |
it should "find all overlapping intervals" in { | |
val randomizedIntervals: Seq[TomsInterval] = Intervals.sortBy { _ => Random.nextInt() } | |
val tree = IntervalTree[Int, TomsInterval](randomizedIntervals) | |
val x = tree.overlapping(TomsInterval(24, 26)) | |
x should have size 2 | |
x should contain (TomsInterval(25, 30)) | |
x should contain (TomsInterval(26, 26)) | |
val y = tree.overlapping(TomsInterval(9, 15)) | |
if (y.size != 3) { | |
println(randomizedIntervals) | |
} | |
y should have size 3 | |
y should contain (TomsInterval(8, 9)) | |
y should contain (TomsInterval(6, 10)) | |
y should contain (TomsInterval(15, 23)) | |
val z = tree.overlapping(TomsInterval(27, 29)) | |
z should have size 1 | |
z should contain (TomsInterval(25, 30)) | |
val zz = tree.overlapping(MarksInterval(12, 14)) | |
zz should have size 0 | |
} | |
it should "find overlapping intervals even if it needs to try both sides of the tree" in { | |
val brokenIntervalSets = Seq(List(TomsInterval(17,19), | |
TomsInterval(0,3), | |
TomsInterval(16,21), | |
TomsInterval(26,26), | |
TomsInterval(5,8), | |
TomsInterval(19,20), | |
TomsInterval(6,10), | |
TomsInterval(8,9), | |
TomsInterval(25,30), | |
TomsInterval(15,23)), | |
List(TomsInterval(19,20), | |
TomsInterval(26,26), | |
TomsInterval(0,3), | |
TomsInterval(5,8), | |
TomsInterval(25,30), | |
TomsInterval(6,10), | |
TomsInterval(16,21), | |
TomsInterval(8,9), | |
TomsInterval(17,19), | |
TomsInterval(15,23))) | |
brokenIntervalSets.foreach { brokenIntervals => | |
val tree = IntervalTree[Int, TomsInterval](brokenIntervals) | |
val o = tree.overlapping(MarksInterval(9, 15)) | |
o should have size 3 | |
o should contain (TomsInterval(8, 9)) | |
o should contain (TomsInterval(6, 10)) | |
o should contain (TomsInterval(15, 23)) | |
} | |
} | |
it should "store multiple intervals with the same span" in { | |
val intervals = Seq(NamedInterval(4, 9, "4 to 9"), | |
NamedInterval(1, 4, "one to four"), | |
NamedInterval(4, 9, "four to nine"), | |
NamedInterval(0, 3, "foo"), | |
NamedInterval(-3, 12, "big")) | |
val tree = IntervalTree[Int, NamedInterval](intervals) | |
val o = tree.overlapping(MarksInterval(5, 8)) | |
o should have size 3 | |
o should contain (NamedInterval(4, 9, "4 to 9")) | |
o should contain (NamedInterval(4, 9, "four to nine")) | |
o should contain (NamedInterval(-3, 12, "big")) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment