Skip to content

Instantly share code, notes, and snippets.

@mostafam
Created August 13, 2020 08:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mostafam/5d2654601283cc14ae61403d659c2a61 to your computer and use it in GitHub Desktop.
Save mostafam/5d2654601283cc14ae61403d659c2a61 to your computer and use it in GitHub Desktop.
StringCheckerOp.scala (Spark side)
package org.apache.spark.ml.bundle.extension.ops.feature
import ml.combust.bundle.BundleContext
import ml.combust.bundle.dsl._
import ml.combust.bundle.op.{OpModel, OpNode}
import ml.combust.mleap.core.feature.StringCheckerModel
import org.apache.spark.ml.bundle.SparkBundleContext
import org.apache.spark.ml.mleap.feature.StringChecker
/**
* Created by mostafam on 6/18/20.
*/
class StringCheckerOp extends OpNode[SparkBundleContext, StringChecker, StringCheckerModel] {
override val Model: OpModel[SparkBundleContext, StringCheckerModel] = new OpModel[SparkBundleContext, StringCheckerModel] {
// the class of the model is needed for when we go to serialize JVM objects
override val klazz: Class[StringCheckerModel] = classOf[StringCheckerModel]
// a unique name for our op: "string_checker"
// this should be the same as for the MLeap transformer serialization
override def opName: String = Bundle.BuiltinOps.feature.string_checker
override def store(model: Model, obj: StringCheckerModel)
(implicit context: BundleContext[SparkBundleContext]): Model = {
// add the caseSensitive parameter to the Bundle model that
// will be serialized to our MLeap bundle
model.withValue("caseSensitive", Value.boolean(obj.caseSensitive))
}
override def load(model: Model)
(implicit context: BundleContext[SparkBundleContext]): StringCheckerModel = {
// retrieve our parameters
val caseSensitive = model.value("caseSensitive").getBoolean
// reconstruct the model using the parameters
StringCheckerModel(caseSensitive)
}
}
override val klazz: Class[StringChecker] = classOf[StringChecker]
override def name(node: StringChecker): String = node.uid
override def model(node: StringChecker): StringCheckerModel = node.model
override def load(node: Node, model: StringCheckerModel)
(implicit context: BundleContext[SparkBundleContext]): StringChecker = {
new StringChecker(uid = node.name, model = model).
setInputCols(node.shape.inputs(0).asInstanceOf[String], node.shape.inputs(1).asInstanceOf[String]).
setOutputCol(node.shape.standardOutput.asInstanceOf[String])
}
override def shape(node: StringChecker)(implicit context: BundleContext[SparkBundleContext]): NodeShape =
NodeShape().withInput("text", node.getInputCols(0)).
withInput("query", node.getInputCols(1)).
withOutput("output", node.getOutputCol)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment