Skip to content

Instantly share code, notes, and snippets.

@kirked
Created September 1, 2017 19:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kirked/49d3a63a14b344dfab9cde02b5f99c27 to your computer and use it in GitHub Desktop.
Save kirked/49d3a63a14b344dfab9cde02b5f99c27 to your computer and use it in GitHub Desktop.
Basic tree data structure in Scala
sealed trait Tree[+A] {
def value: A
def children: Vector[Tree[A]]
def get: Option[A]
def isEmpty: Boolean
final def nonEmpty: Boolean = !isEmpty
def isLeaf: Boolean
final def nonLeaf: Boolean = !isLeaf
def contains[B >: A](a: B): Boolean
/** Return a depth-first node vector for this tree */
def nodes: Vector[A]
//------ Traversal
def map[B](f: A => B): Tree[B]
def flatMap[B](f: A => Tree[B]): Tree[B]
def filter(predicate: A => Boolean): Tree[A]
def filterNot(predicate: A => Boolean): Tree[A]
def flatten: Tree[A]
def foreach(f: A => Unit): Unit
def foldLeft[B](zero: B)(f: (B, A) => B): B
//------ Manipulation
def +[B >: A](child: Tree[B]): Tree[B]
def -[B >: A](child: Tree[B]): Tree[B]
}
object Tree {
val empty = new Tree[Nothing] {
override def value: Nothing = throw new NoSuchElementException("empty tree")
override val children = Vector.empty[Tree[Nothing]]
override def get: Option[Nothing] = None
override def isEmpty = true
override def isLeaf = true
override def contains[B >: Nothing](a: B): Boolean = false
override def nodes: Vector[Nothing] = Vector.empty
override def filter(predicate: Nothing => Boolean): Tree[Nothing] = this
override def filterNot(predicate: Nothing => Boolean): Tree[Nothing] = this
override def flatMap[B](f: Nothing => Tree[B]) = (this: Tree[B])
override def flatten: Tree[Nothing] = this
override def foreach(f: Nothing => Unit): Unit = ()
override def foldLeft[B](zero: B)(f: (B, Nothing) => B): B = zero
override def map[B](f: Nothing => B): Tree[B] = (this: Tree[B])
override def +[B >: Nothing](child: Tree[B]): Tree[B] = throw new NoSuchElementException("empty tree")
override def -[B >: Nothing](child: Tree[B]): Tree[B] = throw new NoSuchElementException("empty tree")
}
def apply[A](a: A, children: Vector[Tree[A]] = Vector.empty): Tree[A] = Node(a, children)
case class Node[+A](value: A, children: Vector[Tree[A]]) extends Tree[A] {
def get: Option[A] = Some(value)
def isEmpty: Boolean = false
def isLeaf: Boolean = children.isEmpty
def contains[B >: A](a: B): Boolean = if (a == value) true else children.foldLeft(false) { case (acc, node) => acc || node.contains(a) }
/** Return a depth-first node vector for this tree */
def nodes: Vector[A] = this.value +: children.flatMap(_.nodes)
def map[B](f: A => B): Tree[B] = new Node(f(value), children.map(_.map(f)))
def flatMap[B](f: A => Tree[B]): Tree[B] = {
val r = f(value)
new Node(r.value, r.children ++ children.map(_.flatMap(f)))
}
def filter(predicate: A => Boolean): Tree[A] = {
if (predicate(value)) {
val filtered = children.map(_.filter(predicate))
if (filtered == children) this
else new Node(value, filtered)
}
else Tree.empty
}
def filterNot(predicate: A => Boolean): Tree[A] = {
val p: A => Boolean = { a => !predicate(a) }
filter(p)
}
def flatten: Tree[A] = Node(value, children.map(_.flatten).filterNot(_.isEmpty))
def foreach(f: A => Unit): Unit = {
f(value)
children.foreach(_.foreach(f))
}
def foldLeft[B](zero: B)(f: (B, A) => B): B = {
val ff: (B, Tree[A]) => B = { (acc, tree) => tree.foldLeft(acc)(f) }
children.foldLeft(f(zero, value))(ff)
}
def +[B >: A](child: Tree[B]): Tree[B] = if (child.isEmpty) this else new Node(value, children :+ child)
def -[B >: A](child: Tree[B]): Tree[B] =
if (child.isEmpty) this
else if (children.contains(child)) new Node(value, children.filterNot(_ == child))
else new Node(value, children.map(_ - child))
}
}
case class Forest[+A](roots: Seq[Tree[A]])
object Forest {
import Tree.Node
def empty[A]: Forest[A] = new Forest(Seq.empty[Tree[A]])
/**
* Construct a forest from a sequence of node values, given identifier and parent identification functions.
* @param identify A function that provides an ID for the provided node value.
* @param parent A function that provides the parent ID for the provided node value.
* @param values A sequence of `A` values to place into the forest.
* @return A forest of nodes.
*/
def fromValues[A, B](identify: (A) => B, parent: (A) => Option[B])(values: Seq[A]): Forest[A] = {
def mkTree(allChildren: Map[Option[B], Seq[A]])(rootValue: A): Node[A] = {
val id = identify(rootValue)
val children = allChildren.get(Some(id)) match {
case None => Vector.empty
case Some(localChildren) => localChildren map(mkTree(allChildren) _)
}
Node(rootValue, children.toVector)
}
val (nonRoots, roots) = values groupBy(parent) partition(_._1.isDefined)
Forest(roots.head._2.toSeq map(mkTree(nonRoots) _))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment