Skip to content

Instantly share code, notes, and snippets.

@jenglert
Created July 9, 2013 20:19
Show Gist options
  • Save jenglert/5960901 to your computer and use it in GitHub Desktop.
Save jenglert/5960901 to your computer and use it in GitHub Desktop.
Scala handwriting analysis
import scala.actors.Future
import scala.actors.Futures
import java.util.concurrent.atomic.AtomicInteger
object Main extends App {
val start = System.currentTimeMillis()
val sample = scala.io.Source.fromFile("src/main/scala/digitssample.csv")
val sampleRows = sample.getLines.drop(1).map(Row.fromLine(_)).toList
println("Done reading sample file.")
val check = scala.io.Source.fromFile("src/main/scala/digitscheck.csv")
val checkRows = check.getLines.drop(1).map(Row.fromLine(_)).toList
var correct = new AtomicInteger
val total = checkRows.size
val futures = checkRows.map { rowToCheck =>
Futures.future {
val result = sampleRows.foldLeft(("NONE", Int.MaxValue)) { (acc, row) =>
val dist = rowToCheck.distance(row)
if (dist < acc._2) {
(row.label, dist)
}
else {
acc
}
}
if (rowToCheck.label == result._1) {
correct.addAndGet(1)
}
println("Expected: " + rowToCheck.label + " Result: " + result._1)
}
}
Futures.awaitAll(100000, futures.toArray:_*)
println("Correct: " + correct.get + " out of: " + total + " in: " + (System.currentTimeMillis - start))
}
case class Row(label: String, data: Array[Int]) {
def distance(row: Row): Int = {
data.zip(row.data).map( p => p._1 - p._2).map(x => x * x).sum
}
}
object Row {
def fromLine(line: String) = {
val rows = line.split(",")
new Row(rows.head, rows.tail.map(_.toInt))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment