Skip to content

Instantly share code, notes, and snippets.

@mtomko
Last active February 9, 2018 21:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mtomko/2e90ce76087f595daf8b2334553fbd91 to your computer and use it in GitHub Desktop.
Save mtomko/2e90ce76087f595daf8b2334553fbd91 to your computer and use it in GitHub Desktop.
An attempt to create an interval tree based on the red-black tree implementation given by Okasaki in his book on persistent data structures
/*
* Copyright 2015. Genomic Perturbation Platform, The Broad Institute of Harvard and MIT.
* http://www.broadinstitute.org
*/
package org.broadinstitute.rnai.scalamari.interval
import scala.annotation.tailrec
/** A representation of a closed interval in the domain of `A` */
trait Interval[+A] {
def e1: A
def e2: A
}
object Interval {
/** Creates an `Ordering[Interval[A]]` from an `Ordering[A]` suitable for use by the
* IntervalTree */
def ordering[A](implicit ordering: Ordering[A]): Ordering[Interval[A]] =
Ordering.by { i: Interval[A] => (i.e1, i.e2) }
/** Returns the lower endpoint of the Interval[A] based on the implicit ordering */
def min[A](i: Interval[A])(implicit ordering: Ordering[A]): A = ordering.min(i.e1, i.e2)
/** Returns the higher endpoint of the Interval[A] based on the implicit ordering */
def max[A](i: Interval[A])(implicit ordering: Ordering[A]): A = ordering.max(i.e1, i.e2)
/** Returns a tuple representing the endpoints of `i` sorted by the implicit ordering */
def ordered[A](i: Interval[A])(implicit ordering: Ordering[A]): (A, A) = {
import ordering._
if (i.e1 > i.e2) (i.e2, i.e1)
else (i.e1, i.e2)
}
def ordered[A](x: A, y: A)(implicit ordering: Ordering[A]): (A, A) = {
import ordering._
if (x > y) (y, x)
else (x, y)
}
/** Returns true iff the closed intervals `x` and `y` overlap */
def overlaps[A](x: Interval[A], y: Interval[A])(implicit ordering: Ordering[A]): Boolean =
overlaps(x.e1, x.e2, y.e1, y.e2)
/** Returns true iff the closed intervals `(e1, e2)` and `i` overlap */
def overlaps[A](e1: A, e2: A, i: Interval[A])(implicit ordering: Ordering[A]): Boolean =
overlaps(e1, e2, i.e1, i.e2)
/** Returns true iff the closed intervals `(x1, x2)` and `(y1, y2)` overlap */
def overlaps[A](x1: A, x2: A, y1: A, y2: A)(implicit ordering: Ordering[A]): Boolean = {
import ordering._
val (xl, xh) = Interval.ordered(x1, x2)
val (yl, yh) = Interval.ordered(y1, y2)
// |--- x ----|
// |- y -|
(xl >= yl && xl <= yh) ||
// |--- x ---|
// |---y---|
(xh >= yl && xh <= yh)
}
}
/** An immutable red-black tree representation of a collection of closed intervals, providing efficient
* methods for finding intervals that overlap a given query interval */
class IntervalTree[E, I <: Interval[E]] private[interval] (val t: IntervalTree.Tree[E, I])(implicit ordering: Ordering[E]) {
import IntervalTree._
/** Returns a new `IntervalTree[E, I]` containing all the intervals in this tree, as well as the
* provided interval */
def +(x: I): IntervalTree[E, I] = new IntervalTree(insert(x, t))
/** Returns an interval that overlaps `i` from the tree, if one exists */
def search(i: Interval[E]): Option[I] = {
import ordering._
val iLow = Interval.min(i)
@tailrec
def loop(x: Tree[E, I]): Option[I] = {
x match {
case E => None
case T(_, left, ti, _, right) =>
if (Interval.overlaps(i, ti)) Some(ti)
else {
left match {
case T(_, _, _, leftMax, _) if leftMax >= iLow => loop(left)
case _ => loop(right)
}
}
}
}
loop(t)
}
/** Returns all intervals that overlap `(min(x, y), max(x, y))` in the tree */
def overlapping(x: E, y: E): Seq[I] = {
import ordering._
val (iLow, iHi) = Interval.ordered(x, y)
def loop(x: Tree[E, I]): List[I] = {
x match {
case E => Nil
case T(_, left, ti, _, right) =>
val lefts = left match {
// explore the left subtree if `i` begins below the max upper bound in the left subtree
case T(_, _, _, leftMax, _) if leftMax >= iLow => loop(left)
case _ => Nil
}
// explore the right subtree if this interval starts before `i` ends
val rights = if (Interval.min(ti) < iHi) loop(right) else Nil
if (Interval.overlaps(iLow, iHi, ti)) ti :: (rights ++ lefts)
else rights ++ lefts
}
}
loop(t)
}
/** Returns all intervals that overlap `i` in the tree */
def overlapping(i: Interval[E]): Seq[I] = overlapping(i.e1, i.e2)
}
object IntervalTree {
/////////////////////////////////////////////////////////////////////////////////////////////////
// Public API
/////////////////////////////////////////////////////////////////////////////////////////////////
/** Construct a new empty RBTree */
def apply[E, I <: Interval[E]]()(implicit ordering: Ordering[E]): IntervalTree[E, I] =
new IntervalTree(E)
/** Construct a new RBTree containing the single provided element */
def apply[E, I <: Interval[E]](x: I)(implicit ordering: Ordering[E]) =
new IntervalTree(insert[E, I](x, E))
/** Construct a new RBTree with the provided contents */
def apply[E, I <: Interval[E]](xs: Traversable[I])(implicit ordering: Ordering[E]) =
new IntervalTree(insertAll[E, I](xs, E))
/////////////////////////////////////////////////////////////////////////////////////////////////
// Internal tree representation
/////////////////////////////////////////////////////////////////////////////////////////////////
private[interval] sealed trait Color
private[interval] case object R extends Color
private[interval] case object B extends Color
private[interval] sealed trait Tree[+E, +I <: Interval[E]]
private[interval] case object E extends Tree[Nothing, Nothing]
private[interval] case class T[+E, +I <: Interval[E]](color: Color,
left: Tree[E, I],
interval: I,
max: E,
right: Tree[E, I])
(implicit ordering: Ordering[E]) extends Tree[E, I] {
// this check was necessary for debugging; leaving it in place now
require(left != E || right != E || (left == E && right == E && max == Interval.max(interval)),
"Empty tree must have the correct max subtree value")
}
/////////////////////////////////////////////////////////////////////////////////////////////////
// Private operations
/////////////////////////////////////////////////////////////////////////////////////////////////
private[this] def insertAll[E, I <: Interval[E]](xs: Traversable[I], t: Tree[E, I])
(implicit ordering: Ordering[E]): Tree[E, I] =
xs.foldLeft(t) { (t: Tree[E, I], x: I) => insert(x, t) }
/** Inserts the given element into the tree */
private[IntervalTree] def insert[E, I <: Interval[E]](x: I, t: Tree[E, I])
(implicit ordering: Ordering[E]): Tree[E, I] = {
val intervalOrdering = Interval.ordering
import intervalOrdering._
def ins(t: Tree[E, I]): T[E, I] = t match {
case E => T(R, E, x, Interval.max(x), E)
case s@T(color, a, y, ymax, b) =>
if (x <= y) balance((color, ins(a), y, ymax, b))
else balance((color, a, y, ymax, ins(b)))
}
ins(t).copy(color = B)
}
private[interval] def subtreeMax[E](l: Tree[E, Interval[E]],
k: Interval[E],
r: Tree[E, Interval[E]])
(implicit ordering: Ordering[E]): E = {
import ordering._
val upper = Interval.max(k)
(l, r) match {
// k is an interval at a leaf node
case (E, E) => upper
// k is an interval at a node with a left child only
case (T(_, _, _, lmax, _), E) => max(lmax, upper)
// k is an interval at a node with a right child only
case (E, T(_, _, _, rmax, _)) => max(rmax, upper)
// k is an interval at a node with 2 children
case (T(_, _, _, lmax, _), T(_, _, _, rmax, _)) => max(max(lmax, upper), rmax)
}
}
// DON'T PANIC
private[this] def balance[E, I <: Interval[E]](t: (Color, Tree[E, I], I, E, Tree[E, I]))
(implicit ordering: Ordering[E]): T[E, I] = {
val balanced = t match {
// corresponds to the graph at 9:00 in Fig. 3.5 (Okasaki)
case (B, T(R, T(R, a, x, xmax, b), y, _, c), z, zmax, d) =>
// this subtree only changes color from R -> B; its max does not change
val left = T(B, a, x, xmax.asInstanceOf[E], b)
// z is the new right child, max should be chosen from c, z, and d
val right: T[E, I] = T(B, c, z, subtreeMax(c, z, d), d)
T(R, left, y, subtreeMax(left, y, right), right)
// corresponds to the graph at 12:00 in Fig. 3.5 (Okasaki)
case (B, T(R, a, x, _, T(R, b, y, _, c)), z, zmax, d) =>
// x no longer contains c, so compute max from a, x, b
val left: T[E, I] = T(B, a, x, subtreeMax(a, x, b), b)
// z no longer contains a or b, compute max from c, z, d
val right: T[E, I] = T(B, c, z, subtreeMax(c, z, d), d)
T(R, left, y, subtreeMax(left, y, right), right)
// corresponds to the graph at 6:00 in Fig. 3.5 (Okasaki)
case (B, a, x, xmax, T(R, T(R, b, y, _, c), z, _, d)) =>
// x no longer contains c or d, compute max from a, x, b
val left: T[E, I] = T(B, a, x, subtreeMax(a, x, b), b)
// z no longer contains b, compute max from c, z, d
val right: T[E, I] = T(B, c, z, subtreeMax(c, z, d), d)
T(R, left, y, subtreeMax(left, y, right), right)
// corresponds to the graph at 3:00 in Fig. 3.5 (Okasaki)
case (B, a, x, xmax, T(R, b, y, _, T(R, c, z, zmax, d))) =>
// x no longer contains c or d; compute max from a, x, b
val left: T[E, I] = T(B, a, x, subtreeMax(a, x, b), b)
// this subtree only changes color from R -> B; its max does not change
val right = T(B, c, z, zmax.asInstanceOf[E], d)
T(R, left, y, subtreeMax(left, y, right), right)
// the default case
case (color, a, e, _, b) =>
//compute `subtreeMax(a, e, b)` because the old max may have been changed by ins()
T(color, a, e, subtreeMax(a, e, b), b)
}
balanced
}
}
/*
* Copyright 2015. Genomic Perturbation Platform, The Broad Institute of Harvard and MIT.
* http://www.broadinstitute.org
*/
package org.broadinstitute.rnai.scalamari.interval
import org.scalatest.{FlatSpec, Matchers}
import scala.util.Random
class IntervalTreeSpec extends FlatSpec with Matchers {
import IntervalTree._
def checkInvariants[A](tree: IntervalTree[A, Interval[A]])(implicit ordering: Ordering[A]) = {
// No red node has a red child.
def redBlackTreeInvariant1(tree: Tree[A, Interval[A]]): Unit = {
tree match {
case E => ()
case T(R, T(R, _, _, _, _), _, _, _)
| T(R, _, _, _, T(R, _, _, _, _)) =>
fail("Invariant check failed: tree has a red node that has a red child.")
case T(_, a, _, _, b) =>
redBlackTreeInvariant1(a)
redBlackTreeInvariant1(b)
}
}
// Every path from the root to an empty node contains the same number of black nodes.
def redBlackTreeInvariant2(tree: Tree[_, _]): Unit = {
def traverse(tree: Tree[_, _], blackNodes: Int): Int = {
tree match {
case E => blackNodes
case T(color, a, _, _, b) =>
val subBlackNodes = if (color == B) blackNodes + 1 else blackNodes
val left = traverse(a, subBlackNodes)
val right = traverse(b, subBlackNodes)
if (left != right)
fail("Invariant check failed: tree contains 2 paths to empty nodes with different numbers of black nodes")
else left
}
}
traverse(tree, 0)
}
// The nodes in the tree are stored in order of the lowest point in the interval
def intervalTreeInvariant1(tree: Tree[A, Interval[A]])(implicit ordering: Ordering[A]): Unit = {
import ordering._
def inOrder(tree: Tree[A, Interval[A]], prev: Option[A]): Option[A] = {
tree match {
case E => prev
case T(_, a, k, _, b) =>
val pa = inOrder(a, prev)
pa.foreach { p =>
if (p > Interval.min(k))
fail("Invariant check failed: tree is not sorted in order of the lowest point in the interval")
}
inOrder(b, Some(Interval.min(k)))
}
}
inOrder(tree, None)
}
// The max value stored at each node should be at least as large as any value in the subtree
def intervalTreeInvariant2(tree: Tree[A, Interval[A]])(implicit ordering: Ordering[A]): Unit = {
import ordering._
def traverse(tree: Tree[A, Interval[A]], parentMax: Option[A]): Unit = {
tree match {
case E => // fine
case T(_, a, k, subtreeMax, b) =>
if (Interval.max(k) > subtreeMax) {
fail("Invariant check failed: tree has inconsistent interval and max")
}
parentMax.foreach { m =>
if (subtreeMax > m) fail("Invariant check failed: tree has a subtree with a max larger than the parent's max")
}
traverse(a, Some(subtreeMax))
traverse(b, Some(subtreeMax))
}
}
traverse(tree, None)
}
redBlackTreeInvariant1(tree.t)
redBlackTreeInvariant2(tree.t)
intervalTreeInvariant1(tree.t)
intervalTreeInvariant2(tree.t)
}
case class TomsInterval(override val e1: Int, override val e2: Int) extends Interval[Int]
case class MarksInterval(override val e1: Int, override val e2: Int) extends Interval[Int]
case class NamedInterval(override val e1: Int, override val e2: Int, name: String) extends Interval[Int]
// this set of intervals is borrowed from the interval tree example in CLR ...
val Intervals = Seq(TomsInterval(0, 3),
TomsInterval(5, 8),
TomsInterval(6, 10),
TomsInterval(8, 9),
TomsInterval(15, 23),
TomsInterval(16, 21),
TomsInterval(17, 19),
TomsInterval(19, 20),
TomsInterval(25, 30),
TomsInterval(26, 26))
"Interval.min" should "find the smallest endpoint in an interval" in {
Interval.min(TomsInterval(4, 22)) should be (4)
Interval.min(TomsInterval(22, 4)) should be (4)
Interval.min(TomsInterval(0, -3)) should be (-3)
Interval.min(TomsInterval(-3, 1)) should be (-3)
}
"Interval.max" should "find the largest endpoint in an interval" in {
Interval.max(TomsInterval(4, 22)) should be (22)
Interval.max(TomsInterval(22, 4)) should be (22)
Interval.max(TomsInterval(0, -3)) should be (0)
Interval.max(TomsInterval(-3, 1)) should be (1)
}
"Interval.ordered" should "sort the elements in an interval" in {
Interval.ordered(4, 22) should be ((4, 22))
Interval.ordered(TomsInterval(4, 22)) should be ((4, 22))
Interval.ordered(22, 4) should be ((4, 22))
Interval.ordered(TomsInterval(22, 4)) should be ((4, 22))
Interval.ordered(0, -3) should be ((-3, 0))
Interval.ordered(TomsInterval(0, -3)) should be (-3, 0)
Interval.ordered(-3, 0) should be ((-3, 0))
Interval.ordered(TomsInterval(-3, 0)) should be (-3, 0)
}
"Interval.overlaps" should "determine whether two intervals overlap" in {
val i1 = TomsInterval(4, 9)
val i2 = MarksInterval(9, 22)
val i3 = TomsInterval(0, 3)
val i4 = MarksInterval(15, 33)
val i5 = TomsInterval(15, 30)
Interval.overlaps(i1, i2) should be (true)
Interval.overlaps(i2, i1) should be (true)
Interval.overlaps(i1, i3) should be (false)
Interval.overlaps(i1, i4) should be (false)
Interval.overlaps(i2, i3) should be (false)
Interval.overlaps(i2, i4) should be (true)
Interval.overlaps(i3, i4) should be (false)
Interval.overlaps(i4, i5) should be (true)
}
"IntervalTree" should "hold intervals without violating its invariants" in {
// ... because we're perverse, we're going to try all permutations
val intervalPermutations = Intervals.permutations
intervalPermutations.foreach { is =>
val tree = IntervalTree[Int, Interval[Int]](is)
checkInvariants(tree)
}
}
it should "find an overlapping interval" in {
val tree = IntervalTree[Int, TomsInterval](Intervals.sortBy { _ => Random.nextInt() })
val x = tree.search(TomsInterval(24, 26))
x.isDefined should be (true)
x map Interval.ordered[Int] foreach { case (a, b) =>
(a >= 24 && a <= 26) || (b >= 24 && b <= 26) should be (true)
}
}
it should "find all overlapping intervals" in {
val randomizedIntervals: Seq[TomsInterval] = Intervals.sortBy { _ => Random.nextInt() }
val tree = IntervalTree[Int, TomsInterval](randomizedIntervals)
val x = tree.overlapping(TomsInterval(24, 26))
x should have size 2
x should contain (TomsInterval(25, 30))
x should contain (TomsInterval(26, 26))
val y = tree.overlapping(TomsInterval(9, 15))
if (y.size != 3) {
println(randomizedIntervals)
}
y should have size 3
y should contain (TomsInterval(8, 9))
y should contain (TomsInterval(6, 10))
y should contain (TomsInterval(15, 23))
val z = tree.overlapping(TomsInterval(27, 29))
z should have size 1
z should contain (TomsInterval(25, 30))
val zz = tree.overlapping(MarksInterval(12, 14))
zz should have size 0
}
it should "find overlapping intervals even if it needs to try both sides of the tree" in {
val brokenIntervalSets = Seq(List(TomsInterval(17,19),
TomsInterval(0,3),
TomsInterval(16,21),
TomsInterval(26,26),
TomsInterval(5,8),
TomsInterval(19,20),
TomsInterval(6,10),
TomsInterval(8,9),
TomsInterval(25,30),
TomsInterval(15,23)),
List(TomsInterval(19,20),
TomsInterval(26,26),
TomsInterval(0,3),
TomsInterval(5,8),
TomsInterval(25,30),
TomsInterval(6,10),
TomsInterval(16,21),
TomsInterval(8,9),
TomsInterval(17,19),
TomsInterval(15,23)))
brokenIntervalSets.foreach { brokenIntervals =>
val tree = IntervalTree[Int, TomsInterval](brokenIntervals)
val o = tree.overlapping(MarksInterval(9, 15))
o should have size 3
o should contain (TomsInterval(8, 9))
o should contain (TomsInterval(6, 10))
o should contain (TomsInterval(15, 23))
}
}
it should "store multiple intervals with the same span" in {
val intervals = Seq(NamedInterval(4, 9, "4 to 9"),
NamedInterval(1, 4, "one to four"),
NamedInterval(4, 9, "four to nine"),
NamedInterval(0, 3, "foo"),
NamedInterval(-3, 12, "big"))
val tree = IntervalTree[Int, NamedInterval](intervals)
val o = tree.overlapping(MarksInterval(5, 8))
o should have size 3
o should contain (NamedInterval(4, 9, "4 to 9"))
o should contain (NamedInterval(4, 9, "four to nine"))
o should contain (NamedInterval(-3, 12, "big"))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment