Skip to content

Instantly share code, notes, and snippets.

@charlieda
Created January 8, 2014 18:46
Show Gist options
  • Save charlieda/8322111 to your computer and use it in GitHub Desktop.
Save charlieda/8322111 to your computer and use it in GitHub Desktop.
import java.util.*;
import java.lang.Math;
public class NaiveBayesClassifier implements java.io.Serializable {
// the number of spam and ham documents
private int numSpam = 0;
private int numHam = 0;
// total number of words we've seen
private int totalSpamWords = 0;
private int totalHamWords = 0;
// table of word counts for each word we've seen
// the size of this is the size of our vocabulary
private Map<String, Counts> wordCounts;
public NaiveBayesClassifier(int numHam, int numSpam, int totalHamWords, int totalSpamWords, Map<String, Counts> wordCounts) {
this.numHam = numHam;
this.numSpam = numSpam;
this.totalSpamWords = totalSpamWords;
this.totalHamWords = totalHamWords;
this.wordCounts = wordCounts;
}
//public double getLikelihoodRatio(Map<String, Integer> words) {
public double getLikelihoodRatio(ArrayList<String> words){
//System.out.println(words.get(1)+" "+words.get(2));
// initialise our ratios to the prior distribution
double hamLogRatio = Math.log((double)numHam / (numHam + numSpam));
double spamLogRatio = Math.log((double)numSpam / (numHam + numSpam));
// for each word in the received email
// update hamLog and spamLog ratios
for(int ii = 0; ii < words.size(); ii++) {
String w = words.get(ii);
if (wordCounts.containsKey(w)){
int countInHam = wordCounts.get(w).hamCount;
int countInSpam = wordCounts.get(w).spamCount;
int vocabSize = wordCounts.size();
hamLogRatio += Math.log((countInHam + 1.0) / (totalHamWords + vocabSize));
spamLogRatio += Math.log((countInSpam + 1.0) / (totalSpamWords + vocabSize));
//System.out.println("HAM: "+hamLogRatio);
}
}
/*
// add likelihood ratio for each word in our vocab
for( Map.Entry<String, Counts> entry : wordCounts.entrySet() ) {
String w = entry.getKey();
int countInHam = entry.getValue().hamCount;
int countInSpam = entry.getValue().spamCount;
int vocabSize = wordCounts.size();
if(words.containsKey(w)) {
// System.err.println(w);
//System.err.println("ln( (" + (countInHam + 1) + " / " + (totalHamWords + vocabSize) + ") ^ " + words.get(w) + ")");
// hamLogRatio += Math.log( Math.pow( (countInHam + 1.0) / (totalHamWords + vocabSize), words.get(w) ) );
// spamLogRatio += Math.log( Math.pow( (countInSpam + 1.0) / (totalSpamWords + vocabSize), words.get(w) ) );
hamLogRatio += logPow( (countInHam + 1.0) / (totalHamWords + vocabSize) , words.get(w));
spamLogRatio += logPow((countInSpam + 1.0) / (totalSpamWords + vocabSize), words.get(w));
//System.err.println("Likelihood ratio: " + (hamLogRatio - spamLogRatio) + " after " + w);
}
}
*/
//System.err.println("Final Likelihood ratio: " + (hamLogRatio - spamLogRatio) );
return hamLogRatio - spamLogRatio;
}
/**
* @return ln( value ^ exp )
*/
private static double logPow( double value, double exp) {
if( exp < 10 ) {
return Math.log( Math.pow(value, exp) );
}
double ret = 0;
for(int i = 0; i < exp; i++) {
ret += Math.log(value);
}
return ret;
}
public static class Counts implements java.io.Serializable {
public int spamCount;
public int hamCount;
public Counts() {
this.spamCount = 0;
this.hamCount = 0;
}
public String toString() {
return "Spam: " + this.spamCount + " Ham: " + this.hamCount;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment