Last active
December 15, 2015 05:21
-
-
Save LittleWat/c40e03b3d60f30945b01 to your computer and use it in GitHub Desktop.
SparkLDAの実行についてのメモ ref: http://qiita.com/LittleWat/items/3ad733f7362edd954714
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name := "LDA" | |
version := "0.0.1" | |
scalaVersion := "2.10.4" | |
libraryDependencies ++= Seq( | |
"org.apache.spark" %% "spark-core" % "1.5.1" % "provided", | |
"org.apache.spark" %% "spark-mllib" % "1.5.1" | |
) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
val conf = new SparkConf().setMaster(master).setAppName("MLlib_LDA") | |
conf.registerKryoClasses(Array(classOf[MLlib_LDA])) | |
val sc = new SparkContext(conf) | |
// create instance | |
val model = new MLlib_LDA(inputFile, outputFileDirPath, sc) | |
// run lda | |
model.run(algo, K.toInt) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// logger | |
val logger = LoggerFactory.getLogger(getClass) | |
logger.warn("inputFile: " + inputFile); | |
logger.warn("outputDirPath: " + outputDirPath); | |
// constructor | |
val topicDist_file = outputDirPath + "\\User_Topic_Distribution.csv" | |
val wordDist_file = outputDirPath + "\\Topic_Jancode_Distribution.csv" | |
val idToWord_file = outputDirPath + "\\IdToWord.csv" | |
def run(algorithm: String, topic_num: Int): LDAModel = { | |
// TODO クラスの外でインポートしたい | |
import java.io._ | |
val start = System.currentTimeMillis(); | |
var sqlContext = new SQLContext(sc); | |
//val df = sqlContext.read.format("com.databricks.spark.csv").option("header", "true").option("inferSchema", "true").load(inputFile) | |
val df = sqlContext.read.format("com.databricks.spark.csv").option("header", "true").load(inputFile) | |
//split a sentence to words | |
val regexTokenizer = new RegexTokenizer().setInputCol("jan").setOutputCol("jans").setPattern("\\W") // alternatively .setPattern("\\w+").setGaps(false) | |
// Tokenの完成 | |
val regexTokenized = regexTokenizer.transform(df) | |
val df2 = regexTokenized.select("ID", "jans").toDF | |
// fit a CountVectorizerModel from the corpus | |
val cvModel: CountVectorizerModel = new CountVectorizer().setInputCol("jans").setOutputCol("features").fit(df2) | |
// SparseVectorがfeaturesカラムに入ったdataframeを作成 | |
val corpus = cvModel.transform(df2).select("ID", "features", "jans") | |
val corpusRDD = corpus.rdd.map { row => | |
val id = row.getString(0).toLong | |
val jan = row.get(1).asInstanceOf[Vector] | |
(id, jan) | |
} | |
val start2 = System.currentTimeMillis(); | |
var interval = (start2 - start) / 1000; | |
println("beforeLDA elapsed time: " + interval + "sec"); | |
logger.warn("beforeLDA elapsed time: " + interval + "sec"); | |
//corpusRDD.collect.foreach(println) | |
val ldaModel: LDAModel = new LDA().setK(topic_num).setOptimizer(algorithm).run(corpusRDD) | |
val end = System.currentTimeMillis(); | |
interval = (end - start) / 1000; | |
println("afterLDA elapsed time: " + interval + "sec"); | |
logger.warn("afterLDA elapsed time: " + interval + "sec"); | |
// cast LDAModel to sub class | |
if (algorithm == "em") { | |
println("algorithm: em") | |
logger.warn("algorithm: em") | |
val topicDistributions = ldaModel.asInstanceOf[DistributedLDAModel].topicDistributions.collect() | |
// TODO duplicated code | |
var f = new PrintWriter(new File(topicDist_file)) | |
topicDistributions.sortBy(_._1).foreach { case (docID, topicDistribution) => f.write(topicDistribution.toArray.mkString(",") + '\n') } | |
f.close | |
} else { | |
println("algorithm: online") | |
logger.warn("algorithm: online") | |
val topicDistributions = ldaModel.asInstanceOf[LocalLDAModel].topicDistributions(corpusRDD).collect() | |
// TODO duplicated code | |
var f = new PrintWriter(new File(topicDist_file)) | |
topicDistributions.sortBy(_._1).foreach { case (docID, topicDistribution) => f.write(topicDistribution.toArray.mkString(",") + '\n') } | |
f.close | |
} | |
//write wordDistributions to file | |
def sortName(array1: Array[Int], array2: Array[Double]): List[Double] = { | |
var temp = List(array2(array1.indexWhere(_ == 0))) | |
val arrayLength = array1.length | |
for (i <- Range(1, arrayLength)) { | |
temp = temp :+ array2(array1.indexWhere(_ == i)) | |
} | |
return temp //.mkString(",") | |
} | |
var f4 = new PrintWriter(new File(wordDist_file)) | |
val describeTopics = ldaModel.describeTopics(); | |
val topicNo = describeTopics.length; | |
for (i <- Range(0, topicNo)) { | |
var topicTemp = describeTopics(i) | |
var arrayUserID = describeTopics(i)._1 | |
var arrayWeight = describeTopics(i)._2 | |
f4.write(sortName(arrayUserID, arrayWeight).mkString(",") + "\n"); | |
} | |
f4.close; | |
var f5 = new PrintWriter(new File(idToWord_file)) | |
for (word <- cvModel.vocabulary) { | |
f5.write(word + "\n") | |
} | |
f5.close; | |
val end2 = System.currentTimeMillis(); | |
interval = (end2 - start) / 1000; | |
println("all elapsed time: " + interval + "sec"); | |
logger.warn("all elapsed time: " + interval + "sec"); | |
return ldaModel | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
sbt clean package |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{SPARK_HOME}/bin/spark-submit --class LDA_model ./target/scala-2.10/lda_2.10-0.0.1.jar --packages com.databricks:spark-csv_2.10:1.2.0 "local[*]" input_folder_name output_folder_name "10" "online" | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scala.collection.mutable | |
import org.apache.spark.{SparkContext, SparkConf} | |
import org.apache.spark.SparkContext._ | |
import org.apache.spark.mllib.clustering.{EMLDAOptimizer, OnlineLDAOptimizer, LDAModel, LocalLDAModel, DistributedLDAModel, LDA} | |
import org.apache.spark.mllib.linalg.{Vector, Vectors} | |
import org.apache.spark.rdd.RDD | |
import org.apache.spark.sql.SQLContext | |
import org.apache.spark.sql._ | |
import org.apache.spark.ml.feature.{Tokenizer, RegexTokenizer} | |
import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel} | |
import java.io._ | |
import java.text.BreakIterator | |
import java.io.PrintWriter | |
import org.slf4j.LoggerFactory | |
object LDA_model { | |
def main(args: Array[String]) { | |
val master = args(0) // "local" or "cluster" | |
val inputFile = args(1) // input_file_path | |
val outputFileDirPath = args(2) | |
val K = args(3) // LDA parameter | |
val algo = args(4) | |
val conf = new SparkConf().setMaster(master).setAppName("MLlib_LDA") | |
conf.registerKryoClasses(Array(classOf[MLlib_LDA])) | |
val sc = new SparkContext(conf) | |
// create instance | |
val model = new MLlib_LDA(inputFile, outputFileDirPath, sc) | |
// run lda | |
model.run(algo, K.toInt) | |
} | |
class MLlib_LDA(var inputFile: String, var outputDirPath: String, var sc: SparkContext) { | |
// logger | |
val logger = LoggerFactory.getLogger(getClass) | |
logger.warn("inputFile: " + inputFile); | |
logger.warn("outputDirPath: " + outputDirPath); | |
// constructor | |
val topicDist_file = outputDirPath + "\\docTopicDistribution.csv" | |
val wordDist_file = outputDirPath + "\\topicWordDistribution.csv" | |
val idToWord_file = outputDirPath + "\\idToWord.csv" | |
def run(algorithm: String, topic_num: Int): LDAModel = { | |
import java.io._ | |
val start = System.currentTimeMillis(); | |
var sqlContext = new SQLContext(sc); | |
//val df = sqlContext.read.format("com.databricks.spark.csv").option("header", "true").option("inferSchema", "true").load(inputFile) | |
val df = sqlContext.read.format("com.databricks.spark.csv").option("header", "true").load(inputFile) | |
//split a sentence to words | |
val regexTokenizer = new RegexTokenizer().setInputCol("jan").setOutputCol("jans").setPattern("\\W") // alternatively .setPattern("\\w+").setGaps(false) | |
// Tokenの完成 | |
val regexTokenized = regexTokenizer.transform(df) | |
val df2 = regexTokenized.select("ID", "jans").toDF | |
// fit a CountVectorizerModel from the corpus | |
val cvModel: CountVectorizerModel = new CountVectorizer().setInputCol("jans").setOutputCol("features").fit(df2) | |
// SparseVectorがfeaturesカラムに入ったdataframeを作成 | |
val corpus = cvModel.transform(df2).select("ID", "features", "jans") | |
val corpusRDD = corpus.rdd.map { row => | |
val id = row.getString(0).toLong | |
val jan = row.get(1).asInstanceOf[Vector] | |
(id, jan) | |
} | |
val start2 = System.currentTimeMillis(); | |
var interval = (start2 - start) / 1000; | |
println("beforeLDA elapsed time: " + interval + "sec"); | |
logger.warn("beforeLDA elapsed time: " + interval + "sec"); | |
//corpusRDD.collect.foreach(println) | |
val ldaModel: LDAModel = new LDA().setK(topic_num).setOptimizer(algorithm).run(corpusRDD) | |
val end = System.currentTimeMillis(); | |
interval = (end - start) / 1000; | |
println("afterLDA elapsed time: " + interval + "sec"); | |
logger.warn("afterLDA elapsed time: " + interval + "sec"); | |
// cast LDAModel to sub class | |
if (algorithm == "em") { | |
println("algorithm: em") | |
logger.warn("algorithm: em") | |
val topicDistributions = ldaModel.asInstanceOf[DistributedLDAModel].topicDistributions.collect() | |
// TODO duplicated code | |
var f = new PrintWriter(new File(topicDist_file)) | |
topicDistributions.sortBy(_._1).foreach { case (docID, topicDistribution) => f.write(topicDistribution.toArray.mkString(",") + '\n') } | |
f.close | |
} else { | |
println("algorithm: online") | |
logger.warn("algorithm: online") | |
val topicDistributions = ldaModel.asInstanceOf[LocalLDAModel].topicDistributions(corpusRDD).collect() | |
// TODO duplicated code | |
var f = new PrintWriter(new File(topicDist_file)) | |
topicDistributions.sortBy(_._1).foreach { case (docID, topicDistribution) => f.write(topicDistribution.toArray.mkString(",") + '\n') } | |
f.close | |
} | |
//write wordDistributions to file | |
def sortName(array1: Array[Int], array2: Array[Double]): List[Double] = { | |
var temp = List(array2(array1.indexWhere(_ == 0))) | |
val arrayLength = array1.length | |
for (i <- Range(1, arrayLength)) { | |
temp = temp :+ array2(array1.indexWhere(_ == i)) | |
} | |
return temp //.mkString(",") | |
} | |
var f4 = new PrintWriter(new File(wordDist_file)) | |
val describeTopics = ldaModel.describeTopics(); | |
var arrayItemID = describeTopics(0)._1 | |
for (i <- Range(0, K)) { | |
var topicTemp = describeTopics(i) | |
var arrayWeight = describeTopics(i)._2 | |
f4.write(sortName(arrayItemID, arrayWeight).mkString(",") + "\n"); | |
} | |
f4.close; | |
var f5 = new PrintWriter(new File(idToWord_file)) | |
for (word <- cvModel.vocabulary) { | |
f5.write(word + "\n") | |
} | |
f5.close; | |
val end2 = System.currentTimeMillis(); | |
interval = (end2 - start) / 1000; | |
println("all elapsed time: " + interval + "sec"); | |
logger.warn("all elapsed time: " + interval + "sec"); | |
return ldaModel | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment