Skip to content

Instantly share code, notes, and snippets.

Last active January 27, 2019 13:51
Show Gist options
  • Save bobbyali/9b98316580b1035dc005 to your computer and use it in GitHub Desktop.
Save bobbyali/9b98316580b1035dc005 to your computer and use it in GitHub Desktop.
Binary Naive Bayes Classifier in Java
The raw data in the csv file contains binary values representing the following:
1st col: Is mum present?
2nd col: Is dad present?
3rd col: Is baby present?
4th col: Outcome - 1 means meltdown, 0 means no meltdown
The file executes the app, loads the csv file, and
passes the contents to the file. NaiveBayesBinary
then works out all the probabilities required to make a prediction.
For more info on what's going on, see my blog post at
1 1 1 0
0 1 1 1
0 1 1 1
1 0 1 1
1 1 1 0
1 0 1 1
1 1 0 1
1 0 0 0
1 0 1 0
0 1 1 1
0 1 1 1
package naive_bayes;
import java.util.*;
public class NaiveBayesBinary {
// lists containing training data
private List<Boolean> training_mumPresent = new ArrayList<Boolean>();
private List<Boolean> training_dadPresent = new ArrayList<Boolean>();
private List<Boolean> training_babyPresent = new ArrayList<Boolean>();
private List<Boolean> training_outcome = new ArrayList<Boolean>();
// prior probabilities
public float p_meltdown, p_noMeltdown;
// conditional probabilities
public float p_mum_meltdown, p_dad_meltdown, p_baby_meltdown;
public float p_mum_noMeltdown, p_dad_noMeltdown, p_baby_noMeltdown;
// normalising factors
public float p_mum, p_dad, p_baby;
// posterior probabilities
public float p_meltdown_data, p_noMeltdown_data;
public NaiveBayesBinary(String[] data) {
for (String line: data) {
training_mumPresent.add( convertCharToBoolean( line.charAt(0) ) );
training_dadPresent.add( convertCharToBoolean( line.charAt(2) ) );
training_babyPresent.add( convertCharToBoolean( line.charAt(4) ) );
training_outcome.add( convertCharToBoolean( line.charAt(6) ) );
private void calcPriorProbabilities() {
float numMeltdown = 0, numNoMeltdown = 0;
for (Boolean b: this.training_outcome) {
if (b == true) {
else {
this.p_meltdown = numMeltdown / this.training_outcome.size();
this.p_noMeltdown = numNoMeltdown / this.training_outcome.size();
private void calcNormalisingProbabilities() {
float numMum = 0, numDad = 0, numBaby = 0;
for (int i = 0; i < this.training_mumPresent.size(); i++) {
if (this.training_mumPresent.get(i) == true) numMum++;
if (this.training_dadPresent.get(i) == true) numDad++;
if (this.training_babyPresent.get(i) == true) numBaby++;
this.p_mum = numMum / this.training_outcome.size();
this.p_dad = numDad / this.training_outcome.size();
this.p_baby = numBaby / this.training_outcome.size();
private void calcConditionalProbabilities() {
float numMeltdown = 0, numNoMeltdown = 0;
float numMeltdownMum = 0, numMeltdownDad = 0, numMeltdownBaby = 0;
float numNoMeltdownMum = 0, numNoMeltdownDad = 0, numNoMeltdownBaby = 0;
for (int i = 0; i < this.training_outcome.size(); i++) {
if (this.training_outcome.get(i) == true) {
if (this.training_mumPresent.get(i) == true) numMeltdownMum++;
if (this.training_dadPresent.get(i) == true) numMeltdownDad++;
if (this.training_babyPresent.get(i) == true) numMeltdownBaby++;
} else {
if (this.training_mumPresent.get(i) == true) numNoMeltdownMum++;
if (this.training_dadPresent.get(i) == true) numNoMeltdownDad++;
if (this.training_babyPresent.get(i) == true) numNoMeltdownBaby++;
this.p_mum_meltdown = numMeltdownMum / numMeltdown;
this.p_dad_meltdown = numMeltdownDad / numMeltdown;
this.p_baby_meltdown = numMeltdownBaby / numMeltdown;
this.p_mum_noMeltdown = numNoMeltdownMum / numNoMeltdown;
this.p_dad_noMeltdown = numNoMeltdownDad / numNoMeltdown;
this.p_baby_noMeltdown = numNoMeltdownBaby / numNoMeltdown;
public void calcPosterior(Boolean mumPresent, Boolean dadPresent, Boolean babyPresent) {
float normaliser = 1;
this.p_meltdown_data = 1;
this.p_noMeltdown_data = 1;
if (mumPresent == true) {
this.p_meltdown_data = p_mum_meltdown;
this.p_noMeltdown_data = p_mum_noMeltdown;
normaliser = normaliser * this.p_mum;
if (dadPresent == true) {
this.p_meltdown_data = this.p_meltdown_data * p_dad_meltdown;
this.p_noMeltdown_data = this.p_noMeltdown_data * p_dad_noMeltdown;
normaliser = normaliser * this.p_dad;
if (babyPresent == true) {
this.p_meltdown_data = this.p_meltdown_data * p_baby_meltdown;
this.p_noMeltdown_data = this.p_noMeltdown_data * p_baby_noMeltdown;
normaliser = normaliser * this.p_baby;
this.p_meltdown_data *= this.p_meltdown;
this.p_noMeltdown_data *= this.p_noMeltdown;
this.p_meltdown_data /= normaliser;
this.p_noMeltdown_data /= normaliser;
public void printPosteriors() {
System.out.println("p(Breakdown|Data) = " + this.p_meltdown_data);
System.out.println("p(No Breakdown|Data) = " + this.p_noMeltdown_data);
System.out.println("Sum of posteriors = " + (this.p_meltdown_data + this.p_noMeltdown_data));
if (this.p_meltdown_data > this.p_noMeltdown_data) {
System.out.println("Breakdown is more likely.");
} else if (this.p_meltdown_data < this.p_noMeltdown_data) {
System.out.println("No Breakdown is more likely.");
} else {
System.out.println("Equal chance of breakdown vs no breakdown.");
System.out.println(" ");
public void printProbabilities() {
System.out.println("p(Breakdown) = " + this.p_meltdown);
System.out.println("p(No Breakdown) = " + this.p_noMeltdown);
System.out.println("p(Mum|Meltdown) = " + this.p_mum_meltdown);
System.out.println("p(Dad|Meltdown) = " + this.p_dad_meltdown);
System.out.println("p(Baby|Meltdown) = " + this.p_baby_meltdown);
System.out.println("p(Mum|No Meltdown) = " + this.p_mum_noMeltdown);
System.out.println("p(Dad|No Meltdown) = " + this.p_dad_noMeltdown);
System.out.println("p(Baby|No Meltdown) = " + this.p_baby_noMeltdown);
private Boolean convertCharToBoolean(char c) {
if (c == '1') {
return true;
} else if (c == '0') {
return false;
} else {
return null;
package naive_bayes;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.List;
public class TestClassifier {
public static void main(String[] args) {
String fileName = "./meltdown.csv";
String[] data;
try {
List<String> lines = Files.readAllLines(Paths.get(fileName), Charset.defaultCharset());
data = lines.toArray(new String[0]);
NaiveBayesBinary dataProcessor = new NaiveBayesBinary(data);
dataProcessor.calcPosterior(true, false, true);
dataProcessor.calcPosterior(false, true, true);
} catch (IOException e) {
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment