Skip to content

Instantly share code, notes, and snippets.

@ally1221
Created January 14, 2022 21:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ally1221/42473cdc818a8cf795ac78d65d48ee14 to your computer and use it in GitHub Desktop.
Save ally1221/42473cdc818a8cf795ac78d65d48ee14 to your computer and use it in GitHub Desktop.
import com.dtech.scala.pipeline.CustomUnaryTransformer
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.ml.PipelineModel
// Initialize spark session
val sparkSession = sparkSession = SparkSession.builder().getOrCreate()
// Create custom transfomer
val transformer = new CustomUnaryTransformer()
.setShift(0.5)
.setInputCol("input")
.setOutputCol("output")
// Save the custom transformer
transformer.write.overwrite().save("/opt/spark-output/notebook/zenurian-test/transformers/unarytransformer")
// Create and save pipeline model
val pipeline = new Pipeline().setStages(Array(myTransformer))
val model = pipeline.fit(df)
model.write.overwrite().save("/opt/spark-output/notebook/zenurian-test/pipelines/unarytransformer")
// Create test dataframe
val df = sparkSession.range(0, 5).toDF("input")
.select(col("input").cast("double").as("input"))
// Successfully load custom unary transformer and transform dataframe
val transformer = CustomUnaryTransformer.load("/opt/spark-output/notebook/zenurian-test/transformers/unarytransformer")
val transformed = transformer.transform(df)
println("Transformed DF")
transformed.show()
// Fail to load pipeline model
var model = PipelineModel.load("/opt/spark-output/notebook/zenurian-test/pipelines/unarytransformer")
println(s"Model was fit using parameters: ${model.parent.extractParamMap}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment