Skip to content

Instantly share code, notes, and snippets.

@LittleWat
Last active December 15, 2015 05:21
Show Gist options
  • Save LittleWat/c40e03b3d60f30945b01 to your computer and use it in GitHub Desktop.
Save LittleWat/c40e03b3d60f30945b01 to your computer and use it in GitHub Desktop.
SparkLDAの実行についてのメモ ref: http://qiita.com/LittleWat/items/3ad733f7362edd954714
name := "LDA"
version := "0.0.1"
scalaVersion := "2.10.4"
libraryDependencies ++= Seq(
"org.apache.spark" %% "spark-core" % "1.5.1" % "provided",
"org.apache.spark" %% "spark-mllib" % "1.5.1"
)
val conf = new SparkConf().setMaster(master).setAppName("MLlib_LDA")
conf.registerKryoClasses(Array(classOf[MLlib_LDA]))
val sc = new SparkContext(conf)
// create instance
val model = new MLlib_LDA(inputFile, outputFileDirPath, sc)
// run lda
model.run(algo, K.toInt)
// logger
val logger = LoggerFactory.getLogger(getClass)
logger.warn("inputFile: " + inputFile);
logger.warn("outputDirPath: " + outputDirPath);
// constructor
val topicDist_file = outputDirPath + "\\User_Topic_Distribution.csv"
val wordDist_file = outputDirPath + "\\Topic_Jancode_Distribution.csv"
val idToWord_file = outputDirPath + "\\IdToWord.csv"
def run(algorithm: String, topic_num: Int): LDAModel = {
// TODO クラスの外でインポートしたい
import java.io._
val start = System.currentTimeMillis();
var sqlContext = new SQLContext(sc);
//val df = sqlContext.read.format("com.databricks.spark.csv").option("header", "true").option("inferSchema", "true").load(inputFile)
val df = sqlContext.read.format("com.databricks.spark.csv").option("header", "true").load(inputFile)
//split a sentence to words
val regexTokenizer = new RegexTokenizer().setInputCol("jan").setOutputCol("jans").setPattern("\\W") // alternatively .setPattern("\\w+").setGaps(false)
// Tokenの完成
val regexTokenized = regexTokenizer.transform(df)
val df2 = regexTokenized.select("ID", "jans").toDF
// fit a CountVectorizerModel from the corpus
val cvModel: CountVectorizerModel = new CountVectorizer().setInputCol("jans").setOutputCol("features").fit(df2)
// SparseVectorがfeaturesカラムに入ったdataframeを作成
val corpus = cvModel.transform(df2).select("ID", "features", "jans")
val corpusRDD = corpus.rdd.map { row =>
val id = row.getString(0).toLong
val jan = row.get(1).asInstanceOf[Vector]
(id, jan)
}
val start2 = System.currentTimeMillis();
var interval = (start2 - start) / 1000;
println("beforeLDA elapsed time: " + interval + "sec");
logger.warn("beforeLDA elapsed time: " + interval + "sec");
//corpusRDD.collect.foreach(println)
val ldaModel: LDAModel = new LDA().setK(topic_num).setOptimizer(algorithm).run(corpusRDD)
val end = System.currentTimeMillis();
interval = (end - start) / 1000;
println("afterLDA elapsed time: " + interval + "sec");
logger.warn("afterLDA elapsed time: " + interval + "sec");
// cast LDAModel to sub class
if (algorithm == "em") {
println("algorithm: em")
logger.warn("algorithm: em")
val topicDistributions = ldaModel.asInstanceOf[DistributedLDAModel].topicDistributions.collect()
// TODO duplicated code
var f = new PrintWriter(new File(topicDist_file))
topicDistributions.sortBy(_._1).foreach { case (docID, topicDistribution) => f.write(topicDistribution.toArray.mkString(",") + '\n') }
f.close
} else {
println("algorithm: online")
logger.warn("algorithm: online")
val topicDistributions = ldaModel.asInstanceOf[LocalLDAModel].topicDistributions(corpusRDD).collect()
// TODO duplicated code
var f = new PrintWriter(new File(topicDist_file))
topicDistributions.sortBy(_._1).foreach { case (docID, topicDistribution) => f.write(topicDistribution.toArray.mkString(",") + '\n') }
f.close
}
//write wordDistributions to file
def sortName(array1: Array[Int], array2: Array[Double]): List[Double] = {
var temp = List(array2(array1.indexWhere(_ == 0)))
val arrayLength = array1.length
for (i <- Range(1, arrayLength)) {
temp = temp :+ array2(array1.indexWhere(_ == i))
}
return temp //.mkString(",")
}
var f4 = new PrintWriter(new File(wordDist_file))
val describeTopics = ldaModel.describeTopics();
val topicNo = describeTopics.length;
for (i <- Range(0, topicNo)) {
var topicTemp = describeTopics(i)
var arrayUserID = describeTopics(i)._1
var arrayWeight = describeTopics(i)._2
f4.write(sortName(arrayUserID, arrayWeight).mkString(",") + "\n");
}
f4.close;
var f5 = new PrintWriter(new File(idToWord_file))
for (word <- cvModel.vocabulary) {
f5.write(word + "\n")
}
f5.close;
val end2 = System.currentTimeMillis();
interval = (end2 - start) / 1000;
println("all elapsed time: " + interval + "sec");
logger.warn("all elapsed time: " + interval + "sec");
return ldaModel
}
sbt clean package
{SPARK_HOME}/bin/spark-submit --class LDA_model ./target/scala-2.10/lda_2.10-0.0.1.jar --packages com.databricks:spark-csv_2.10:1.2.0 "local[*]" input_folder_name output_folder_name "10" "online"
import scala.collection.mutable
import org.apache.spark.{SparkContext, SparkConf}
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.clustering.{EMLDAOptimizer, OnlineLDAOptimizer, LDAModel, LocalLDAModel, DistributedLDAModel, LDA}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql._
import org.apache.spark.ml.feature.{Tokenizer, RegexTokenizer}
import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel}
import java.io._
import java.text.BreakIterator
import java.io.PrintWriter
import org.slf4j.LoggerFactory
object LDA_model {
def main(args: Array[String]) {
val master = args(0) // "local" or "cluster"
val inputFile = args(1) // input_file_path
val outputFileDirPath = args(2)
val K = args(3) // LDA parameter
val algo = args(4)
val conf = new SparkConf().setMaster(master).setAppName("MLlib_LDA")
conf.registerKryoClasses(Array(classOf[MLlib_LDA]))
val sc = new SparkContext(conf)
// create instance
val model = new MLlib_LDA(inputFile, outputFileDirPath, sc)
// run lda
model.run(algo, K.toInt)
}
class MLlib_LDA(var inputFile: String, var outputDirPath: String, var sc: SparkContext) {
// logger
val logger = LoggerFactory.getLogger(getClass)
logger.warn("inputFile: " + inputFile);
logger.warn("outputDirPath: " + outputDirPath);
// constructor
val topicDist_file = outputDirPath + "\\docTopicDistribution.csv"
val wordDist_file = outputDirPath + "\\topicWordDistribution.csv"
val idToWord_file = outputDirPath + "\\idToWord.csv"
def run(algorithm: String, topic_num: Int): LDAModel = {
import java.io._
val start = System.currentTimeMillis();
var sqlContext = new SQLContext(sc);
//val df = sqlContext.read.format("com.databricks.spark.csv").option("header", "true").option("inferSchema", "true").load(inputFile)
val df = sqlContext.read.format("com.databricks.spark.csv").option("header", "true").load(inputFile)
//split a sentence to words
val regexTokenizer = new RegexTokenizer().setInputCol("jan").setOutputCol("jans").setPattern("\\W") // alternatively .setPattern("\\w+").setGaps(false)
// Tokenの完成
val regexTokenized = regexTokenizer.transform(df)
val df2 = regexTokenized.select("ID", "jans").toDF
// fit a CountVectorizerModel from the corpus
val cvModel: CountVectorizerModel = new CountVectorizer().setInputCol("jans").setOutputCol("features").fit(df2)
// SparseVectorがfeaturesカラムに入ったdataframeを作成
val corpus = cvModel.transform(df2).select("ID", "features", "jans")
val corpusRDD = corpus.rdd.map { row =>
val id = row.getString(0).toLong
val jan = row.get(1).asInstanceOf[Vector]
(id, jan)
}
val start2 = System.currentTimeMillis();
var interval = (start2 - start) / 1000;
println("beforeLDA elapsed time: " + interval + "sec");
logger.warn("beforeLDA elapsed time: " + interval + "sec");
//corpusRDD.collect.foreach(println)
val ldaModel: LDAModel = new LDA().setK(topic_num).setOptimizer(algorithm).run(corpusRDD)
val end = System.currentTimeMillis();
interval = (end - start) / 1000;
println("afterLDA elapsed time: " + interval + "sec");
logger.warn("afterLDA elapsed time: " + interval + "sec");
// cast LDAModel to sub class
if (algorithm == "em") {
println("algorithm: em")
logger.warn("algorithm: em")
val topicDistributions = ldaModel.asInstanceOf[DistributedLDAModel].topicDistributions.collect()
// TODO duplicated code
var f = new PrintWriter(new File(topicDist_file))
topicDistributions.sortBy(_._1).foreach { case (docID, topicDistribution) => f.write(topicDistribution.toArray.mkString(",") + '\n') }
f.close
} else {
println("algorithm: online")
logger.warn("algorithm: online")
val topicDistributions = ldaModel.asInstanceOf[LocalLDAModel].topicDistributions(corpusRDD).collect()
// TODO duplicated code
var f = new PrintWriter(new File(topicDist_file))
topicDistributions.sortBy(_._1).foreach { case (docID, topicDistribution) => f.write(topicDistribution.toArray.mkString(",") + '\n') }
f.close
}
//write wordDistributions to file
def sortName(array1: Array[Int], array2: Array[Double]): List[Double] = {
var temp = List(array2(array1.indexWhere(_ == 0)))
val arrayLength = array1.length
for (i <- Range(1, arrayLength)) {
temp = temp :+ array2(array1.indexWhere(_ == i))
}
return temp //.mkString(",")
}
var f4 = new PrintWriter(new File(wordDist_file))
val describeTopics = ldaModel.describeTopics();
var arrayItemID = describeTopics(0)._1
for (i <- Range(0, K)) {
var topicTemp = describeTopics(i)
var arrayWeight = describeTopics(i)._2
f4.write(sortName(arrayItemID, arrayWeight).mkString(",") + "\n");
}
f4.close;
var f5 = new PrintWriter(new File(idToWord_file))
for (word <- cvModel.vocabulary) {
f5.write(word + "\n")
}
f5.close;
val end2 = System.currentTimeMillis();
interval = (end2 - start) / 1000;
println("all elapsed time: " + interval + "sec");
logger.warn("all elapsed time: " + interval + "sec");
return ldaModel
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment