Created
August 13, 2020 08:56
-
-
Save mostafam/3c5f871f748e562d32f326c4007160ca to your computer and use it in GitHub Desktop.
Create Spark PipelineModel and Export it
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
scala> import org.apache.spark.ml.mleap.feature.StringChecker | |
scala> import ml.combust.mleap.core.feature.StringCheckerModel | |
scala> import org.apache.spark.ml.bundle.SparkBundleContext | |
scala> import ml.combust.bundle.BundleFile | |
scala> import ml.combust.mleap.spark.SparkSupport._ | |
scala> import org.apache.spark.ml.{Pipeline, PipelineModel} | |
scala> import org.apache.spark.ml.feature.{StringIndexer,VectorAssembler} | |
scala> import org.apache.spark.sql._ | |
scala> import org.apache.spark.sql.functions._ | |
scala> import resource._ | |
scala> val df = spark.createDataFrame( | |
| Seq((0, "john.doe@gmail.com", "John"), (1, "JackieChan234@xyz.com","jack"), (2, "ping_pong@missed.org","Al")) | |
| ).toDF("id", "email", "first_name") | |
scala> val indexer = new StringIndexer().setInputCol("first_name").setOutputCol("categoryIndex") | |
scala> val stringChecker = new StringChecker(uid = "string_checker", model = new StringCheckerModel(caseSensitive = false)). | |
setInputCols("email", "first_name"). | |
setOutputCol("is_it_there?") | |
scala> val vectorAssembler = new VectorAssembler().setInputCols(Array("categoryIndex", "is_it_there")).setOutputCol("combined") | |
scala> val pipeline = new Pipeline().setStages(Array(indexer, stringChecker, vectorAssembler)) | |
scala> val pipelineModel = pipeline.fit(df) | |
// making sure it works properly | |
scala> pipelineModel.transform(df).show(false) | |
+---+---------------------+----------+-------------+------------+---------+ | |
|id |email |first_name|categoryIndex|is_it_there?|combined | | |
+---+---------------------+----------+-------------+------------+---------+ | |
|0 |john.doe@gmail.com |John |0.0 |1.0 |[0.0,1.0]| | |
|1 |JackieChan234@xyz.com|jack |1.0 |1.0 |[1.0,1.0]| | |
|2 |ping_pong@missed.org |Al |2.0 |0.0 |[2.0,0.0]| | |
+---+---------------------+----------+-------------+------------+---------+ | |
// making sure it serializes properly in Spark | |
scala> pipelineModel.save("pm") | |
// making sure it deserializes properly in Spark | |
scala> val pipelineModel2 = PipelineModel.load("pm") | |
scala> pipelineModel2.transform(df).show(false) | |
+---+---------------------+----------+-------------+------------+---------+ | |
|id |email |first_name|categoryIndex|is_it_there?|combined | | |
+---+---------------------+----------+-------------+------------+---------+ | |
|0 |john.doe@gmail.com |John |0.0 |1.0 |[0.0,1.0]| | |
|1 |JackieChan234@xyz.com|jack |1.0 |1.0 |[1.0,1.0]| | |
|2 |ping_pong@missed.org |Al |2.0 |0.0 |[2.0,0.0]| | |
+---+---------------------+----------+-------------+------------+---------+ | |
// making sure it serializes properly to MLeap object | |
scala> val sbc = SparkBundleContext().withDataset(pipelineModel.transform(df)) | |
scala> (for(bf <- managed(BundleFile("jar:file:MLeap.zip"))) yield { | |
pipelineModel.writeBundle.save(bf)(sbc).get | |
}).tried.get |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment