Skip to content

Instantly share code, notes, and snippets.

@fabrizioc1
Created February 12, 2019 00:01
Show Gist options
  • Save fabrizioc1/0b1647894420daec3ab7b8e648c7f2b1 to your computer and use it in GitHub Desktop.
Save fabrizioc1/0b1647894420daec3ab7b8e648c7f2b1 to your computer and use it in GitHub Desktop.
Example of scala Spark transformer with python wrapper
from pyspark import since, keyword_only
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, Param
from pyspark.ml.util import JavaMLReadable, JavaMLWritable
from pyspark.ml.wrapper import JavaTransformer
class Stemmer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
@keyword_only
def __init__(self, inputCol=None, outputCol=None):
super(Stemmer, self).__init__()
self._java_obj = self._new_java_obj("org.fct.spark.transformer.Stemmer", self.uid)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
def setParams(self, inputCol=None, outputCol=None):
kwargs = self._input_kwargs
return self._set(**kwargs)
package org.fct.spark.transformer
import edu.stanford.nlp.process.Morphology
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.{DataType, StringType, ArrayType}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.UnaryTransformer
class Stemmer(override val uid: String) extends org.apache.spark.ml.UnaryTransformer[Seq[String], Seq[String], Stemmer] {
def this() = this(org.apache.spark.ml.util.Identifiable.randomUID("stemmer"))
override protected def createTransformFunc: Seq[String] => Seq[String] = { strArray =>
val stemmer = new edu.stanford.nlp.process.Morphology()
strArray.map(originStr => stemmer.stem(originStr))
}
override protected def validateInputType(inputType: org.apache.spark.sql.types.DataType): Unit = {
require(inputType == org.apache.spark.sql.types.ArrayType(org.apache.spark.sql.types.StringType), s"Input type must be ArrayType(StringType) but got $inputType.")
}
override protected def outputDataType: org.apache.spark.sql.types.DataType = new org.apache.spark.sql.types.ArrayType(org.apache.spark.sql.types.StringType, false)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment