Skip to content

Instantly share code, notes, and snippets.

@GrahamLea
Last active December 14, 2015 05:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save GrahamLea/5036541 to your computer and use it in GitHub Desktop.
Save GrahamLea/5036541 to your computer and use it in GitHub Desktop.
A Scala script that takes a CSV as input on stdin, produces a pivot table on two columns using a count() function and outputs the pivot table to stdout.
import collection.mutable
import io.Source
import java.io.FileInputStream
def toIntOption(s: String): Option[Int] = try { Some(s.toInt) } catch { case e: NumberFormatException => None }
val intArgs = args.map(toIntOption).flatten
if (args.length != 2 && intArgs.length != 2) {
System.err.println("usage: <x_column_number> <y_column_number>")
sys.exit(1)
}
// You can change these if you need different behaviour.
val containsColumnHeaders = true
val containsRowHeaders = false
val inputStream = System.in
//val inputStream = new FileInputStream("survey.csv")
val lines = Source.fromInputStream(inputStream).getLines().toBuffer
var cells = lines map { _.split(',').toBuffer }
if (containsColumnHeaders)
cells = cells.drop(1)
if (containsRowHeaders) {
cells = cells map { _.drop(1) }
}
val argToIndex = if (containsRowHeaders) 2 else 1
val (column1Index, column2Index) = (intArgs(0) - argToIndex, intArgs(1) - argToIndex)
val map = new mutable.HashMap[(String, String), Int]
def pairInRow(row: mutable.Buffer[String]): (String, String) = (row(column1Index), row(column2Index))
// You could change the pivot aggregation fuctnion from count() to soemthing else here
// by changing 'rows.size' to some other calculation, e.g. (rows.map{_.toInt}).sum
val pairCounts: Map[(String, String), Int] = (for ((pair, rows) <- (cells groupBy pairInRow)) yield (pair, rows.size)).toMap
val outputColumns = (pairCounts.keys map {_._1}).toSet.toSeq.sorted
val outputRows = (pairCounts.keys map {_._2}).toSet.toSeq.sorted
for (column <- outputColumns) {
print(',')
print(column)
}
println()
for (row <- outputRows) {
print(row)
for (column <- outputColumns) {
print(',')
print(pairCounts.getOrElse((column, row), ""))
}
println()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment