Skip to content

Instantly share code, notes, and snippets.

@timothyklim
Forked from johnynek/TreeList.scala
Created December 29, 2018 03:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save timothyklim/f1b7e842754ae6e821c8f0aa9ece6df8 to your computer and use it in GitHub Desktop.
Save timothyklim/f1b7e842754ae6e821c8f0aa9ece6df8 to your computer and use it in GitHub Desktop.
Implementation of "Purely Functional Random Access Lists" by Chris Okasaki in scala. This gives O(1) cons and uncons, and 2 log_2 N lookup.
package org.bykn.list
import cats.Applicative
import cats.implicits._
/**
* Implementation of "Purely Functional Random Access Lists" by Chris Okasaki.
* This gives O(1) cons and uncons, and 2 log_2 N lookup.
*/
sealed abstract class TreeList[+A] {
def uncons: Option[(A, TreeList[A])]
def cons[A1 >: A](a1: A1): TreeList[A1]
def get(idx: Long): Option[A]
def size: Long
def foldLeft[B](init: B)(fn: (B, A) => B): B
def foldRight[B](fin: B)(fn: (A, B) => B): B
def map[B](fn: A => B): TreeList[B]
def drop(n: Long): TreeList[A]
/**
* Split the list roughly in half
*/
def split: (TreeList[A], TreeList[A])
def ::[A1 >: A](a1: A1): TreeList[A1] = cons(a1)
override def toString: String = {
val strb = new java.lang.StringBuilder
strb.append("TreeList(")
def loop(first: Boolean, l: TreeList[A]): Unit =
l.uncons match {
case None => ()
case Some((h, t)) =>
if (!first) strb.append(", ")
strb.append(h.toString)
loop(false, t)
}
loop(true, this)
strb.append(")")
strb.toString
}
}
object TreeList {
sealed trait Nat {
def value: Int
}
sealed abstract class NatEq[A <: Nat, B <: Nat] {
def subst[F[_ <: Nat]](f: F[A]): F[B]
}
object NatEq {
implicit def refl[A <: Nat]: NatEq[A, A] =
new NatEq[A, A] {
def subst[F[_ <: Nat]](f: F[A]): F[A] = f
}
}
object Nat {
case class Succ[P <: Nat](prev: P) extends Nat {
val value: Int = prev.value + 1
}
case object Zero extends Nat {
def value: Int = 0
}
def maybeEq[N1 <: Nat, N2 <: Nat](n1: N1, n2: N2): Option[NatEq[N1, N2]] =
// I don't see how to prove this in scala, but it is true
if (n1.value == n2.value) Some(NatEq.refl[N1].asInstanceOf[NatEq[N1, N2]])
else None
}
sealed abstract class Tree[+N <: Nat, +A] {
def value: A
def depth: N
def size: Long // this is 2^(depth + 1) - 1
def get(idx: Long): Option[A]
def map[B](fn: A => B): Tree[N, B]
def foldRight[B](fin: B)(fn: (A, B) => B): B
}
case class Root[A](value: A) extends Tree[Nat.Zero.type, A] {
def depth: Nat.Zero.type = Nat.Zero
def size = 1L
def get(idx: Long): Option[A] =
if(idx == 0L) Some(value) else None
def map[B](fn: A => B) = Root(fn(value))
def foldRight[B](fin: B)(fn: (A, B) => B): B = fn(value, fin)
}
case class Balanced[N <: Nat, A](value: A, left: Tree[N, A], right: Tree[N, A]) extends Tree[Nat.Succ[N], A] {
val depth: Nat.Succ[N] = Nat.Succ(left.depth)
val size = 1L + left.size + right.size
def get(idx: Long): Option[A] =
if (idx == 0L) Some(value)
else if (idx <= left.size) left.get(idx - 1)
else right.get(idx - (left.size + 1))
def map[B](fn: A => B) = Balanced[N, B](fn(value), left.map(fn), right.map(fn))
def foldRight[B](fin: B)(fn: (A, B) => B): B = {
val rightB = right.foldRight(fin)(fn)
val leftB = left.foldRight(rightB)(fn)
fn(value, leftB)
}
}
def traverseTree[F[_]: Applicative, A, B, N <: Nat](ta: Tree[N, A], fn: A => F[B]): F[Tree[N, B]] =
ta match {
case Root(a) => fn(a).map(Root(_))
case Balanced(a, left, right) =>
(fn(a), traverseTree(left, fn), traverseTree(right, fn)).mapN { (b, l, r) =>
Balanced(b, l, r)
}
}
private case class Trees[A](treeList: List[Tree[Nat, A]]) extends TreeList[A] {
def cons[A1 >: A](a1: A1): TreeList[A1] =
treeList match {
case h1 :: h2 :: rest =>
def go[N1 <: Nat, N2 <: Nat, A2 <: A](t1: Tree[N1, A2], t2: Tree[N2, A2]): TreeList[A1] =
Nat.maybeEq[N1, N2](t1.depth, t2.depth) match {
case Some(eqv) =>
type T[N <: Nat] = Tree[N, A2]
Trees(Balanced[N2, A1](a1, eqv.subst[T](t1), t2) :: rest)
case None =>
Trees(Root(a1) :: treeList)
}
go(h1, h2)
case lessThan2 => Trees(Root(a1) :: lessThan2)
}
def uncons: Option[(A, TreeList[A])] =
treeList match {
case Nil => None
case Root(a) :: rest => Some((a, Trees(rest)))
case Balanced(a, l, r) :: rest => Some((a, Trees(l :: r :: rest)))
}
def get(idx: Long): Option[A] = {
@annotation.tailrec
def loop(idx: Long, treeList: List[Tree[Nat, A]]): Option[A] =
if (idx < 0L) None
else
treeList match {
case Nil => None
case h :: tail =>
if (h.size <= idx) loop(idx - h.size, tail)
else h.get(idx)
}
loop(idx, treeList)
}
def size: Long = {
@annotation.tailrec
def loop(treeList: List[Tree[Nat, A]], acc: Long): Long =
treeList match {
case Nil => acc
case h :: tail => loop(tail, acc + h.size)
}
loop(treeList, 0L)
}
def foldLeft[B](init: B)(fn: (B, A) => B): B = {
@annotation.tailrec
def loop(init: B, rest: List[Tree[Nat, A]]): B =
rest match {
case Nil => init
case Root(a) :: tail => loop(fn(init, a), tail)
case Balanced(a, l, r) :: rest => loop(fn(init, a), l :: r :: rest)
}
loop(init, treeList)
}
def foldRight[B](fin: B)(fn: (A, B) => B): B =
treeList.reverse.foldLeft(fin) { (b, treea) =>
treea.foldRight(b)(fn)
}
def map[B](fn: A => B) = Trees(treeList.map(_.map(fn)))
def drop(n: Long): TreeList[A] = {
@annotation.tailrec
def loop(n: Long, treeList: List[Tree[Nat, A]]): TreeList[A] =
treeList match {
case Nil => empty
case _ if n == 0L => Trees(treeList)
case h :: tail =>
if (h.size <= n) loop(n - h.size, tail)
else {
h match {
case Root(_) =>
loop(n - 1, tail)
case Balanced(a, l, r) =>
if (n > l.size + 1L) loop(n - l.size - 1L, r :: tail)
else if (n > 1L) loop(n - 1L, l :: r :: tail)
else Trees(l :: r :: tail)
}
}
}
loop(n, treeList)
}
def split: (TreeList[A], TreeList[A]) =
treeList match {
case Nil => (empty, empty)
case Root(_) :: Nil => (this, empty)
case Balanced(a, l, r) :: Nil => (Trees(Root(a) :: l :: Nil), Trees(r :: Nil))
case moreThanOne => (Trees(moreThanOne.init), Trees(moreThanOne.last :: Nil))
}
}
implicit class InvariantTreeList[A](val treeList: TreeList[A]) extends AnyVal {
def traverse[F[_]: Applicative, B](fn: A => F[B]): F[TreeList[B]] =
treeList match {
case Trees(tls) => tls.traverse { tree => traverseTree(tree, fn) }.map(Trees(_))
}
}
val empty: TreeList[Nothing] = Trees[Nothing](Nil)
def fromList[A](list: List[A]): TreeList[A] = {
def loop(rev: List[A], acc: TreeList[A]): TreeList[A] =
rev match {
case Nil => acc
case h :: tail => loop(tail, acc.cons(h))
}
loop(list.reverse, empty)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment