Skip to content

Instantly share code, notes, and snippets.

@einblicker
Created August 19, 2012 00:09
Show Gist options
  • Save einblicker/3390476 to your computer and use it in GitHub Desktop.
Save einblicker/3390476 to your computer and use it in GitHub Desktop.
naive bayes
//こちらのプログラムを移植してみた。解説も詳しくて勉強になる。
//http://d.hatena.ne.jp/aidiary/20100613/1276389337
object NaiveBayes {
type Category = String
type Word = String
case class Data(category: Category, words: List[Word])
def train(data: List[Data]) = new NBClassifier(data)
class NBClassifier (data: List[Data]) {
val (categories, vocabularies, wordcount, catcount) =
data.foldLeft(
(Set[Category](), Set[Word](),
Map[(Category, Word), Int]().withDefaultValue(0),
Map[Category, Int]().withDefaultValue(0))
){
case ((cats, vocas, wcnt, ccnt), Data(cat, doc)) =>
val newCats = cats + cat
val newVocas = vocas ++ doc
val newWcnt = doc.foldLeft(wcnt) {
case (wcnt, word) => wcnt.updated((cat, word), wcnt((cat, word)) + 1)
}
val newCcnt = ccnt.updated(cat, ccnt(cat) + 1)
(newCats, newVocas, newWcnt, newCcnt)
}
val denominator = categories.map{ cat =>
(cat, wordcount.collect{case ((`cat`, _), v) => v}.sum + vocabularies.size)
}.toMap
//事後確率が最大になるクラスへ割り当てる
def classify(doc: List[Word]) =
categories.map{ cat => (cat, score(doc, cat)) }.maxBy(_._2)._1
def wordProb(word: Word, cat: Category) =
(wordcount((cat, word)) + 1).toDouble / denominator(cat).toDouble
def score(doc: List[Word], cat: Category) = {
val total = catcount.values.sum // 総文書数
var score = math.log(catcount(cat).toDouble / total) // log P(cat)
score + doc.map(word => math.log(wordProb(word, cat))).sum
}
override def toString() =
"documents: %d, vocabularies: %d, categories: %d" format(
catcount.values.sum, vocabularies.size, categories.size
)
}
def main(args: Array[String]) = {
val data =
List(Data("yes", List("Chinese", "Beijing", "Chinese")),
Data("yes", List("Chinese", "Chinese", "Shanghai")),
Data("yes", List("Chinese", "Macao")),
Data("no", List("Tokyo", "Japan", "Chinese")))
val nb = train(data)
println(nb)
println("P(Chinese|yes) = "+ nb.wordProb("Chinese", "yes"))
println("P(Tokyo|yes) = "+ nb.wordProb("Tokyo", "yes"))
println("P(Japan|yes) = "+ nb.wordProb("Japan", "yes"))
println("P(Chinese|no) = "+ nb.wordProb("Chinese", "no"))
println("P(Tokyo|no) = "+ nb.wordProb("Tokyo", "no"))
println("P(Japan|no) = "+ nb.wordProb("Japan", "no"))
val test = List("Chinese", "Chinese", "Chinese", "Tokyo", "Japan")
println("log P(yes|test) = "+ nb.score(test, "yes"))
println("log P(no|test) = "+ nb.score(test, "no"))
println(nb.classify(test))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment