Skip to content

Instantly share code, notes, and snippets.

@mostafam
Created August 13, 2020 08:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mostafam/3c5f871f748e562d32f326c4007160ca to your computer and use it in GitHub Desktop.
Save mostafam/3c5f871f748e562d32f326c4007160ca to your computer and use it in GitHub Desktop.
Create Spark PipelineModel and Export it
scala> import org.apache.spark.ml.mleap.feature.StringChecker
scala> import ml.combust.mleap.core.feature.StringCheckerModel
scala> import org.apache.spark.ml.bundle.SparkBundleContext
scala> import ml.combust.bundle.BundleFile
scala> import ml.combust.mleap.spark.SparkSupport._
scala> import org.apache.spark.ml.{Pipeline, PipelineModel}
scala> import org.apache.spark.ml.feature.{StringIndexer,VectorAssembler}
scala> import org.apache.spark.sql._
scala> import org.apache.spark.sql.functions._
scala> import resource._
scala> val df = spark.createDataFrame(
| Seq((0, "john.doe@gmail.com", "John"), (1, "JackieChan234@xyz.com","jack"), (2, "ping_pong@missed.org","Al"))
| ).toDF("id", "email", "first_name")
scala> val indexer = new StringIndexer().setInputCol("first_name").setOutputCol("categoryIndex")
scala> val stringChecker = new StringChecker(uid = "string_checker", model = new StringCheckerModel(caseSensitive = false)).
setInputCols("email", "first_name").
setOutputCol("is_it_there?")
scala> val vectorAssembler = new VectorAssembler().setInputCols(Array("categoryIndex", "is_it_there")).setOutputCol("combined")
scala> val pipeline = new Pipeline().setStages(Array(indexer, stringChecker, vectorAssembler))
scala> val pipelineModel = pipeline.fit(df)
// making sure it works properly
scala> pipelineModel.transform(df).show(false)
+---+---------------------+----------+-------------+------------+---------+
|id |email |first_name|categoryIndex|is_it_there?|combined |
+---+---------------------+----------+-------------+------------+---------+
|0 |john.doe@gmail.com |John |0.0 |1.0 |[0.0,1.0]|
|1 |JackieChan234@xyz.com|jack |1.0 |1.0 |[1.0,1.0]|
|2 |ping_pong@missed.org |Al |2.0 |0.0 |[2.0,0.0]|
+---+---------------------+----------+-------------+------------+---------+
// making sure it serializes properly in Spark
scala> pipelineModel.save("pm")
// making sure it deserializes properly in Spark
scala> val pipelineModel2 = PipelineModel.load("pm")
scala> pipelineModel2.transform(df).show(false)
+---+---------------------+----------+-------------+------------+---------+
|id |email |first_name|categoryIndex|is_it_there?|combined |
+---+---------------------+----------+-------------+------------+---------+
|0 |john.doe@gmail.com |John |0.0 |1.0 |[0.0,1.0]|
|1 |JackieChan234@xyz.com|jack |1.0 |1.0 |[1.0,1.0]|
|2 |ping_pong@missed.org |Al |2.0 |0.0 |[2.0,0.0]|
+---+---------------------+----------+-------------+------------+---------+
// making sure it serializes properly to MLeap object
scala> val sbc = SparkBundleContext().withDataset(pipelineModel.transform(df))
scala> (for(bf <- managed(BundleFile("jar:file:MLeap.zip"))) yield {
pipelineModel.writeBundle.save(bf)(sbc).get
}).tried.get
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment