Skip to content

Instantly share code, notes, and snippets.

@bobbyali
Last active August 29, 2015 14:22
Show Gist options
  • Save bobbyali/004913a0456ef5db8c23 to your computer and use it in GitHub Desktop.
Save bobbyali/004913a0456ef5db8c23 to your computer and use it in GitHub Desktop.
Continuous Naive Bayes Classifier in Java
The raw data in the csv file contains continuous values representing the following:
1st col: Number of hours since last nap
2nd col: Number of hours since last meal
3rd col: Number of toys present
4th col: Outcome - 1 means meltdown, 0 means no meltdown
The TestClassifier.java file executes the app, loads the csv file, and
passes the contents to the NaiveBayesContinuous.java file.
NaiveBayesContinuous then works out all the probabilities required
to make a prediction.
For more info on what's going on, see my blog post at
http://www.hacker-dad.com/how-to-predict-a-meltdown/
5 8 8 1
6 2 7 1
8 4 5 1
7 7 4 1
4 3 7 1
9 9 0 1
5 3 3 0
3 7 2 0
2 5 5 0
5 3 3 0
2 5 7 0
7 8 1 0
package naive_bayes;
import java.util.ArrayList;
import java.util.List;
public class NaiveBayesContinuous {
// lists containing training data
private List<Float> training_nap = new ArrayList<Float>();
private List<Float> training_eat = new ArrayList<Float>();
private List<Float> training_toy = new ArrayList<Float>();
private List<Boolean> training_outcome = new ArrayList<Boolean>();
// prior probabilities
public float p_meltdown, p_noMeltdown;
// pdf parameters (mean and variance)
public float mean_nap_meltdown, var_nap_meltdown;
public float mean_eat_meltdown, var_eat_meltdown;
public float mean_toy_meltdown, var_toy_meltdown;
public float mean_nap_noMeltdown, var_nap_noMeltdown;
public float mean_eat_noMeltdown, var_eat_noMeltdown;
public float mean_toy_noMeltdown, var_toy_noMeltdown;
public float mean_nap, var_nap;
public float mean_eat, var_eat;
public float mean_toy, var_toy;
// posterior probabilities
public float p_meltdown_data, p_noMeltdown_data;
public NaiveBayesContinuous(String[] data) {
for (String line: data) {
training_nap.add( (float) Character.getNumericValue( line.charAt(0) ));
training_eat.add( (float) Character.getNumericValue( line.charAt(2) ));
training_toy.add( (float) Character.getNumericValue( line.charAt(4) ));
training_outcome.add( convertCharToBoolean( line.charAt(6) ) );
}
calcPriorProbabilities();
calcPdfParameters();
}
private void calcPriorProbabilities() {
float numMeltdown = 0, numNoMeltdown = 0;
for (Boolean b: this.training_outcome) {
if (b == true) {
numMeltdown++;
}
else {
numNoMeltdown++;
}
}
this.p_meltdown = numMeltdown / this.training_outcome.size();
this.p_noMeltdown = numNoMeltdown / this.training_outcome.size();
}
private void calcPdfParameters() {
List<Float> nap_meltdown = new ArrayList<Float>(), eat_meltdown = new ArrayList<Float>(), toy_meltdown = new ArrayList<Float>();;
List<Float> nap_noMeltdown = new ArrayList<Float>(), eat_noMeltdown = new ArrayList<Float>(), toy_noMeltdown = new ArrayList<Float>();
for (int i = 0; i < this.training_outcome.size(); i++) {
if (this.training_outcome.get(i) == true) {
nap_meltdown.add( this.training_nap.get(i) );
eat_meltdown.add( this.training_eat.get(i) );
toy_meltdown.add( this.training_toy.get(i) );
} else {
nap_noMeltdown.add( this.training_nap.get(i) );
eat_noMeltdown.add( this.training_eat.get(i) );
toy_noMeltdown.add( this.training_toy.get(i) );
}
}
this.mean_nap_meltdown = calcMean(nap_meltdown);
this.mean_eat_meltdown = calcMean(eat_meltdown);
this.mean_toy_meltdown = calcMean(toy_meltdown);
this.mean_nap_noMeltdown = calcMean(nap_noMeltdown);
this.mean_eat_noMeltdown = calcMean(eat_noMeltdown);
this.mean_toy_noMeltdown = calcMean(toy_noMeltdown);
this.var_nap_meltdown = calcVariance(nap_meltdown, this.mean_nap_meltdown);
this.var_eat_meltdown = calcVariance(eat_meltdown, this.mean_eat_meltdown);
this.var_toy_meltdown = calcVariance(toy_meltdown, this.mean_toy_meltdown);
this.var_nap_noMeltdown = calcVariance(nap_noMeltdown, this.mean_nap_noMeltdown);
this.var_eat_noMeltdown = calcVariance(eat_noMeltdown, this.mean_eat_noMeltdown);
this.var_toy_noMeltdown = calcVariance(toy_noMeltdown, this.mean_toy_noMeltdown);
this.mean_nap = calcMean(this.training_nap);
this.mean_eat = calcMean(this.training_eat);
this.mean_toy = calcMean(this.training_toy);
this.var_nap = calcVariance(this.training_nap, this.mean_nap);
this.var_eat = calcVariance(this.training_eat, this.mean_eat);
this.var_toy = calcVariance(this.training_toy, this.mean_toy);
}
public void calcPosterior(float nap, float eat, float toy) {
float numerator_meltdown = calcGaussianConditionalProbability(nap, this.mean_nap_meltdown, this.var_nap_meltdown)
* calcGaussianConditionalProbability(eat, this.mean_eat_meltdown, this.var_eat_meltdown)
* calcGaussianConditionalProbability(toy, this.mean_toy_meltdown, this.var_toy_meltdown)
* this.p_meltdown;
float numerator_noMeltdown = calcGaussianConditionalProbability(nap, this.mean_nap_noMeltdown, this.var_nap_noMeltdown)
* calcGaussianConditionalProbability(eat, this.mean_eat_noMeltdown, this.var_eat_noMeltdown)
* calcGaussianConditionalProbability(toy, this.mean_toy_noMeltdown, this.var_toy_noMeltdown)
* this.p_meltdown;
float denominator = calcGaussianConditionalProbability(nap, this.mean_nap, this.var_nap)
* calcGaussianConditionalProbability(eat, this.mean_eat, this.var_eat)
* calcGaussianConditionalProbability(toy, this.mean_toy, this.var_toy);
this.p_meltdown_data = numerator_meltdown / denominator;
this.p_noMeltdown_data = numerator_noMeltdown / denominator;
printPosteriors();
}
public void printPosteriors() {
System.out.println("Posteriors:");
System.out.println("p(Breakdown|Data) = " + this.p_meltdown_data);
System.out.println("p(No Breakdown|Data) = " + this.p_noMeltdown_data);
System.out.println("Sum of posteriors = " + (this.p_meltdown_data + this.p_noMeltdown_data));
if (this.p_meltdown_data > this.p_noMeltdown_data) {
System.out.println("Breakdown is more likely.");
} else if (this.p_meltdown_data < this.p_noMeltdown_data) {
System.out.println("No Breakdown is more likely.");
} else {
System.out.println("Equal chance of breakdown vs no breakdown.");
}
System.out.println(" ");
}
private float calcGaussianConditionalProbability(float v, float mean, float variance) {
float term1 = (float) (1 / (Math.sqrt(2 * Math.PI * variance)));
float term2 = (float) -(Math.pow(v-mean,2)) / (2 * variance);
return (float) (term1 * Math.exp(term2));
}
private float calcMean(List<Float> data) {
float total = 0;
for (float i : data) {
total += i;
}
return total / data.size();
}
private float calcVariance(List<Float> data, float mean) {
float ssds = 0; // sum of squared differences
for (float i : data) {
ssds += Math.pow(i - mean, 2);
}
return ssds / (data.size() - 1);
}
private Boolean convertCharToBoolean(char c) {
if (c == '1') {
return true;
} else if (c == '0') {
return false;
} else {
return null;
}
}
}
package naive_bayes;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.List;
public class TestClassifier {
public static void main(String[] args) {
String fileName = "./continuous.csv";
String[] data;
try {
List<String> lines = Files.readAllLines(Paths.get(fileName), Charset.defaultCharset());
data = lines.toArray(new String[0]);
NaiveBayesContinuous dataProcessor = new NaiveBayesContinuous(data);
dataProcessor.calcPosterior(6, 5, 3);
} catch (IOException e) {
e.printStackTrace();
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment