Skip to content

Instantly share code, notes, and snippets.

@terrisgit
Created November 25, 2018 22:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save terrisgit/0061d506e2e38a7e7f69cf5596d32308 to your computer and use it in GitHub Desktop.
Save terrisgit/0061d506e2e38a7e7f69cf5596d32308 to your computer and use it in GitHub Desktop.
import org.apache.spark.ml.clustering.KMeans
import org.apache.spark.ml.evaluation.ClusteringEvaluator
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.{Pipeline}
import org.apache.spark.sql.functions.col
import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer}
import org.apache.spark.sql.types._
// Data file:
// age(int)
// gender('M' or 'F')
// days since prior purchase(int)
// month(string, 3 char month abbreviation)
// amount(float)
val schema = StructType(Array(
StructField("age", DoubleType, true),
StructField("gender", StringType, true),
StructField("days", DoubleType, true),
StructField("month", StringType, true),
StructField("amount", DoubleType, true)))
// Read the input file
val df = spark.read.format("csv")
.option("header", "false")
.option("inferSchema", "true")
.schema(schema)
.load("/home/jovyan/work/data.csv")
val gindexer = new StringIndexer().setInputCol("gender").setOutputCol("genderIndex")
val gencoder = new OneHotEncoder().setInputCol("genderIndex").setOutputCol("genderVec")
val mindexer = new StringIndexer().setInputCol("month").setOutputCol("monthIndex")
val mencoder = new OneHotEncoder().setInputCol("monthIndex").setOutputCol("monthVec")
// Specify the fields used for clustering
val assembler = new VectorAssembler()
.setInputCols(Array("age","genderVec","days","monthVec","amount"))
.setOutputCol("features")
// k-means model with two clusters
val kmeans = new KMeans().setK(2).setSeed(1L)
// Create a pipeline
val pipeline = new Pipeline().setStages(Array(gindexer, gencoder, mindexer, mencoder, assembler, kmeans))
// Run the pipeline
val kMeansPredictionModel = pipeline.fit(df)
// Create a dataframe with the transformed input plus a
// field named 'prediction' containing the cluster number
val predictionResult = kMeansPredictionModel.transform(df)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment