Skip to content

Instantly share code, notes, and snippets.

@geoHeil
Last active January 12, 2017 16:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save geoHeil/6a23d18ccec085d486165089f9f430f2 to your computer and use it in GitHub Desktop.
Save geoHeil/6a23d18ccec085d486165089f9f430f2 to your computer and use it in GitHub Desktop.
spark window function replacement problem
import java.sql.Date
import org.apache.log4j.{ Level, Logger }
import org.apache.spark.SparkConf
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
case class FooBar(foo: Option[Date], bar: String)
object WindowFunctionExample extends App {
Logger.getLogger("org").setLevel(Level.WARN)
val conf: SparkConf = new SparkConf()
.setAppName("foo")
.setMaster("local[*]")
.set("spark.default.parallelism", "12") // prototyping on macbook 4 real cores apparently 3* real-cores is good
val spark: SparkSession = SparkSession
.builder()
.config(conf)
.getOrCreate()
import spark.implicits._
val myDf = Seq(("2016-01-01", "first"), ("2016-01-02", "second"),
("2016-wrongFormat", "noValidFormat"),
("2016-01-04", "lastAssumingSameDate"))
.toDF("foo", "bar")
.withColumn("foo", 'foo.cast("Date"))
.as[FooBar]
myDf.show
def notMissing(row: Option[FooBar]): Boolean = row.isDefined && row.get.foo.isDefined
// myDf.rdd.filter(x => notMissing(Some(x))).count
// println(myDf.rdd.filter(x => notMissing(Some(x))).count)
val toCarry = myDf.rdd.mapPartitionsWithIndex { case (i, iter) => Iterator((i, iter.filter(x => notMissing(Some(x))).toSeq.lastOption)) }.collectAsMap
// println("###################### carry ")
// println(toCarry)
println(toCarry.foreach(println))
// println("###################### carry ")
val toCarryBd = spark.sparkContext.broadcast(toCarry)
def fill(i: Int, iter: Iterator[FooBar]): Iterator[FooBar] = {
if (iter.isEmpty) {
iter
} else {
var lastNotNullRow: Option[FooBar] = toCarryBd.value.get(i).get
while (lastNotNullRow == None) {
println("choosing next value")
lastNotNullRow = toCarryBd.value.get(i + 1).get
}
iter.map(foo => {
// println("original " + foo)
if ( /*(lastNotNullRow != None) &&*/ (!notMissing(Some(foo)) /*&& (foo.bar != None)*/ )) {
// println("replaced")
// this will go into the default case
// FooBar(lastNotNullRow.getOrElse(FooBar(Option(Date.valueOf("2016-01-01")), "DUMMY")).foo, foo.bar)
// println("last not null value was " + lastNotNullRow)
FooBar(lastNotNullRow.get.foo, foo.bar) // TODO warning this throws an error
} else {
lastNotNullRow = Some(foo)
foo
}
})
}
}
val imputed: RDD[FooBar] = myDf.rdd.mapPartitionsWithIndex { case (i, iter) => fill(i, iter) }
val imputedDF = imputed.toDS()
// println(imputedDF.orderBy($"foo").collect.toList)
imputedDF.show
spark.stop
}
@geoHeil
Copy link
Author

geoHeil commented Dec 9, 2016

The input is

----------+--------------------+
|       foo|                 bar|
+----------+--------------------+
|2016-01-01|               first|
|2016-01-02|              second|
|      null|       noValidFormat|
|2016-01-04|lastAssumingSameDate|
+----------+--------------------+

and output

+----------+--------------------+
|       foo|                 bar|
+----------+--------------------+
|2016-01-01|               first|
|2016-01-02|              second|
|2016-01-01|       noValidFormat|
|2016-01-04|lastAssumingSameDate|
+----------+--------------------+

as you can see noValidFormat is replaced by 2016-01-01. It should have been 2016-02-02 because that was the last "good known value" and the following FooBar(lastNotNullRow.getOrElse(FooBar(Date.valueOf("2016-01-01"), "DUMMY")).foo, foo.bar) is using 2016-01-01 just as a dummy replacement value.

A deeper look at the map shows that None values are present. How can this happen if iter.filter(notMissing(_) is used in order to build the toCarry Map.

###################### carry 
Map(2 -> None, 5 -> None, 4 -> None, 7 -> Some(FooBar(2016-01-04,lastAssumingSameDate)), 1 -> Some(FooBar(2016-01-01,first)), 3 -> Some(FooBar(2016-01-02,second)), 6 -> None, 0 -> None)
(2,None)
(5,None)
(4,None)
(7,Some(FooBar(2016-01-04,lastAssumingSameDate)))
(1,Some(FooBar(2016-01-01,first)))
(3,Some(FooBar(2016-01-02,second)))
(6,None)
(0,None)
()
###################### carry 

@geoHeil
Copy link
Author

geoHeil commented Jan 9, 2017

With the last update the code sort of works, but still the wrong date is chosen for imputation

+----------+--------------------+
|       foo|                 bar|
+----------+--------------------+
|2016-01-01|               first|
|2016-01-02|              second|
|      null|       noValidFormat|
|2016-01-04|lastAssumingSameDate|
+----------+--------------------+

+----------+--------------------+
|       foo|                 bar|
+----------+--------------------+
|2016-01-01|               first|
|2016-01-02|              second|
|2016-01-04|       noValidFormat|
|2016-01-04|lastAssumingSameDate|
+----------+--------------------+

You see that noValidFormat should have been filled with 2016-01-02 to be filled with last good known value (forward fill)
Main change was to add:

var lastNotNullRow: Option[FooBar] = toCarryBd.value.get(i).get
      if (lastNotNullRow == None) {
        lastNotNullRow = toCarryBd.value.get(i + 1).get
      }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment