Skip to content

Instantly share code, notes, and snippets.

@bobbyali
Last active January 27, 2019 13:51
Show Gist options
  • Save bobbyali/9b98316580b1035dc005 to your computer and use it in GitHub Desktop.
Save bobbyali/9b98316580b1035dc005 to your computer and use it in GitHub Desktop.
Binary Naive Bayes Classifier in Java
The raw data in the csv file contains binary values representing the following:
1st col: Is mum present?
2nd col: Is dad present?
3rd col: Is baby present?
4th col: Outcome - 1 means meltdown, 0 means no meltdown
The TestClassifier.java file executes the app, loads the csv file, and
passes the contents to the NaiveBayesBinary.java file. NaiveBayesBinary
then works out all the probabilities required to make a prediction.
For more info on what's going on, see my blog post at
http://www.hacker-dad.com/how-to-predict-a-meltdown/
1 1 1 0
0 1 1 1
0 1 1 1
1 0 1 1
1 1 1 0
1 0 1 1
1 1 0 1
1 0 0 0
1 0 1 0
0 1 1 1
0 1 1 1
package naive_bayes;
import java.util.*;
public class NaiveBayesBinary {
// lists containing training data
private List<Boolean> training_mumPresent = new ArrayList<Boolean>();
private List<Boolean> training_dadPresent = new ArrayList<Boolean>();
private List<Boolean> training_babyPresent = new ArrayList<Boolean>();
private List<Boolean> training_outcome = new ArrayList<Boolean>();
// prior probabilities
public float p_meltdown, p_noMeltdown;
// conditional probabilities
public float p_mum_meltdown, p_dad_meltdown, p_baby_meltdown;
public float p_mum_noMeltdown, p_dad_noMeltdown, p_baby_noMeltdown;
// normalising factors
public float p_mum, p_dad, p_baby;
// posterior probabilities
public float p_meltdown_data, p_noMeltdown_data;
public NaiveBayesBinary(String[] data) {
for (String line: data) {
training_mumPresent.add( convertCharToBoolean( line.charAt(0) ) );
training_dadPresent.add( convertCharToBoolean( line.charAt(2) ) );
training_babyPresent.add( convertCharToBoolean( line.charAt(4) ) );
training_outcome.add( convertCharToBoolean( line.charAt(6) ) );
}
calcPriorProbabilities();
calcNormalisingProbabilities();
calcConditionalProbabilities();
}
private void calcPriorProbabilities() {
float numMeltdown = 0, numNoMeltdown = 0;
for (Boolean b: this.training_outcome) {
if (b == true) {
numMeltdown++;
}
else {
numNoMeltdown++;
}
}
this.p_meltdown = numMeltdown / this.training_outcome.size();
this.p_noMeltdown = numNoMeltdown / this.training_outcome.size();
}
private void calcNormalisingProbabilities() {
float numMum = 0, numDad = 0, numBaby = 0;
for (int i = 0; i < this.training_mumPresent.size(); i++) {
if (this.training_mumPresent.get(i) == true) numMum++;
if (this.training_dadPresent.get(i) == true) numDad++;
if (this.training_babyPresent.get(i) == true) numBaby++;
}
this.p_mum = numMum / this.training_outcome.size();
this.p_dad = numDad / this.training_outcome.size();
this.p_baby = numBaby / this.training_outcome.size();
}
private void calcConditionalProbabilities() {
float numMeltdown = 0, numNoMeltdown = 0;
float numMeltdownMum = 0, numMeltdownDad = 0, numMeltdownBaby = 0;
float numNoMeltdownMum = 0, numNoMeltdownDad = 0, numNoMeltdownBaby = 0;
for (int i = 0; i < this.training_outcome.size(); i++) {
if (this.training_outcome.get(i) == true) {
if (this.training_mumPresent.get(i) == true) numMeltdownMum++;
if (this.training_dadPresent.get(i) == true) numMeltdownDad++;
if (this.training_babyPresent.get(i) == true) numMeltdownBaby++;
numMeltdown++;
} else {
if (this.training_mumPresent.get(i) == true) numNoMeltdownMum++;
if (this.training_dadPresent.get(i) == true) numNoMeltdownDad++;
if (this.training_babyPresent.get(i) == true) numNoMeltdownBaby++;
numNoMeltdown++;
}
}
this.p_mum_meltdown = numMeltdownMum / numMeltdown;
this.p_dad_meltdown = numMeltdownDad / numMeltdown;
this.p_baby_meltdown = numMeltdownBaby / numMeltdown;
this.p_mum_noMeltdown = numNoMeltdownMum / numNoMeltdown;
this.p_dad_noMeltdown = numNoMeltdownDad / numNoMeltdown;
this.p_baby_noMeltdown = numNoMeltdownBaby / numNoMeltdown;
}
public void calcPosterior(Boolean mumPresent, Boolean dadPresent, Boolean babyPresent) {
float normaliser = 1;
this.p_meltdown_data = 1;
this.p_noMeltdown_data = 1;
if (mumPresent == true) {
this.p_meltdown_data = p_mum_meltdown;
this.p_noMeltdown_data = p_mum_noMeltdown;
normaliser = normaliser * this.p_mum;
}
if (dadPresent == true) {
this.p_meltdown_data = this.p_meltdown_data * p_dad_meltdown;
this.p_noMeltdown_data = this.p_noMeltdown_data * p_dad_noMeltdown;
normaliser = normaliser * this.p_dad;
}
if (babyPresent == true) {
this.p_meltdown_data = this.p_meltdown_data * p_baby_meltdown;
this.p_noMeltdown_data = this.p_noMeltdown_data * p_baby_noMeltdown;
normaliser = normaliser * this.p_baby;
}
this.p_meltdown_data *= this.p_meltdown;
this.p_noMeltdown_data *= this.p_noMeltdown;
this.p_meltdown_data /= normaliser;
this.p_noMeltdown_data /= normaliser;
printPosteriors();
}
public void printPosteriors() {
System.out.println("Posteriors:");
System.out.println("p(Breakdown|Data) = " + this.p_meltdown_data);
System.out.println("p(No Breakdown|Data) = " + this.p_noMeltdown_data);
System.out.println("Sum of posteriors = " + (this.p_meltdown_data + this.p_noMeltdown_data));
if (this.p_meltdown_data > this.p_noMeltdown_data) {
System.out.println("Breakdown is more likely.");
} else if (this.p_meltdown_data < this.p_noMeltdown_data) {
System.out.println("No Breakdown is more likely.");
} else {
System.out.println("Equal chance of breakdown vs no breakdown.");
}
System.out.println(" ");
}
public void printProbabilities() {
System.out.println("Priors:");
System.out.println("p(Breakdown) = " + this.p_meltdown);
System.out.println("p(No Breakdown) = " + this.p_noMeltdown);
System.out.println("Likelihoods:");
System.out.println("p(Mum|Meltdown) = " + this.p_mum_meltdown);
System.out.println("p(Dad|Meltdown) = " + this.p_dad_meltdown);
System.out.println("p(Baby|Meltdown) = " + this.p_baby_meltdown);
System.out.println("p(Mum|No Meltdown) = " + this.p_mum_noMeltdown);
System.out.println("p(Dad|No Meltdown) = " + this.p_dad_noMeltdown);
System.out.println("p(Baby|No Meltdown) = " + this.p_baby_noMeltdown);
}
private Boolean convertCharToBoolean(char c) {
if (c == '1') {
return true;
} else if (c == '0') {
return false;
} else {
return null;
}
}
}
package naive_bayes;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.List;
public class TestClassifier {
public static void main(String[] args) {
String fileName = "./meltdown.csv";
String[] data;
try {
List<String> lines = Files.readAllLines(Paths.get(fileName), Charset.defaultCharset());
data = lines.toArray(new String[0]);
NaiveBayesBinary dataProcessor = new NaiveBayesBinary(data);
dataProcessor.calcPosterior(true, false, true);
dataProcessor.calcPosterior(false, true, true);
} catch (IOException e) {
e.printStackTrace();
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment