Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Scala implementation of SplayTree
// This implementation based on:
// - https://doi.org/10.1145/3828.3835
sealed trait SplayTree[+K, V] {
val left: SplayTree[K, V]
def key: K
def value: V
val right: SplayTree[K, V]
val isEmpty: Boolean
val size: Int
val top: Option[(K, V)]
def pop: (K, SplayTree[K, V])
def merge[K1 >: K](that: SplayTree[K1, V])(implicit ord: Ordering[K1]): SplayTree[K1, V]
def insert[K1 >: K](key: K1, value: V)(implicit ord: Ordering[K1]): SplayTree[K1, V]
def cut[K1 >: K](pivot: K1)(implicit ord: Ordering[K1]): (SplayTree[K1, V], SplayTree[K1, V])
def splay[K1 >: K](key: K1)(implicit ord: Ordering[K1]): SplayTree[K1, V]
def find[K1 >: K](key: K1)(implicit ord: Ordering[K1]): Option[V]
def remove[K1 >: K](key: K1)(implicit ord: Ordering[K1]): SplayTree[K1, V]
}
sealed case class SplayTreeEmpty[V]() extends SplayTree[Nothing, V] {
val left: SplayTree[Nothing, V] = this
def key: Nothing = throw new IllegalArgumentException("empty")
def value: V = throw new IllegalArgumentException("empty")
val right: SplayTree[Nothing, V] = this
val isEmpty: Boolean = true
val size: Int = 0
val top: Option[Nothing] = None
def pop: (Nothing, SplayTree[Nothing, V]) =
throw new IllegalArgumentException("empty")
override def merge[K1 >: Nothing](that: SplayTree[K1, V])(implicit ord: Ordering[K1]): SplayTree[K1, V] =
that
override def insert[K1 >: Nothing](key: K1, value: V)(implicit ord: Ordering[K1]): SplayTree[K1, V] =
SplayTreeNode(SplayTreeEmpty(), key, value, SplayTreeEmpty())
override def cut[K1 >: Nothing](pivot: K1)(implicit ord: Ordering[K1]): (SplayTree[K1, V], SplayTree[K1, V]) =
(SplayTreeEmpty(), SplayTreeEmpty())
override def splay[K1 >: Nothing](key: K1)(implicit ord: Ordering[K1]): SplayTree[K1, V] =
this
override def find[K1 >: Nothing](key: K1)(implicit ord: Ordering[K1]): Option[V] = None
override def remove[K1 >: Nothing](key: K1)(implicit ord: Ordering[K1]): SplayTree[K1, V] =
this
}
sealed case class SplayTreeNode[K : Ordering, V](left: SplayTree[K, V], key: K, value: V, right: SplayTree[K, V]) extends SplayTree[K, V] {
val isEmpty: Boolean = false
val size: Int = left.size + 1 + right.size
val top: Option[(K, V)] =
if (left.isEmpty) Some((key, value))
else left.top
val pop: (K, SplayTree[K, V]) = {
if (left.isEmpty) (key, right)
else if (left.left.isEmpty) (left.key, SplayTreeNode(left.right, key, value, right))
else {
val (min, tree) = left.left.pop
(min, SplayTreeNode(tree, left.key, left.value, SplayTreeNode(left.right, key, value, right)))
}
}
override def merge[K1 >: K](that: SplayTree[K1, V])(implicit ord: Ordering[K1]): SplayTree[K1, V] = {
val (thatLeft, thatRight) = that cut key
SplayTreeNode(left merge thatLeft, key, value, right merge thatRight)
}
override def insert[K1 >: K](newKey: K1, newValue: V)(implicit ord: Ordering[K1]): SplayTree[K1, V] =
if (ord.gt(key, newKey)) SplayTreeNode(left insert(newKey, newValue), key, value, right) splay newKey
else SplayTreeNode(left, key, value, right insert(newKey, newValue)) splay newKey
override def cut[K1 >: K](pivot: K1)(implicit ord: Ordering[K1]): (SplayTree[K1, V], SplayTree[K1, V]) =
ord.lt(key, pivot) match {
case true if right.isEmpty =>
(this, SplayTreeEmpty())
case true if ord.lt(right.key, pivot) =>
val (small, big) = right.right cut pivot
(SplayTreeNode(SplayTreeNode(left, key, value, right.left), right.key, right.value, small), big)
case true =>
val (small, big) = right.left cut pivot
(SplayTreeNode(left, key, value, small), SplayTreeNode(big, right.key, right.value, right.right))
case false if left.isEmpty =>
(SplayTreeEmpty(), this)
case false if ord.lt(pivot, left.key) =>
val (small, _) = left.left cut pivot
(small, SplayTreeNode(left.left, left.key, left.value, SplayTreeNode(left.right, key, value, right)))
case false =>
val (small, big) = left.right cut pivot
(SplayTreeNode(left.left, left.key, left.value, small), SplayTreeNode(big, key, value, right))
}
override def splay[K1 >: K](key: K1)(implicit ord: Ordering[K1]): SplayTree[K1, V] = {
(left, right) match {
// zig-zig
case (SplayTreeNode(SplayTreeNode(lll, llk, llv, llr), lk, lv, lr), _) if ord.eq(llk, key) =>
SplayTreeNode(lll, llk, llv, SplayTreeNode(llr, lk, lv, SplayTreeNode(lr, key, value, right)))
case (_, SplayTreeNode(rl, rk, rv, SplayTreeNode(rrl, rrk, rrv, rrr))) if ord.eq(rrk, key) =>
SplayTreeNode(SplayTreeNode(SplayTreeNode(left, key, value, rl), rk, rv, rrl), rrk, rrv, rrr)
// zig-zag
case (SplayTreeNode(ll, lk, lv, SplayTreeNode(lrl, lrk, lrv, lrr)), _ ) if ord.eq(lrk, key) =>
SplayTreeNode(SplayTreeNode(ll, lk, lv, lrl), lrk, lrv, SplayTreeNode(lrr, key, value, right))
case (_, SplayTreeNode(SplayTreeNode(rll, rlk, rlv, rlr), rk, rv, rr)) if ord.eq(rk, key) =>
SplayTreeNode(SplayTreeNode(left, key, value, rll), rlk, rlv, SplayTreeNode(rlr, rk, rv, rr))
// zig
case (SplayTreeNode(ll, lk, lv, lr), _) if ord.eq(lk, key) =>
SplayTreeNode(ll, lk, lv, SplayTreeNode(lr, key, value, right))
case (_, SplayTreeNode(rl, rk, rv, rr)) if ord.eq(rk, key) =>
SplayTreeNode(SplayTreeNode(left, key, value, rl), rk, rv, rr)
case _ =>
this
}
}
override def find[K1 >: K](searchKey: K1)(implicit ord: Ordering[K1]): Option[V] =
ord.compare(searchKey, key) match {
case 0 =>
Some(value)
case x if x < 0 =>
left.find(searchKey)
case _ =>
right.find(searchKey)
}
override def remove[K1 >: K](removeKey: K1)(implicit ord: Ordering[K1]): SplayTree[K1, V] = {
ord.compare(removeKey, key) match {
case 0 =>
left merge right
case x if x < 0 =>
SplayTreeNode(left remove removeKey, key, value, right)
case _ =>
SplayTreeNode(left, key, value, right remove removeKey)
}
}
}
object SplayTree {
def empty[A: Ordering, V]: SplayTree[A, V] = SplayTreeEmpty[V]()
def apply[A: Ordering, V](xs: Map[A, V]): SplayTree[A, V] =
xs.foldLeft(empty[A, V]) {
case (tree, (k, v)) => tree insert (k, v)
}
}
import org.scalatest._
class SplayTreeTest extends WordSpec with Matchers {
"SplayTree" should {
"works at simple case" in {
val map = (-666 to 666).map { k =>
(k, k.toString)
}.toMap
val tree = SplayTree(map)
tree.size should be(map.size)
map.foreach { case (key, value) =>
tree.find(key) should be(Some(value))
}
tree.find(777) should be(None)
map.foreach { case (key, _) =>
val removedTree = tree.remove(key)
map.filterNot(_._1 == key).foreach { case (sKey, value) =>
removedTree.find(sKey) should be(Some(value))
}
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment