Last active
November 24, 2016 08:05
-
-
Save geoHeil/6f5007f1c230ab7218e96f28e53bb947 to your computer and use it in GitHub Desktop.
find nearest holiday +- as separate columns
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.sql.Date | |
import org.apache.log4j.{Level, Logger} | |
import org.apache.spark.SparkConf | |
import org.apache.spark.ml.feature.VectorAssembler | |
import org.apache.spark.sql.SparkSession | |
import org.apache.spark.sql.functions._ | |
object Foo extends App { | |
Logger.getLogger("org").setLevel(Level.WARN) | |
val conf: SparkConf = new SparkConf() | |
.setAppName("nearest holidays") | |
.setMaster("local[*]") | |
.set("spark.executor.memory", "2G") | |
.set("spark.executor.cores", "4") | |
.set("spark.default.parallelism", "4") | |
.set("spark.driver.memory", "1G") | |
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") | |
val spark: SparkSession = SparkSession | |
.builder() | |
.config(conf) | |
// .enableHiveSupport() | |
.getOrCreate() | |
import spark.implicits._ | |
val dates = Seq( | |
("2016-01-01", "ABC"), | |
("2016-01-02", "ABC"), | |
("2016-01-03", "POL"), | |
("2016-01-04", "ABC"), | |
("2016-01-05", "POL"), | |
("2016-01-06", "ABC"), | |
("2016-01-08", "ABC"), | |
("2016-01-09", "POL"), | |
("2016-01-07", "POL"), | |
("2016-01-10", "ABC") | |
).toDF("dates", "ISO") | |
.withColumn("dates", 'dates.cast("Date")) | |
dates.show | |
dates.sort('dates).show // TODO could it be optimized if both holidays and dates were sorted? | |
dates.printSchema | |
val holidays = Seq("2016-01-01", "2016-01-07", "2016-02-02") | |
.toDF("holiday") | |
.withColumn("holiday", 'holiday.cast("Date")) | |
.select("holiday").as[Date] | |
.sort("holiday") | |
.collect | |
val hP = spark.sparkContext.broadcast(holidays.zip(holidays.tail)) | |
def geq(d1: Date, d2: Date): Boolean = d1.after(d2) || d1.equals(d2) | |
def leq(d1: Date, d2: Date): Boolean = d1.before(d2) || d1.equals(d2) | |
// not nice nearly tripled the line count | |
val findNearestHolliday = udf((inDate: Date) => { | |
val hP_l = hP.value | |
val dates = hP_l.collectFirst { | |
case (d1, d2) if (geq(inDate, d1) && leq(inDate, d2)) => | |
(Some(ChronoUnit.DAYS.between(d1.toLocalDate, inDate.toLocalDate)), // if case | |
Some(ChronoUnit.DAYS.between(inDate.toLocalDate, d2.toLocalDate))) //else case | |
} | |
dates.getOrElse(if (leq(inDate, hP_l.head._1)) { | |
(None, Some(ChronoUnit.DAYS.between(inDate.toLocalDate, hP_l.head._1.toLocalDate))) | |
} | |
else { | |
(Some(ChronoUnit.DAYS.between(hP_l.last._2.toLocalDate, inDate.toLocalDate)), None) | |
}) | |
}) | |
val withHoliday = dates.withColumn("nearestHollidays", findNearestHolliday('dates)) | |
withHoliday.show | |
withHoliday.printSchema | |
//TODO how can I have them as separate columns? / how to flatten? this select does not work. | |
// withHoliday.select("dates", 'nearestHollidays._1, 'nearestHollidays._2) | |
// TODO get distance - otherwise this does not work | |
// val vectorAssembler = new VectorAssembler() | |
// .setInputCols(withHoliday.columns) | |
// .setOutputCol("features") | |
// | |
// vectorAssembler.transform(withHoliday).show | |
spark.stop | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
http://chat.stackoverflow.com/rooms/128910/discussion-between-georg-heiler-and-evan058