Skip to content

Instantly share code, notes, and snippets.

@marcovivero
Created January 20, 2016 18:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save marcovivero/bf05d78b5e4f3bdc5414 to your computer and use it in GitHub Desktop.
Save marcovivero/bf05d78b5e4f3bdc5414 to your computer and use it in GitHub Desktop.
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.mllib.clustering.DistributedLDAModel
import org.apache.spark.mllib.clustering.LDA
import org.apache.spark.mllib.linalg.SparseVector
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions
import org.apache.spark.sql.Row
import org.apache.spark.sql.UserDefinedFunction
import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.StringType
import scala.collection.immutable.HashSet
import scala.math._
abstract class ClusteringModel extends Serializable {
def transform (data : Data, featuresCol : String, clusterCol : String) : DataFrame
}
class LDAModel (params : LDAParams) extends ClusteringModel {
def transform (data: Data, featuresCol : String, clusterCol : String) : DataFrame = {
val featureIndex : Int = data.data.columns.indexOf(featuresCol)
val idxData : RDD[(Long, Row)] = data.data.rdd.zipWithIndex.map(_.swap)
val ldaData : RDD[(Long, Vector)] = idxData.map(e => (e._1, e._2.getAs[Vector](featureIndex)))
val lda : DistributedLDAModel = new LDA()
.setK(params.numTopics)
.setAlpha(params.alpha)
.setBeta(params.beta)
.run(ldaData)
.asInstanceOf[DistributedLDAModel]
val clusters : RDD[(Long, Double)] = lda
.topicDistributions
.map(e => (e._1, (0 until e._2.size).maxBy(k => e._2.toArray(k)).toDouble))
.persist
val wordGivenTopic : RDD[(Long, Array[String])] = {
val topics : Array[Array[(Int, Double)]]= lda
.describeTopics(100 )
.map(e => e._1.zip(e._2))
clusters.map(k => (k._1, topics(k._2.toInt).map(_._1).map(k => data.nameArray(k))))
}
val schema : StructType = StructType(
data.data.schema.toSeq ++
Seq(
StructField(clusterCol, DoubleType),
StructField(clusterCol + "Description", new ArrayType(StringType, false))
)
)
val dataWithCluster : RDD[Row] = idxData
.join(clusters)
.join(wordGivenTopic)
.map(e => Row.fromSeq(e._2._1._1.toSeq ++ Seq[Any](e._2._1._2, e._2._2)))
data.sqlContext.createDataFrame(dataWithCluster, schema)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment