Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@geoHeil
Last active November 28, 2016 15:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save geoHeil/8dc7a6b6938a517f068e7fd6a981ed12 to your computer and use it in GitHub Desktop.
Save geoHeil/8dc7a6b6938a517f068e7fd6a981ed12 to your computer and use it in GitHub Desktop.
Spark convert estimator to transformer
// Copyright (C) 2016 Georg Heiler
// master thesis to detect fraud of never paying customers
package org.apache.spark.ml.feature
import org.apache.log4j.{Level, Logger}
import org.apache.spark.SparkConf
import org.apache.spark.ml.param.{Param, ParamMap, Params}
import org.apache.spark.ml.util._
import org.apache.spark.ml.{Estimator, Model, Transformer}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import scala.language.postfixOps
trait PreprocessingParam2s extends Params {
final val isInList = new Param[Array[String]](this, "isInList", "list of isInList items")
}
class ExampleTrans(override val uid: String) extends Transformer with PreprocessingParam2s {
def this() = this(Identifiable.randomUID("testingParameter Access"))
def copy(extra: ParamMap): ExampleTrans = {
defaultCopy(extra)
}
def setIsInList(value: Array[String]): this.type = set(isInList, value)
override def transformSchema(schema: StructType): StructType = {
val idx = schema.fieldIndex("ISO")
val field = schema.fields(idx)
if (field.dataType != StringType) {
throw new Exception(s"Input type ${field.dataType} did not match input type StringType")
}
schema.add(StructField("isInList", IntegerType, false))
}
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
import dataset.sparkSession.implicits._
dataset.withColumn("isInList", when('ISO isin ($(isInList): _*), 1).otherwise(0))
}
}
class ExampleEstimator(override val uid: String) extends Estimator with PreprocessingParam2s {
def this() = this(Identifiable.randomUID("testingParameter Access"))
def copy(extra: ParamMap): ExampleEstimator = defaultCopy(extra)
def setIsInList(value: Array[String]): this.type = {
set(isInList, value)
}
override def transformSchema(schema: StructType): StructType = {
val idx = schema.fieldIndex("ISO")
val field = schema.fields(idx)
if (field.dataType != StringType) {
throw new Exception(s"Input type ${field.dataType} did not match input type StringType")
}
schema
.add(StructField("isInList", IntegerType, false))
.add(StructField("someField", DoubleType, false))
}
//in reality perform some computation here
override def fit(dataset: Dataset[_]): ExampleTransModel = new ExampleTransModel(uid, 1.0)
}
class ExampleTransModel(
override val uid: String,
val someValue: Double
)
extends Model[ExampleTransModel] with PreprocessingParam2s {
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
import dataset.sparkSession.implicits._
dataset
.withColumn("isInList", when('ISO isin ($(isInList): _*), 1).otherwise(0))
.withColumn("someField", when('ISO, "fooBar"))
}
override def transformSchema(schema: StructType): StructType = {
val idx = schema.fieldIndex("ISO")
val field = schema.fields(idx)
if (field.dataType != StringType) {
throw new Exception(s"Input type ${field.dataType} did not match input type StringType")
}
schema
.add(StructField("isInList", IntegerType, false))
.add(StructField("someField", DoubleType, false))
}
override def copy(extra: ParamMap): ExampleTransModel = defaultCopy(extra)
}
object Foo extends App {
Logger.getLogger("org").setLevel(Level.WARN)
val conf: SparkConf = new SparkConf()
.setAppName("example trans")
.setMaster("local[*]")
.set("spark.executor.memory", "2G")
.set("spark.executor.cores", "4")
.set("spark.default.parallelism", "4")
.set("spark.driver.memory", "1G")
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
val spark: SparkSession = SparkSession
.builder()
.config(conf)
.getOrCreate()
import spark.implicits._
val dates = Seq(
("2016-01-01", "ABC"),
("2016-01-02", "ABC"),
("2016-01-03", "POL"),
("2016-01-04", "ABC"),
("2016-01-05", "POL"),
("2016-01-06", "ABC"),
("2016-01-07", "POL"),
("2016-01-08", "ABC"),
("2016-01-09", "def"),
("2016-01-10", "ABC")
).toDF("dates", "ISO")
dates.show
new ExampleTrans().setIsInList(Array("def", "ABC")).transform(dates).show
new ExampleEstimator().setIsInList(Array("def", "ABC")).fit(dates).transform(dates).show
spark.stop
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment