Skip to content

Instantly share code, notes, and snippets.

@mostafam
Created August 13, 2020 08:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mostafam/bbbd88568c707fa2f33fd1cb62a5c970 to your computer and use it in GitHub Desktop.
Save mostafam/bbbd88568c707fa2f33fd1cb62a5c970 to your computer and use it in GitHub Desktop.
StringChecker.scala (Spark side)
package org.apache.spark.ml.mleap.feature
import ml.combust.mleap.core.feature.StringCheckerModel
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.{HasInputCols, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types._
class StringChecker(override val uid: String,
val model: StringCheckerModel) extends Transformer
with HasInputCols
with HasOutputCol
with MLWritable {
def this(model: StringCheckerModel) = this(uid = Identifiable.randomUID("string_checker"), model = model)
def setInputCols(text: String, query: String): this.type = set(inputCols, Array(text, query))
def setOutputCol(value: String): this.type = set(outputCol, value)
@org.apache.spark.annotation.Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
val stringCheckerUdf = udf {
(text: String, query: String) => model(text, query): Double
}
dataset.withColumn($(outputCol), stringCheckerUdf(dataset($(inputCols)(0)), dataset($(inputCols)(1))))
}
override def copy(extra: ParamMap): Transformer = copyValues(new StringChecker(uid, model), extra)
@DeveloperApi
override def transformSchema(schema: StructType): StructType = {
require(
$(inputCols).forall { c =>
schema(c).dataType.isInstanceOf[StringType]
},
s"Input column must be of type StringType but got ${$(inputCols).map(c => schema(c).dataType)}")
val inputFields = schema.fields
require(!inputFields.exists(_.name == $(outputCol)),
s"Output column ${$(outputCol)} already exists.")
StructType(schema.fields :+ StructField($(outputCol), DoubleType))
}
override def write: MLWriter = new StringChecker.StringCheckerWriter(this)
}
object StringChecker extends MLReadable[StringChecker] {
override def read: MLReader[StringChecker] = new StringCheckerReader
override def load(path: String): StringChecker = super.load(path)
private class StringCheckerWriter(instance: StringChecker) extends MLWriter {
private case class Data(caseSensitive: Boolean)
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: fromTypeName, toTypeName
val model = instance.model
val caseSensitive = model.caseSensitive
val data = Data(caseSensitive)
val dataPath = new Path(path, "data").toString
sparkSession
.createDataFrame(Seq(data))
.repartition(1)
.write
.parquet(dataPath)
}
}
private class StringCheckerReader extends MLReader[StringChecker] {
/** Checked against metadata when loading model */
private val className = classOf[StringChecker].getName
override def load(path: String): StringChecker = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath).select("caseSensitive").head()
val caseSensitive = data.getAs[Boolean](0)
val model = new StringCheckerModel(caseSensitive)
val transformer = new StringChecker(metadata.uid, model)
metadata.getAndSetParams(transformer)
transformer
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment