Skip to content

Instantly share code, notes, and snippets.

@marcovivero
Created January 20, 2016 21:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save marcovivero/cc45d7ed9cb12c10d95a to your computer and use it in GitHub Desktop.
Save marcovivero/cc45d7ed9cb12c10d95a to your computer and use it in GitHub Desktop.
class LogRegLDAModel (params : LogRegLDAParams) extends ClusteringModel {
def transform(data : Data, featuresCol : String, clusterCol : String) : DataFrame = {
// Define LR Model.
val labels = Array("conversion", "churnThreeMonths", "churnSixMonths", "churnNineMonths")
val lr : LogisticRegression = new LogisticRegression()
.setFeaturesCol(featuresCol)
.setRegParam(params.regParam)
.setElasticNetParam(params.elasticNetParam)
val featureUDF : String => UserDefinedFunction = (label : String) => {
val weights : Array[Double] = lr
.setLabelCol(label)
.setFeaturesCol(featuresCol)
.fit(data.data)
.weights
.toArray
val prop = (weights.size * params.featureProp)
.toInt
val indices : HashSet[Int] = HashSet(
(0 until weights.size)
.sortBy(k => weights(k))
.take(prop) : _*
)
functions.udf(
(v : Vector) => {
val newIndices : Array[Int] = v.asInstanceOf[SparseVector].indices
.filter(idx => indices.contains(idx))
new SparseVector(prop, newIndices, newIndices.map(e => 1.0))
}
)
}
val newData : DataFrame = labels.foldLeft(data.data)((b, a) => {
b.withColumn(featuresCol + "|" + a, featureUDF(a)(data.data(featuresCol)))
})
val featureIndices : Array[Int] = labels
.map(label => newData.columns.indexOf(featuresCol + "|" + label))
val idxData : RDD[(Long, Row)] = newData
.rdd
.zipWithIndex
.map(_.swap)
.persist
val ldaData : Seq[RDD[(Long, Vector)]] = featureIndices
.map(idx => idxData.map(elem => (elem._1, elem._2.getAs[Vector](idx))))
val lda = new LDA()
.setK(params.numTopics)
.setAlpha(params.alpha)
.setBeta(params.beta)
val clusters : Seq[RDD[(Long, Double)]] = ldaData
.map(data => {
val model = lda.run(data).asInstanceOf[DistributedLDAModel]
model.topicDistributions.map {
case (idx, vector) => (idx, (0 until vector.size).maxBy(k => vector(k)).toDouble)
}
})
val schema = StructType(
newData.schema.toSeq ++
labels.map(label => StructField(clusterCol + "|" + label, DoubleType))
)
val finalData : RDD[(Long, Row)] = clusters.foldLeft(idxData)((b, a) => {
b.join(a).map(elem => (elem._1, Row.fromSeq(elem._2._1.toSeq ++ Seq(elem._2._2))))
})
data.data.sqlContext.createDataFrame(finalData.map(_._2), schema)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment