Skip to content

Instantly share code, notes, and snippets.

@kevinwright
Created September 30, 2011 08:34
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kevinwright/1253107 to your computer and use it in GitHub Desktop.
Save kevinwright/1253107 to your computer and use it in GitHub Desktop.
Deeper than flatmap, taking on whole trees in one bite!
def flatTraverse[T](node: Node)(pfn: PartialFunction[Node, TraversableOnce[T]]): Seq[T] = {
traverse(node)(pfn).flatten
}
def traverse[T](node: Node)(pfn: PartialFunction[Node, T]): Seq[T] = {
def inner(n: Node, acc: List[T]): List[T] = n match {
case x if (pfn isDefinedAt x) => pfn(x) :: acc
case e: Elem => (e.child.toList map {inner(_, Nil)}).flatten ::: acc
case _ => acc
}
inner(node, Nil)
}
def transform(n: Node)(pfn: PartialFunction[Node, Node]): Node = n match {
case x if (pfn isDefinedAt x) => pfn(x)
case e: Elem => e.copy(child = e.child.toSeq map { transform(_)(pfn) })
case x => x
}
...
@Test
def shouldFlatTraverse() {
val input =
<a>
<b>
<d>1</d>
<d>2</d>
<d>3</d>
</b>
<c>
<e>
<b>
<d>4</d>
<d>5</d>
</b>
</e>
</c>
<c>
<b>
<d>6</d>
<d>7</d>
<d>8</d>
<d>9</d>
</b>
</c>
</a>
val output = flatTraverse(input) {
case x @ <b>{_*}</b> => x.child collect {case e: Elem => e.text.trim()}
}
output should be (Seq("1", "2", "3", "4", "5", "6", "7", "8", "9"))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment