Created
March 6, 2024 07:04
-
-
Save makenowjust/18e70035dcf6dba6c6a7791893637c8f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package rbtree | |
import scala.annotation.tailrec | |
// References: | |
// - https://abhiroop.github.io/Haskell-Red-Black-Tree/ | |
// - https://github.com/Abhiroop/okasaki/blob/b4e8b6261cf9c44b7b273116be3da6efde76232d/src/RedBlackTree.hs | |
enum Color: | |
case R, B | |
enum Tree[+A]: | |
case E | |
case T(color: Color, left: Tree[A], key: A, right: Tree[A]) | |
object Tree: | |
import Color.* | |
def empty[A]: Tree[A] = Tree.E | |
private def balance[A](color: Color, left: Tree[A], key: A, right: Tree[A]): Tree.T[A] = | |
(color, left, key, right) match | |
case (B, T(R, T(R, a, x, b), y, c), z, d) => T(R, T(B, a, x, b), y, T(B, c, z, d)) | |
case (B, T(R, a, x, T(R, b, y, c)), z, d) => T(R, T(B, a, x, b), y, T(B, c, z, d)) | |
case (B, a, x, T(R, T(R, b, y, c), z, d)) => T(R, T(B, a, x, b), y, T(B, c, z, d)) | |
case (B, a, x, T(R, b, y, T(R, c, z, d))) => T(R, T(B, a, x, b), y, T(B, c, z, d)) | |
case (color, a, x, b) => T(color, a, x, b) | |
extension [A](tree: Tree[A]) | |
@tailrec | |
def contains(value: A)(using Ordering[A]): Boolean = tree match | |
case E => false | |
case T(color, left, key, right) => | |
val o = summon[Ordering[A]].compare(value, key) | |
if o < 0 then left.contains(value) | |
else if o > 0 then right.contains(value) | |
else true | |
def insert(value: A)(using Ordering[A]): Tree[A] = | |
def ins(tree: Tree[A]): Tree.T[A] = tree match | |
case E => T(R, E, value, E) | |
case T(color, left, key, right) => | |
val o = summon[Ordering[A]].compare(value, key) | |
if o < 0 then balance(color, ins(left), key, right) | |
else if o > 0 then balance(color, left, key, ins(right)) | |
else T(color, left, key, right) | |
val T(_, left, key, right) = ins(tree) | |
T(B, left, key, right) | |
def delete(value: A)(using Ordering[A]): Tree[A] = | |
def del(tree: Tree[A]): Tree[A] = tree match | |
case E => E | |
case t @ T(color, left, key, right) => | |
val o = summon[Ordering[A]].compare(value, key) | |
if o < 0 then delL(t) | |
else if o > 0 then delR(t) | |
else fuse(left, right) | |
def delL(tree: Tree.T[A]): Tree.T[A] = tree match | |
case T(_, t1 @ T(B, _, _, _), y, t2) => balL(del(t1), y, t2) | |
case T(_, t1, y, t2) => T(R, del(t1), y, t2) | |
def balL(left: Tree[A], key: A, right: Tree[A]): Tree.T[A] = | |
(left, key, right) match | |
case (T(R, t1, x, t2), y, t3) => T(R, T(B, t1, x, t2), y, t3) | |
case (t1, y, T(B, t2, z, t3)) => balance(B, t1, y, T(R, t2, z, t3)) | |
case (t1, y, T(R, T(B, t2, u, t3), z, t4 @ (T(B, l, value, r)))) => | |
T(R, T(B, t1, y, t2), u, balance(B, t3, z, T(R, l, value, r))) | |
case _ => sys.error("unreachable") | |
def delR(tree: Tree.T[A]): Tree.T[A] = tree match | |
case T(_, t1, y, t2 @ T(B, _, _, _)) => balR(t1, y, del(t2)) | |
case T(_, t1, y, t2) => T(R, t1, y, del(t2)) | |
def balR(left: Tree[A], key: A, right: Tree[A]): Tree.T[A] = | |
(left, key, right) match | |
case (t1, y, T(R, t2, x, t3)) => T(R, t1, y, T(B, t2, x, t3)) | |
case (T(B, t1, z, t2), y, t3) => balance(B, T(R, t1, z, t2), y, t3) | |
case (T(R, T(B, l, value, r), z, T(B, t2, u, t3)), y, t4) => | |
T(R, balance(B, T(R, l, value, r), z, t2), u, T(B, t3, y, t4)) | |
case _ => sys.error("unreachable") | |
def fuse(left: Tree[A], right: Tree[A]): Tree[A] = | |
(left, right) match | |
case (E, t) => t | |
case (t, E) => t | |
case (t1@T(B, _, _, _), T(R, t3, y, t4)) => | |
T(R, fuse(t1, t3), y, t4) | |
case (T(R, t1, x, t2), t3@T(B, _, _, _)) => | |
T(R, t1, x, fuse(t2, t3)) | |
case (T(R, t1, x, t2), T(R, t3, y, t4)) => | |
fuse(t2, t3) match | |
case T(R, s1, z, s2) => T(R, T(R, t1, x, s1), z, T(R, s2, y, t4)) | |
case s => T(R, t1, x, T(R, s, y, t4)) | |
case (T(B, t1, x, t2), T(B, t3, y, t4)) => | |
fuse(t2, t3) match | |
case T(R, s1, z, s2) => T(R, T(B, t1, x, s1), z, T(B, s2, y, t4)) | |
case s => balL(t1, x, T(B, s, y, t4)) | |
del(tree) match | |
case E => E | |
case T(_, left, key, right) => T(B, left, key, right) | |
def toSet: Set[A] = | |
def loop(tree: Tree[A]): Set[A] = tree match | |
case E => Set.empty | |
case T(_, left, key, right) => loop(left).union(loop(right)).union(Set(key)) | |
loop(tree) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment