Last active
May 27, 2019 17:59
-
-
Save LooooongTran/13d831cc72fc222d4373 to your computer and use it in GitHub Desktop.
Bag of Words using Spark and Scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// our labelled sentences | |
case class AnnotatedText(isRhetorical: Boolean, content: String) | |
def toLabelledPoints(annotatedTextList: List[AnnotatedText]) ={ | |
//contributors: @rgcase, @ppremont | |
import org.apache.spark.mllib.classification.SVMWithSGD | |
import org.apache.spark.mllib.linalg.Vectors | |
import org.apache.spark.mllib.regression.LabeledPoint | |
import org.apache.spark.rdd.RDD | |
import org.apache.spark.{SparkConf, SparkContext} | |
val conf = new SparkConf().setAppName("knowtakerWeb").setMaster("local[2]").set("spark.executor.memory", "1g") | |
val sc = new SparkContext(conf) | |
val zippedText = annotatedTextList.zipWithIndex | |
val zippedContent = zippedText.map(z => (z._1.content,z._2) ) | |
val svm: RDD[(Int, List[(Int, Double)])] = sc.makeRDD(zippedContent).flatMap(x => (x._1.split(' ').map(y=> (y, x._2)))) //break out each word and sentence ID into its own tuple | |
.groupByKey // Groups each word with a list of sentences they appear in | |
.zipWithIndex // gives each word an index/unique identifier | |
.flatMap{case ((word, listOfSentences), wordIndex) => listOfSentences.map((word, wordIndex, _ )) } // break out into a tuple (word, word index, sentence index) | |
.groupBy(x=>x._3) // regroup into sentences | |
.map(y => (y._1,y._2.map(z=> (z._2, 1)))) //remove the extra sentence index, string, and give each word an occurence of 1 | |
.map(x => (x._1, x._2.groupBy(_._1).map(z => (z._1,z._2.size)))) // count the number of occurences of words in each sentence | |
.map(z=> (z._1, z._2.map(vectorNode => (vectorNode._1.toInt, vectorNode._2.toDouble)).toList)) // converts to proper number format | |
val indexSize = svm.collect.map(x => x._2).map(x => { | |
val (a,b) = x.unzip | |
a}).flatten.max + 1 | |
val parsedData : RDD[LabeledPoint] = svm.map{ x => | |
val isLabelledTrue = zippedText.filter(y=>y._2 == x._1).head._1.isLabelledTrue | |
LabeledPoint(if (isLabelledTrue) 1 else 0, Vectors.sparse(indexSize, x._2.toSeq)) | |
}.cache() | |
parsedData | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment