Skip to content

Instantly share code, notes, and snippets.

@TrustNoOne
Created November 14, 2016 08:22
Show Gist options
  • Save TrustNoOne/01b46830e2f4db141ebb412e23775fbb to your computer and use it in GitHub Desktop.
Save TrustNoOne/01b46830e2f4db141ebb412e23775fbb to your computer and use it in GitHub Desktop.
import cats._
import Tree._
import scala.annotation.tailrec
sealed trait Tree[+A] {
def elem: A
def children: List[Tree[A]]
def flattenNodes = TreeFoldable.foldLeft(this, List.newBuilder[A])((b, x) ⇒ b += x).result
}
object Tree {
// builders
private case class TLeaf[+A] private (elem: A) extends Tree[A] { val children = Nil }
object Leaf {
def apply[A](elem: A): Tree[A] = TLeaf(elem)
def unapply[A](tree: Tree[A]): Option[A] = tree match {
case _: TNode[_] ⇒ None
case TLeaf(elem) ⇒ Some(elem)
}
}
private case class TNode[+A](elem: A, children: List[Tree[A]]) extends Tree[A]
object Node {
def apply[A](elem: A, children: List[Tree[A]]): Tree[A] =
if (children.isEmpty) TLeaf(elem) else TNode(elem, children)
def unapply[A](tree: Tree[A]): Option[(A, List[Tree[A]])] = tree match {
case _: TLeaf[_] ⇒ None
case TNode(elem, children) ⇒ Some((elem, children))
}
}
implicit object TreeMonad extends Monad[Tree] {
override def pure[A](x: A): Tree[A] = Leaf(x)
// stack-unsafe operations, we want them to be stack-safe (tail recursive)
override def flatMap[A, B](fa: Tree[A])(f: (A) ⇒ Tree[B]): Tree[B] = {
val r = f(fa.elem)
Node(r.elem, r.children ++ fa.children.map(ts ⇒ flatMap(ts)(f)))
}
override def tailRecM[A, B](a: A)(f: (A) ⇒ Tree[Either[A, B]]): Tree[B] = flatMap(f(a)) {
case Right(b) ⇒ pure(b)
case Left(nextA) ⇒ tailRecM(nextA)(f)
}
}
implicit object TreeFoldable extends Foldable[Tree] {
override def foldLeft[A, B](fa: Tree[A], b: B)(f: (B, A) ⇒ B): B = {
@tailrec def loop(curr: List[Either[Tree[A], A]], acc: B): B = curr match {
case Nil ⇒ acc
case Left(tree) :: tail ⇒ loop(Right(tree.elem) :: tree.children.map(Left(_)) ::: tail, acc)
case Right(a) :: tail ⇒ loop(tail, f(acc, a))
}
loop(List(Left(fa)), b)
}
override def foldRight[A, B](fa: Tree[A], lb: Eval[B])(f: (A, Eval[B]) ⇒ Eval[B]): Eval[B] = {
def loop(curr: List[Either[Tree[A], A]]): Eval[B] = curr match {
case Nil ⇒ lb
case Left(tree) :: tail ⇒ Eval.defer(loop(Right(tree.elem) :: tree.children.map(Left(_)) ::: tail))
case Right(a) :: tail ⇒ f(a, Eval.defer(loop(tail)))
}
Eval.defer(loop(List(Left(fa))))
}
}
class TreeEq[A](implicit ev: Eq[A]) extends Eq[Tree[A]] {
def eqv(xs: Tree[A], ys: Tree[A]): Boolean = {
@tailrec def loop(xs: List[Tree[A]], ys: List[Tree[A]]): Boolean =
(xs, ys) match {
case (Nil, Nil) ⇒ true
case (x :: tx, y :: ty) if x eq y ⇒ loop(tx, ty)
case (TLeaf(x) :: tx, TLeaf(y) :: ty) if ev.eqv(x, y) ⇒ loop(tx, ty)
case (TNode(x, xch) :: tx, TNode(y, ych) :: ty) if ev.eqv(x, y) && xch.size == ych.size ⇒ loop(xch ::: tx, ych ::: ty)
case _ ⇒ false
}
loop(List(xs), List(ys))
}
}
implicit def treeEq[A: Eq] = new TreeEq[A]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment