Created
August 13, 2020 08:46
-
-
Save mostafam/5d2654601283cc14ae61403d659c2a61 to your computer and use it in GitHub Desktop.
StringCheckerOp.scala (Spark side)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.apache.spark.ml.bundle.extension.ops.feature | |
import ml.combust.bundle.BundleContext | |
import ml.combust.bundle.dsl._ | |
import ml.combust.bundle.op.{OpModel, OpNode} | |
import ml.combust.mleap.core.feature.StringCheckerModel | |
import org.apache.spark.ml.bundle.SparkBundleContext | |
import org.apache.spark.ml.mleap.feature.StringChecker | |
/** | |
* Created by mostafam on 6/18/20. | |
*/ | |
class StringCheckerOp extends OpNode[SparkBundleContext, StringChecker, StringCheckerModel] { | |
override val Model: OpModel[SparkBundleContext, StringCheckerModel] = new OpModel[SparkBundleContext, StringCheckerModel] { | |
// the class of the model is needed for when we go to serialize JVM objects | |
override val klazz: Class[StringCheckerModel] = classOf[StringCheckerModel] | |
// a unique name for our op: "string_checker" | |
// this should be the same as for the MLeap transformer serialization | |
override def opName: String = Bundle.BuiltinOps.feature.string_checker | |
override def store(model: Model, obj: StringCheckerModel) | |
(implicit context: BundleContext[SparkBundleContext]): Model = { | |
// add the caseSensitive parameter to the Bundle model that | |
// will be serialized to our MLeap bundle | |
model.withValue("caseSensitive", Value.boolean(obj.caseSensitive)) | |
} | |
override def load(model: Model) | |
(implicit context: BundleContext[SparkBundleContext]): StringCheckerModel = { | |
// retrieve our parameters | |
val caseSensitive = model.value("caseSensitive").getBoolean | |
// reconstruct the model using the parameters | |
StringCheckerModel(caseSensitive) | |
} | |
} | |
override val klazz: Class[StringChecker] = classOf[StringChecker] | |
override def name(node: StringChecker): String = node.uid | |
override def model(node: StringChecker): StringCheckerModel = node.model | |
override def load(node: Node, model: StringCheckerModel) | |
(implicit context: BundleContext[SparkBundleContext]): StringChecker = { | |
new StringChecker(uid = node.name, model = model). | |
setInputCols(node.shape.inputs(0).asInstanceOf[String], node.shape.inputs(1).asInstanceOf[String]). | |
setOutputCol(node.shape.standardOutput.asInstanceOf[String]) | |
} | |
override def shape(node: StringChecker)(implicit context: BundleContext[SparkBundleContext]): NodeShape = | |
NodeShape().withInput("text", node.getInputCols(0)). | |
withInput("query", node.getInputCols(1)). | |
withOutput("output", node.getOutputCol) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment