Skip to content

Instantly share code, notes, and snippets.

@SethTisue
Last active May 20, 2021 14:40
Show Gist options
  • Save SethTisue/2c84c855221bc5a31e129226ade2cb81 to your computer and use it in GitHub Desktop.
Save SethTisue/2c84c855221bc5a31e129226ade2cb81 to your computer and use it in GitHub Desktop.

From scala.xml.pull to javax.xml.stream.events

Want to process XML that's too large to fit in memory?

The Scala standard library used to offer scala.xml.pull for this. It became part of the scala-xml library when scala-xml became separate. But then scala.xml.pull got deprecated (in scala.xml 1.1.1) and finally removed entirely (in scala-xml 2.0.0-M1). The recommended replacement is javax.xml.stream.events.

I had some old code that used scala.xml.pull to digest my iTunes Music Library.xml and print various statistics. Recently I converted it to Scala 3, so I decided to get off the deprecated API at the same time.

So, here is the before-and-after. Perhaps this will help other users of scala.xml.pull who want to convert.

Before (Scala 2 + scala.xml.pull)

/// nicer API/DSL for dealing with XMLEventReader

class Reader(source: io.Source) {
  import scala.xml.pull._
  private val reader =
    new XMLEventReader(source) { override val preserveWS = false }
  private val it: collection.BufferedIterator[XMLEvent] = reader.buffered
  def start() = it.next() match { case EvElemStart(_, s, _, _) => s }
  def start(s: String) = { it.next() match { case EvElemStart(_, `s`, _, _) => } }
  def end() = it.next() match { case EvElemEnd(_, s) => s}
  def end(s: String) = { it.next() match { case EvElemEnd(_, `s`) => } }
  // it would be nice to use BufferedIterator.takeWhile here, but it advances the iterator one too
  // far; see https://issues.scala-lang.org/browse/SI-3581 - ST 7/15/10
  def text() = Iterator.continually(it.head)
                 .takeWhile(_.isInstanceOf[EvText])
                 .collect{case EvText(x) => it.next(); x}
                 .mkString
  def slurp() = (start(), text(), end())._2
  def slurp(s: String) = (start(s), text(), end(s))._2
  def atEnd = it.head.isInstanceOf[EvElemEnd]
  def atEnd(s: String) = it.head match {
    case EvElemEnd(_, `s`) => true
    case _ => false
  }
  def stop() = reader.stop()
}

After (Scala 3 + javax.xml.stream.events)

import javax.xml.stream.{ XMLInputFactory, XMLEventReader, XMLStreamConstants }
import javax.xml.stream.events.XMLEvent

/** nicer API/DSL for dealing with XMLEventReader */
class Reader(stream: java.io.InputStream):

  private val reader: XMLEventReader = {
    val factory = XMLInputFactory.newFactory
    val reader = factory.createXMLEventReader(stream)
    factory.createFilteredReader(reader,
      ev => !ev.isStartDocument &&
        ev.getEventType != XMLStreamConstants.DTD && !(ev.isCharacters && ev.asCharacters.isIgnorableWhiteSpace))
  }

  private object it extends collection.BufferedIterator[XMLEvent]:
    override def hasNext = reader.hasNext
    override def head = reader.peek()
    override def next() = reader.nextEvent()

  def close() =
    reader.close()

  def start() =
    it.next().asStartElement.getName
  def start(s: String) =
    require(s == it.next().asStartElement.getName.getLocalPart)

  def end() =
    it.next().asEndElement.getName
  def end(s: String) =
    require(s == it.next().asEndElement.getName.getLocalPart)

  def atEnd =
    it.head.isEndElement
  def atEnd(s: String) =
    it.head.isEndElement && s == it.head.asEndElement.getName.getLocalPart

  def text() =
    // it would be nice to use BufferedIterator.takeWhile here, but it advances the iterator one too
    // far; see https://issues.scala-lang.org/browse/SI-3581 - ST 7/15/10
    Iterator.continually(it.head)
      .takeWhile(_.isCharacters)
      .map{ev => it.next(); ev.asCharacters.getData}
      .mkString

  def slurp() =
    (start(), text(), end())._2
  def slurp(s: String) =
    (start(s), text(), end(s))._2

Client code

Here's an example of actually using this code to process some XML. I've only included the Scala 3 version, since it barely needed to change.

case class Track(
  name: String,
  artist: String,
  album: String,
  trackNumber: Int,
  time: Long,
  plays: Int,
  lastPlayed: Long,
  stars: Int,
)

case class Library(
  tracks: Set[Track],
  trackNames: Map[String, Long],
  artists: Map[String, Long],
  artistLengths: Map[String, Long],
  albums: Map[String, Long],
  albumLengths: Map[String, Long],
  playDates: Map[String, Long],
)

def foreachTrack(stream: java.io.InputStream)(fn: Track => Unit): Unit =
  val reader = new Reader(stream)
  import reader._
  start("plist")
  start("dict")
  while(slurp("key") != "Tracks")
    slurp()
  start("dict")
  while(!atEnd) do
    slurp("key")
    start("dict")
    val entries =
      val temp = collection.mutable.Map[String, String]()
      while (!atEnd("dict"))
        temp += slurp("key") -> slurp()
      end("dict")
      temp.toMap
    try
      fn(Track(
        name = entries("Name"),
        trackNumber = entries.getOrElse("Track Number", "0").toInt,
        artist = entries.getOrElse("Sort Artist", entries("Artist")),
        album = entries.getOrElse("Album", ""),
        time = entries("Total Time").toLong,
        plays = entries.getOrElse("Play Count", "0").toInt,
        lastPlayed = entries.getOrElse("Play Date", "0").toLong,
        stars = entries.getOrElse("Rating", "0").toInt / 20,
      ))
    catch
      case _: java.util.NoSuchElementException =>
        if !skippable(entries)
        then println(entries)
  reader.close()

def skippable(entries: Map[String, String]): Boolean =
  !entries.isDefinedAt("Kind") ||
    entries("Kind").endsWith(" app") ||
    entries("Kind").endsWith(" book") ||
    entries("Kind") == "Book" ||
    entries("Genre") == "Podcast"

def read(stream: java.io.InputStream): Library =
  def newMap = collection.mutable.Map[String, Long]()
  val (trackNames, artists, artistLengths, albums, albumLengths, playDates) = (newMap, newMap, newMap, newMap, newMap, newMap)
  val tracks = collection.mutable.Set[Track]()
  def loop(): Unit =
    foreachTrack(stream){ track =>
      tracks += track
      import track._
      trackNames(artist + " - " + name) = time * plays
      artists(artist) = time * plays + artists.getOrElse(artist, 0L)
      artistLengths(artist) = time + artistLengths.getOrElse(artist, 0L)
      val artistAlbum = artist + " - " + album
      albums(artistAlbum) = time * plays + albums.getOrElse(artistAlbum, 0L)
      albumLengths(artistAlbum) = time + albumLengths.getOrElse(artistAlbum, 0L)
      if !playDates.isDefinedAt(artist) || playDates(artist) < lastPlayed
      then playDates(artist) = lastPlayed
      // if artists.size >= 200 then return
    }
  loop()
  Library(tracks.toSet,
    trackNames.toMap, artists.toMap, artistLengths.toMap,
    albums.toMap, albumLengths.toMap, playDates.toMap)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment