Created
August 13, 2020 08:34
-
-
Save mostafam/bbbd88568c707fa2f33fd1cb62a5c970 to your computer and use it in GitHub Desktop.
StringChecker.scala (Spark side)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.apache.spark.ml.mleap.feature | |
import ml.combust.mleap.core.feature.StringCheckerModel | |
import org.apache.hadoop.fs.Path | |
import org.apache.spark.annotation.DeveloperApi | |
import org.apache.spark.ml.Transformer | |
import org.apache.spark.ml.param.ParamMap | |
import org.apache.spark.ml.param.shared.{HasInputCols, HasOutputCol} | |
import org.apache.spark.ml.util._ | |
import org.apache.spark.sql.functions._ | |
import org.apache.spark.sql.{DataFrame, Dataset} | |
import org.apache.spark.sql.types._ | |
class StringChecker(override val uid: String, | |
val model: StringCheckerModel) extends Transformer | |
with HasInputCols | |
with HasOutputCol | |
with MLWritable { | |
def this(model: StringCheckerModel) = this(uid = Identifiable.randomUID("string_checker"), model = model) | |
def setInputCols(text: String, query: String): this.type = set(inputCols, Array(text, query)) | |
def setOutputCol(value: String): this.type = set(outputCol, value) | |
@org.apache.spark.annotation.Since("2.0.0") | |
override def transform(dataset: Dataset[_]): DataFrame = { | |
val stringCheckerUdf = udf { | |
(text: String, query: String) => model(text, query): Double | |
} | |
dataset.withColumn($(outputCol), stringCheckerUdf(dataset($(inputCols)(0)), dataset($(inputCols)(1)))) | |
} | |
override def copy(extra: ParamMap): Transformer = copyValues(new StringChecker(uid, model), extra) | |
@DeveloperApi | |
override def transformSchema(schema: StructType): StructType = { | |
require( | |
$(inputCols).forall { c => | |
schema(c).dataType.isInstanceOf[StringType] | |
}, | |
s"Input column must be of type StringType but got ${$(inputCols).map(c => schema(c).dataType)}") | |
val inputFields = schema.fields | |
require(!inputFields.exists(_.name == $(outputCol)), | |
s"Output column ${$(outputCol)} already exists.") | |
StructType(schema.fields :+ StructField($(outputCol), DoubleType)) | |
} | |
override def write: MLWriter = new StringChecker.StringCheckerWriter(this) | |
} | |
object StringChecker extends MLReadable[StringChecker] { | |
override def read: MLReader[StringChecker] = new StringCheckerReader | |
override def load(path: String): StringChecker = super.load(path) | |
private class StringCheckerWriter(instance: StringChecker) extends MLWriter { | |
private case class Data(caseSensitive: Boolean) | |
override protected def saveImpl(path: String): Unit = { | |
// Save metadata and Params | |
DefaultParamsWriter.saveMetadata(instance, path, sc) | |
// Save model data: fromTypeName, toTypeName | |
val model = instance.model | |
val caseSensitive = model.caseSensitive | |
val data = Data(caseSensitive) | |
val dataPath = new Path(path, "data").toString | |
sparkSession | |
.createDataFrame(Seq(data)) | |
.repartition(1) | |
.write | |
.parquet(dataPath) | |
} | |
} | |
private class StringCheckerReader extends MLReader[StringChecker] { | |
/** Checked against metadata when loading model */ | |
private val className = classOf[StringChecker].getName | |
override def load(path: String): StringChecker = { | |
val metadata = DefaultParamsReader.loadMetadata(path, sc, className) | |
val dataPath = new Path(path, "data").toString | |
val data = sparkSession.read.parquet(dataPath).select("caseSensitive").head() | |
val caseSensitive = data.getAs[Boolean](0) | |
val model = new StringCheckerModel(caseSensitive) | |
val transformer = new StringChecker(metadata.uid, model) | |
metadata.getAndSetParams(transformer) | |
transformer | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment