Instantly share code, notes, and snippets.

Embed
What would you like to do?
import org.apache.spark.ml.PipelineStage
import org.apache.spark.ml.feature.{OneHotEncoder, VectorAssembler, StringIndexer}
import org.apache.spark.mllib.linalg.VectorUDT
import scala.Function.const
import scala.language.{implicitConversions, reflectiveCalls}
/* Spark Pipeline API is kind of sad to use, so let's make a nicer, more compositional API!
*/
/**
* A Col is a just a DataFrame column waiting to be build.
*
* Pass the `build` function the name of the output column to get the pipeline stages to build that column.
* `suggestedName` is the name that should be used for this column if it ends up being an intermediate column.
*
* Notice the type parameter `T`, it allow us to have a stringly typed API on top of pipelines.
*
* def example() = {
* col[String]("cyber_channel") |> stringIndexer |> oneHotEncoder build "output"
*
* val col1 = col[String]("someCol")
* val col2 = stringIndexer(col1)
* val pipeline: Array[PipelineStage] = col2.build("output2")
*
* vectorAssembler(Array(col[String]("someCol"), col[String]("anotherCol"))).build("vector")
* }
*/
case class Col[T](suggestedName: String, build: String => Array[PipelineStage])
object PipelineBuilder {
/**
* An existing column, that is already built (so we return an empty array when asked for the pipeline stages).
*/
def col[T](name: String) = Col[T](name, const(Array()))
/**
* With that definition of Col, a transformation (or any pipeline stage, for that matter) is just a function
* that takes a column and returns another.
*/
type Transfo[S, T] = Col[S] => Col[T]
/**
* This helper function creates a `Transfo` from a PipelineStage constructor (that has setInputCol and setOutputCol).
*/
def mkColTransform[A <: PipelineStage {def setInputCol(c: String): A; def setOutputCol(c: String): A}, S, T]
(a: () => A)(col: Col[S]): Col[T] = {
Col(col.suggestedName + "_" + a().getClass.getSimpleName,
(nextCol: String) => col.build(col.suggestedName) :+ a().setInputCol(col.suggestedName).setOutputCol(nextCol))
}
/**
* This other helper creates a function that returns a `Col` from an array of input col (for `VectorAssembler`, for example)
*/
def mkColsTransform[A <: PipelineStage {def setInputCols(c: Array[String]): A; def setOutputCol(c: String): A}, S, T]
(a: () => A)(cols: Array[Col[S]]): Col[T] = {
val newName = cols.map(_.suggestedName).mkString("_") + "_" + a().getClass.getSimpleName
Col(newName, (nextCol: String) => cols.flatMap(c => c.build(c.suggestedName))
:+ a().setInputCols(cols.map(_.suggestedName)).setOutputCol(nextCol))
}
// Now, we can define some `Transfos`
val stringIndexer: Transfo[String, Double] = mkColTransform(() => new StringIndexer())
val oneHotEncoder: Transfo[Double, VectorUDT] = mkColTransform(() => new OneHotEncoder())
val vectorAssembler: Array[Col[String]] => Col[VectorUDT] = mkColsTransform(() => new VectorAssembler())
// Finally, the `|>` operator copypasted from scalaz, to avoid a scalaz dependency
final class IdOps[A](val self: A) extends AnyVal {
/** Applies `self` to the provided function. The Thrush combinator. */
def |>[B](f: A => B): B =
f(self)
}
implicit def ToIdOps[A](a: A): IdOps[A] = new IdOps(a)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment