Skip to content

Instantly share code, notes, and snippets.

@veered
Created March 13, 2013 17:52
Show Gist options
  • Save veered/5154508 to your computer and use it in GitHub Desktop.
Save veered/5154508 to your computer and use it in GitHub Desktop.
import java.io.FileNotFoundException;
import java.io.PrintStream;
import java.util.ArrayList;
public abstract class Classifier {
public abstract Integer classify(DataPoint dataPoint);
public void test(String fileName) throws FileNotFoundException {
test(Main.parse(fileName));
}
public void test(ArrayList<DataPoint> testingData) {
PrintStream out = System.out;
int falseCorrect = 0;
int numFalse = 0;
int trueCorrect = 0;
int numTrue = 0;
for (int i = 0; i < testingData.size(); i++) {
Integer category = classify(testingData.get(i));
if (category == 1)
numTrue++;
else
numFalse++;
if (category == testingData.get(i).category) {
if (category == 1)
trueCorrect++;
else
falseCorrect++;
}
}
out.println("Class 0: tested " + numFalse + ", correctly classified " + falseCorrect);
out.println("Class 1: tested " + numTrue + ", correctly classified " + trueCorrect);
out.println("Overall: tested " + (numTrue + numFalse) + ", correctly classified " + (trueCorrect + falseCorrect));
Double accuracy = (double) (falseCorrect + trueCorrect) / testingData.size();
out.println("Accuracy: " + accuracy);
}
}
/**
* Created with IntelliJ IDEA.
* User: lucas
* Date: 3/7/13
* Time: 6:02 AM
* To change this template use File | Settings | File Templates.
*/
import java.util.ArrayList;
public class DataPoint {
public Integer category;
public ArrayList<Integer> data;
public DataPoint(Integer category, ArrayList<Integer> data) {
this.category = category;
this.data = data;
}
}
import java.io.FileNotFoundException;
import java.util.ArrayList;
public class LogisticRegression extends Classifier{
double[] parameters;
public ArrayList<DataPoint> padData(ArrayList<DataPoint> data) {
for (DataPoint d : data) {
d.data.add(0, 1);
}
return data;
}
public DataPoint padData(DataPoint d) {
d.data.add(0, 1);
return d;
}
public void train(String fileName, int epochs, double learningRate) throws FileNotFoundException {
train(Main.parse(fileName), epochs, learningRate);
}
public void train(ArrayList<DataPoint> trainingData, int epochs, double learningRate) {
trainingData = padData(trainingData);
Integer dim = trainingData.get(0).data.size();
parameters = new double[dim];
for (int i = 0; i < epochs; i++) {
double[] gradient = new double[dim];
for (DataPoint d : trainingData) {
double z = 0;
for (int j = 0; j < dim; j++) {
z += parameters[j] * d.data.get(j);
}
for (int j = 0; j < dim; j ++) {
gradient[j] += d.data.get(j) * (d.category - 1/(1 + Math.exp(-z)));
}
}
for (int j = 0; j < dim; j++) {
parameters[j] += learningRate*gradient[j];
}
}
}
public Integer classify(DataPoint d) {
d = padData(d);
int dim = d.data.size();
double z = 0;
for (int i = 0; i < dim; i++) {
z += parameters[i] * d.data.get(i);
}
double pr = 1/(1 + Math.exp(-z));
return (pr < .5) ? 0 : 1;
}
}
import java.util.*;
import java.io.*;
public class Main {
public static void main(String[] args) throws FileNotFoundException {
// NaiveBayes bayes = new NaiveBayes();
// bayes.train("vote-train.txt", 1);
// bayes.test("vote-test.txt");
LogisticRegression log = new LogisticRegression();
log.train("vote-train.txt", 10000, .0001d);
log.test("vote-test.txt");
}
public static ArrayList<DataPoint> parse(String fileName) throws FileNotFoundException {
File file = new File(fileName);
Scanner scanner = new Scanner(new FileReader(file));
scanner.nextLine();
scanner.nextLine();
ArrayList<DataPoint> dataVector = new ArrayList<DataPoint>();
while(scanner.hasNextLine()) {
String line = scanner.nextLine();
String[] tokens = line.split(":");
Integer category = Integer.parseInt(tokens[1].trim());
ArrayList<Integer> data = new ArrayList<Integer>();
for(String s : tokens[0].split(" ")) {
data.add(Integer.parseInt(s.trim()));
}
dataVector.add(new DataPoint(category, data));
}
scanner.close();
return dataVector;
}
}
import java.io.FileNotFoundException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.HashMap;
public class NaiveBayes extends Classifier {
// Only store prior for true (prior for false is 1 minus this)
Double prTrue = 0d;
// Only store condition probability for when each variable is true
// the conditional probability for when each variable is false is
// just 1 minus this
HashMap<Integer, Double> condFalse = new HashMap<Integer, Double>();
HashMap<Integer, Double> condTrue = new HashMap<Integer, Double>();
public void train(String fileName, Integer laplaceBias) throws FileNotFoundException{
train(Main.parse(fileName), laplaceBias);
}
public void train(ArrayList<DataPoint> trainingData, Integer laplaceBias) {
Integer dim = trainingData.get(0).data.size();
Integer numTrue = 0;
Integer numFalse = 0;
int[] numCondFalse = new int[dim];
int[] numCondTrue = new int[dim];
for(DataPoint d : trainingData) {
if (d.category == 1)
numTrue++;
else
numFalse++;
for(int i = 0; i < dim; i++) {
if (d.category == 0 && d.data.get(i) == 1) {
numCondFalse[i]++;
}
if (d.category == 1 && d.data.get(i) == 1) {
numCondTrue[i]++;
}
}
}
//prTrue = (double)(numTrue + laplaceBias)/(trainingData.size() + 2*laplaceBias);
prTrue = (double)(numTrue)/(trainingData.size());
for(int i = 0; i < dim; i++) {
condFalse.put(i, (double)(numCondFalse[i]+ laplaceBias)/(numFalse+2*laplaceBias));
condTrue.put(i, (double)(numCondTrue[i]+ laplaceBias)/(numTrue+2*laplaceBias));
}
}
public Integer classify(DataPoint dataPoint) {
Double jointFalse = (1-prTrue);
for(int i = 0; i < dataPoint.data.size(); i++) {
if (dataPoint.data.get(i) == 1)
jointFalse *= condFalse.get(i);
else
jointFalse *= 1 - condFalse.get(i);
}
Double jointTrue = prTrue;
for(int i = 0; i < dataPoint.data.size(); i++) {
if (dataPoint.data.get(i) == 1)
jointTrue *= condTrue.get(i);
else
jointTrue *= 1 - condTrue.get(i);
}
return (jointTrue > jointFalse) ? 1 : 0;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment