Skip to content

Instantly share code, notes, and snippets.

@OlegYch
Created April 7, 2012 22:07
Show Gist options
  • Save OlegYch/2332409 to your computer and use it in GitHub Desktop.
Save OlegYch/2332409 to your computer and use it in GitHub Desktop.
naive bayes
class NB {
var examples = List[(String, Iterable[String])]()
/**
* Put your code for adding information to your NB classifier here
*/
def addExample(klass: String, words: Iterable[String]): Unit = {
examples ::=(klass, words)
}
def count[T, TT](i: Iterable[T])(pf: PartialFunction[T, (TT, Double)]) =
i.view.collect(pf).foldLeft(Map[TT, Double]().withDefaultValue(0.0)) {
case (m, (tt, c)) => m + (tt -> (m(tt) + c))
}
lazy val classes = examples.view.map(_._1).toSet
lazy val words = examples.view.flatMap(_._2).toSet
lazy val prior = count(examples) {
case (c, words) => (c, 1.0 / examples.size)
}
lazy val count_c = count(examples) {
case (c, words) => (c -> words.size)
}
lazy val `word->class` = examples.flatMap {case (c, words) => words.map((_, c))}
lazy val count_w_c = count(`word->class`) {
case tuple => tuple -> 1
}
lazy val p_w_c = Map[(String, String), Double]().withDefault {
case wc@(w, c) => (count_w_c(wc) + 1) / (count_c(c) + words.size)
}
lazy val inspect = {
println(classes)
println(prior)
println(count_c)
println(count_w_c.take(10))
println(`word->class`.toSet[(String, String)].map(p_w_c(_)).take(10))
}
/**
* Put your code here for deciding the class of the input file.
* Currently, it just randomly chooses "pos" or "negative"
*/
def classify(words: Iterable[String]): String = {
inspect
val pwc = classes.map(c => (c, (prior(c) :: words.toList.map(w => p_w_c(w -> c)))))
val classification = pwc.map {case (c, pwc) => (c, pwc.map(math.log).sum)}
// println(classification -> words.take(10))
classification.maxBy(_._2)._1
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment