Skip to content

Instantly share code, notes, and snippets.

@geoHeil
Last active November 24, 2016 08:05
Show Gist options
  • Save geoHeil/6f5007f1c230ab7218e96f28e53bb947 to your computer and use it in GitHub Desktop.
Save geoHeil/6f5007f1c230ab7218e96f28e53bb947 to your computer and use it in GitHub Desktop.
find nearest holiday +- as separate columns
import java.sql.Date
import org.apache.log4j.{Level, Logger}
import org.apache.spark.SparkConf
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
object Foo extends App {
Logger.getLogger("org").setLevel(Level.WARN)
val conf: SparkConf = new SparkConf()
.setAppName("nearest holidays")
.setMaster("local[*]")
.set("spark.executor.memory", "2G")
.set("spark.executor.cores", "4")
.set("spark.default.parallelism", "4")
.set("spark.driver.memory", "1G")
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
val spark: SparkSession = SparkSession
.builder()
.config(conf)
// .enableHiveSupport()
.getOrCreate()
import spark.implicits._
val dates = Seq(
("2016-01-01", "ABC"),
("2016-01-02", "ABC"),
("2016-01-03", "POL"),
("2016-01-04", "ABC"),
("2016-01-05", "POL"),
("2016-01-06", "ABC"),
("2016-01-08", "ABC"),
("2016-01-09", "POL"),
("2016-01-07", "POL"),
("2016-01-10", "ABC")
).toDF("dates", "ISO")
.withColumn("dates", 'dates.cast("Date"))
dates.show
dates.sort('dates).show // TODO could it be optimized if both holidays and dates were sorted?
dates.printSchema
val holidays = Seq("2016-01-01", "2016-01-07", "2016-02-02")
.toDF("holiday")
.withColumn("holiday", 'holiday.cast("Date"))
.select("holiday").as[Date]
.sort("holiday")
.collect
val hP = spark.sparkContext.broadcast(holidays.zip(holidays.tail))
def geq(d1: Date, d2: Date): Boolean = d1.after(d2) || d1.equals(d2)
def leq(d1: Date, d2: Date): Boolean = d1.before(d2) || d1.equals(d2)
// not nice nearly tripled the line count
val findNearestHolliday = udf((inDate: Date) => {
val hP_l = hP.value
val dates = hP_l.collectFirst {
case (d1, d2) if (geq(inDate, d1) && leq(inDate, d2)) =>
(Some(ChronoUnit.DAYS.between(d1.toLocalDate, inDate.toLocalDate)), // if case
Some(ChronoUnit.DAYS.between(inDate.toLocalDate, d2.toLocalDate))) //else case
}
dates.getOrElse(if (leq(inDate, hP_l.head._1)) {
(None, Some(ChronoUnit.DAYS.between(inDate.toLocalDate, hP_l.head._1.toLocalDate)))
}
else {
(Some(ChronoUnit.DAYS.between(hP_l.last._2.toLocalDate, inDate.toLocalDate)), None)
})
})
val withHoliday = dates.withColumn("nearestHollidays", findNearestHolliday('dates))
withHoliday.show
withHoliday.printSchema
//TODO how can I have them as separate columns? / how to flatten? this select does not work.
// withHoliday.select("dates", 'nearestHollidays._1, 'nearestHollidays._2)
// TODO get distance - otherwise this does not work
// val vectorAssembler = new VectorAssembler()
// .setInputCols(withHoliday.columns)
// .setOutputCol("features")
//
// vectorAssembler.transform(withHoliday).show
spark.stop
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment