Skip to content

Instantly share code, notes, and snippets.

@LooooongTran
Last active May 27, 2019 17:59
Show Gist options
  • Save LooooongTran/13d831cc72fc222d4373 to your computer and use it in GitHub Desktop.
Save LooooongTran/13d831cc72fc222d4373 to your computer and use it in GitHub Desktop.
Bag of Words using Spark and Scala
// our labelled sentences
case class AnnotatedText(isRhetorical: Boolean, content: String)
def toLabelledPoints(annotatedTextList: List[AnnotatedText]) ={
//contributors: @rgcase, @ppremont
import org.apache.spark.mllib.classification.SVMWithSGD
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}
val conf = new SparkConf().setAppName("knowtakerWeb").setMaster("local[2]").set("spark.executor.memory", "1g")
val sc = new SparkContext(conf)
val zippedText = annotatedTextList.zipWithIndex
val zippedContent = zippedText.map(z => (z._1.content,z._2) )
val svm: RDD[(Int, List[(Int, Double)])] = sc.makeRDD(zippedContent).flatMap(x => (x._1.split(' ').map(y=> (y, x._2)))) //break out each word and sentence ID into its own tuple
.groupByKey // Groups each word with a list of sentences they appear in
.zipWithIndex // gives each word an index/unique identifier
.flatMap{case ((word, listOfSentences), wordIndex) => listOfSentences.map((word, wordIndex, _ )) } // break out into a tuple (word, word index, sentence index)
.groupBy(x=>x._3) // regroup into sentences
.map(y => (y._1,y._2.map(z=> (z._2, 1)))) //remove the extra sentence index, string, and give each word an occurence of 1
.map(x => (x._1, x._2.groupBy(_._1).map(z => (z._1,z._2.size)))) // count the number of occurences of words in each sentence
.map(z=> (z._1, z._2.map(vectorNode => (vectorNode._1.toInt, vectorNode._2.toDouble)).toList)) // converts to proper number format
val indexSize = svm.collect.map(x => x._2).map(x => {
val (a,b) = x.unzip
a}).flatten.max + 1
val parsedData : RDD[LabeledPoint] = svm.map{ x =>
val isLabelledTrue = zippedText.filter(y=>y._2 == x._1).head._1.isLabelledTrue
LabeledPoint(if (isLabelledTrue) 1 else 0, Vectors.sparse(indexSize, x._2.toSeq))
}.cache()
parsedData
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment