Skip to content

Instantly share code, notes, and snippets.

@jagedn
Created July 21, 2024 09:56
Show Gist options
  • Save jagedn/184302ac4f89def14410f8a6f54a93ea to your computer and use it in GitHub Desktop.
Save jagedn/184302ac4f89def14410f8a6f54a93ea to your computer and use it in GitHub Desktop.
kMeans with Spark + Groovy
/*
* This Groovy source file was generated by the Gradle 'init' task.
*/
package cluster
import groovy.json.JsonOutput
import groovy.sql.Sql
import org.apache.spark.SparkConf
import org.apache.spark.ml.clustering.KMeans;
import org.apache.spark.ml.evaluation.ClusteringEvaluator
import org.apache.spark.ml.feature.StandardScaler
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.sql.SparkSession
import java.sql.ResultSet;
class App {
static void main(String[] args) {
def mlDB = Sql.newInstance("jdbc:mysql://localhost/ml", "***", "****", 'com.mysql.jdbc.Driver')
mlDB.withStatement { stm ->
stm.fetchSize = Integer.MIN_VALUE
mlDB.resultSetConcurrency = ResultSet.CONCUR_READ_ONLY
mlDB.resultSetType = ResultSet.TYPE_FORWARD_ONLY
}
def labels = [
'Users' : 'nusers',
'Documents': 'ndocuments',
'Finished' : 'docusfinished',
'Days' : 'daystofinish',
'API' : 'template',
'Web' : 'web',
'Workflow' : 'workflow'
]
def rows = mlDB.rows('select * from ml.clientes order by nombre')
def file = new File("out.csv")
file.text = (["Company"]+labels.keySet()).join(";") + "\n"
rows.eachWithIndex { row , idx->
List<String> details = []
labels.entrySet().eachWithIndex { entry, i ->
details << (row[entry.value] ?: 0.0).toString()
}
file << "${idx+1};"+details.join(';')+"\n"
}
def spark = SparkSession
.builder()
.appName("CustomersKMeans")
.config(new SparkConf().setMaster("local"))
.getOrCreate();
def dataset = spark.read()
.option("delimiter", ";")
.option("header", "true")
.option("inferSchema", "true")
.csv("out.csv")
def assembler = new VectorAssembler(inputCols: labels.keySet(), outputCol: "features")
dataset = assembler.transform(dataset)
def scaler = new StandardScaler(inputCol: "features", outputCol: "scaledFeatures", withStd: true, withMean: true)
def scalerModel = scaler.fit(dataset)
dataset = scalerModel.transform(dataset)
// Trains a k-means model.
def kmeans = new KMeans(k:5 ,seed:1, predictionCol: "Cluster", featuresCol: "scaledFeatures" )
def kmeansModel = kmeans.fit(dataset)
// Make predictions
def predictions = kmeansModel.transform(dataset)
// Evaluate clustering by computing Silhouette score
def evaluator = new ClusteringEvaluator(predictionCol: "Cluster")
double silhouette = evaluator.evaluate(predictions)
println "Silhouette with squared euclidean distance = " + silhouette
println "Coste "+kmeansModel.summary().trainingCost()
def copy = dataset.alias("copy")
copy = copy.join(predictions.select("Company", "Cluster"), "Company", "inner")
copy.show(3)
def json = [
labels:labels.keySet(),
datasets:[]
]
kmeansModel.clusterCenters().eachWithIndex{v,i->
json.datasets << [
label:"Cluster ${i+1}",
data: v.toArray(),
fill: true
]
}
new File("data2.js").text = "const dataArr = "+JsonOutput.prettyPrint(JsonOutput.toJson(json))
spark.stop();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment