Last active
November 28, 2016 15:02
-
-
Save geoHeil/8dc7a6b6938a517f068e7fd6a981ed12 to your computer and use it in GitHub Desktop.
Spark convert estimator to transformer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Copyright (C) 2016 Georg Heiler | |
// master thesis to detect fraud of never paying customers | |
package org.apache.spark.ml.feature | |
import org.apache.log4j.{Level, Logger} | |
import org.apache.spark.SparkConf | |
import org.apache.spark.ml.param.{Param, ParamMap, Params} | |
import org.apache.spark.ml.util._ | |
import org.apache.spark.ml.{Estimator, Model, Transformer} | |
import org.apache.spark.sql.functions._ | |
import org.apache.spark.sql.types._ | |
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} | |
import scala.language.postfixOps | |
trait PreprocessingParam2s extends Params { | |
final val isInList = new Param[Array[String]](this, "isInList", "list of isInList items") | |
} | |
class ExampleTrans(override val uid: String) extends Transformer with PreprocessingParam2s { | |
def this() = this(Identifiable.randomUID("testingParameter Access")) | |
def copy(extra: ParamMap): ExampleTrans = { | |
defaultCopy(extra) | |
} | |
def setIsInList(value: Array[String]): this.type = set(isInList, value) | |
override def transformSchema(schema: StructType): StructType = { | |
val idx = schema.fieldIndex("ISO") | |
val field = schema.fields(idx) | |
if (field.dataType != StringType) { | |
throw new Exception(s"Input type ${field.dataType} did not match input type StringType") | |
} | |
schema.add(StructField("isInList", IntegerType, false)) | |
} | |
override def transform(dataset: Dataset[_]): DataFrame = { | |
transformSchema(dataset.schema, logging = true) | |
import dataset.sparkSession.implicits._ | |
dataset.withColumn("isInList", when('ISO isin ($(isInList): _*), 1).otherwise(0)) | |
} | |
} | |
class ExampleEstimator(override val uid: String) extends Estimator with PreprocessingParam2s { | |
def this() = this(Identifiable.randomUID("testingParameter Access")) | |
def copy(extra: ParamMap): ExampleEstimator = defaultCopy(extra) | |
def setIsInList(value: Array[String]): this.type = { | |
set(isInList, value) | |
} | |
override def transformSchema(schema: StructType): StructType = { | |
val idx = schema.fieldIndex("ISO") | |
val field = schema.fields(idx) | |
if (field.dataType != StringType) { | |
throw new Exception(s"Input type ${field.dataType} did not match input type StringType") | |
} | |
schema | |
.add(StructField("isInList", IntegerType, false)) | |
.add(StructField("someField", DoubleType, false)) | |
} | |
//in reality perform some computation here | |
override def fit(dataset: Dataset[_]): ExampleTransModel = new ExampleTransModel(uid, 1.0) | |
} | |
class ExampleTransModel( | |
override val uid: String, | |
val someValue: Double | |
) | |
extends Model[ExampleTransModel] with PreprocessingParam2s { | |
override def transform(dataset: Dataset[_]): DataFrame = { | |
transformSchema(dataset.schema, logging = true) | |
import dataset.sparkSession.implicits._ | |
dataset | |
.withColumn("isInList", when('ISO isin ($(isInList): _*), 1).otherwise(0)) | |
.withColumn("someField", when('ISO, "fooBar")) | |
} | |
override def transformSchema(schema: StructType): StructType = { | |
val idx = schema.fieldIndex("ISO") | |
val field = schema.fields(idx) | |
if (field.dataType != StringType) { | |
throw new Exception(s"Input type ${field.dataType} did not match input type StringType") | |
} | |
schema | |
.add(StructField("isInList", IntegerType, false)) | |
.add(StructField("someField", DoubleType, false)) | |
} | |
override def copy(extra: ParamMap): ExampleTransModel = defaultCopy(extra) | |
} | |
object Foo extends App { | |
Logger.getLogger("org").setLevel(Level.WARN) | |
val conf: SparkConf = new SparkConf() | |
.setAppName("example trans") | |
.setMaster("local[*]") | |
.set("spark.executor.memory", "2G") | |
.set("spark.executor.cores", "4") | |
.set("spark.default.parallelism", "4") | |
.set("spark.driver.memory", "1G") | |
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") | |
val spark: SparkSession = SparkSession | |
.builder() | |
.config(conf) | |
.getOrCreate() | |
import spark.implicits._ | |
val dates = Seq( | |
("2016-01-01", "ABC"), | |
("2016-01-02", "ABC"), | |
("2016-01-03", "POL"), | |
("2016-01-04", "ABC"), | |
("2016-01-05", "POL"), | |
("2016-01-06", "ABC"), | |
("2016-01-07", "POL"), | |
("2016-01-08", "ABC"), | |
("2016-01-09", "def"), | |
("2016-01-10", "ABC") | |
).toDF("dates", "ISO") | |
dates.show | |
new ExampleTrans().setIsInList(Array("def", "ABC")).transform(dates).show | |
new ExampleEstimator().setIsInList(Array("def", "ABC")).fit(dates).transform(dates).show | |
spark.stop | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
http://stackoverflow.com/questions/40847169/spark-custom-estimator-override-nothing