Skip to content

Instantly share code, notes, and snippets.

@DmitryBe
Created April 28, 2017 04:07
Show Gist options
  • Save DmitryBe/95276932e2f9770368b9f62c748c38d5 to your computer and use it in GitHub Desktop.
Save DmitryBe/95276932e2f9770368b9f62c748c38d5 to your computer and use it in GitHub Desktop.
Parse NY Taxi data with spark
// parsed taxi row
case class Trip(
license: String,
pickupTime: Long,
dropoffTime: Long,
pickupX: Double,
pickupY: Double,
dropoffX: Double,
dropoffY: Double)
// safely access row fields
class RichRow(row: org.apache.spark.sql.Row) {
def getAs[T](field: String): Option[T] = {
if (row.isNullAt(row.fieldIndex(field))) {
None
} else {
Some(row.getAs[T](field))
}
}
}
// parse date
def parseTaxiTime(rr: RichRow, timeField: String): Long = {
val formatter = new SimpleDateFormat(
"yyyy-MM-dd HH:mm:ss", Locale.ENGLISH)
val optDt = rr.getAs[String](timeField)
optDt.map(dt => formatter.parse(dt).getTime).getOrElse(0L)
}
def parseTaxiLoc(rr: RichRow, locField: String): Double = {
rr.getAs[String](locField).map(_.toDouble).getOrElse(0.0)
}
// parse taxy row info
def parse(row: org.apache.spark.sql.Row): Trip = {
val rr = new RichRow(row)
Trip(
license = rr.getAs[String]("hack_license").orNull,
pickupTime = parseTaxiTime(rr, "pickup_datetime"),
dropoffTime = parseTaxiTime(rr, "dropoff_datetime"),
pickupX = parseTaxiLoc(rr, "pickup_longitude"),
pickupY = parseTaxiLoc(rr, "pickup_latitude"),
dropoffX = parseTaxiLoc(rr, "dropoff_longitude"),
dropoffY = parseTaxiLoc(rr, "dropoff_latitude")
)
}
// run f safely
def safe[S, T](f: S => T): S => Either[T, (S, Exception)] = {
new Function[S, Either[T, (S, Exception)]] with Serializable {
def apply(s: S): Either[T, (S, Exception)] = {
try {
Left(f(s))
} catch {
case e: Exception => Right((s, e))
}
}
}
}
// read NY taxi data
val taxiRaw = spark.read.option("header", "true").csv("taxidata")
taxiRaw.show()
// safe parsing
val safeParse = safe(parse)
val taxiParsed = taxiRaw.rdd.map(safeParse)
// good (parsed) records
taxiParsed.map(_.isLeft).
countByValue().
foreach(println)
// cache good records
val taxiGood = taxiParsed.map(_.left.get).toDS
taxiGood.cache()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment