Skip to content

Instantly share code, notes, and snippets.

@raidery
Last active May 8, 2019 03:14
Show Gist options
  • Save raidery/8efacb809c92c7202c489527b6866917 to your computer and use it in GitHub Desktop.
Save raidery/8efacb809c92c7202c489527b6866917 to your computer and use it in GitHub Desktop.
Custom UnaryTransformer
import org.apache.spark.ml._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.types._
val df = Seq(
(0, "a"), (1, "b"),
(2, "c"), (3, "a"),
(4, "a"), (5, "c"))
.toDF("label", "category")
class UpperTransformer(override val uid: String)
extends UnaryTransformer[String, String, UpperTransformer] {
def this() = this(Identifiable.randomUID("upper"))
override protected def validateInputType(inputType: DataType): Unit = {
require(inputType == StringType)
}
protected def createTransformFunc: String => String = {
_.toUpperCase
}
protected def outputDataType: DataType = StringType
}
val upper = new UpperTransformer
//upper.setInputCol("text").transform(df).show
upper.setInputCol("category").transform(df).show
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment