Created
March 13, 2013 17:52
-
-
Save veered/5154508 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.io.FileNotFoundException; | |
import java.io.PrintStream; | |
import java.util.ArrayList; | |
public abstract class Classifier { | |
public abstract Integer classify(DataPoint dataPoint); | |
public void test(String fileName) throws FileNotFoundException { | |
test(Main.parse(fileName)); | |
} | |
public void test(ArrayList<DataPoint> testingData) { | |
PrintStream out = System.out; | |
int falseCorrect = 0; | |
int numFalse = 0; | |
int trueCorrect = 0; | |
int numTrue = 0; | |
for (int i = 0; i < testingData.size(); i++) { | |
Integer category = classify(testingData.get(i)); | |
if (category == 1) | |
numTrue++; | |
else | |
numFalse++; | |
if (category == testingData.get(i).category) { | |
if (category == 1) | |
trueCorrect++; | |
else | |
falseCorrect++; | |
} | |
} | |
out.println("Class 0: tested " + numFalse + ", correctly classified " + falseCorrect); | |
out.println("Class 1: tested " + numTrue + ", correctly classified " + trueCorrect); | |
out.println("Overall: tested " + (numTrue + numFalse) + ", correctly classified " + (trueCorrect + falseCorrect)); | |
Double accuracy = (double) (falseCorrect + trueCorrect) / testingData.size(); | |
out.println("Accuracy: " + accuracy); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* Created with IntelliJ IDEA. | |
* User: lucas | |
* Date: 3/7/13 | |
* Time: 6:02 AM | |
* To change this template use File | Settings | File Templates. | |
*/ | |
import java.util.ArrayList; | |
public class DataPoint { | |
public Integer category; | |
public ArrayList<Integer> data; | |
public DataPoint(Integer category, ArrayList<Integer> data) { | |
this.category = category; | |
this.data = data; | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.io.FileNotFoundException; | |
import java.util.ArrayList; | |
public class LogisticRegression extends Classifier{ | |
double[] parameters; | |
public ArrayList<DataPoint> padData(ArrayList<DataPoint> data) { | |
for (DataPoint d : data) { | |
d.data.add(0, 1); | |
} | |
return data; | |
} | |
public DataPoint padData(DataPoint d) { | |
d.data.add(0, 1); | |
return d; | |
} | |
public void train(String fileName, int epochs, double learningRate) throws FileNotFoundException { | |
train(Main.parse(fileName), epochs, learningRate); | |
} | |
public void train(ArrayList<DataPoint> trainingData, int epochs, double learningRate) { | |
trainingData = padData(trainingData); | |
Integer dim = trainingData.get(0).data.size(); | |
parameters = new double[dim]; | |
for (int i = 0; i < epochs; i++) { | |
double[] gradient = new double[dim]; | |
for (DataPoint d : trainingData) { | |
double z = 0; | |
for (int j = 0; j < dim; j++) { | |
z += parameters[j] * d.data.get(j); | |
} | |
for (int j = 0; j < dim; j ++) { | |
gradient[j] += d.data.get(j) * (d.category - 1/(1 + Math.exp(-z))); | |
} | |
} | |
for (int j = 0; j < dim; j++) { | |
parameters[j] += learningRate*gradient[j]; | |
} | |
} | |
} | |
public Integer classify(DataPoint d) { | |
d = padData(d); | |
int dim = d.data.size(); | |
double z = 0; | |
for (int i = 0; i < dim; i++) { | |
z += parameters[i] * d.data.get(i); | |
} | |
double pr = 1/(1 + Math.exp(-z)); | |
return (pr < .5) ? 0 : 1; | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.*; | |
import java.io.*; | |
public class Main { | |
public static void main(String[] args) throws FileNotFoundException { | |
// NaiveBayes bayes = new NaiveBayes(); | |
// bayes.train("vote-train.txt", 1); | |
// bayes.test("vote-test.txt"); | |
LogisticRegression log = new LogisticRegression(); | |
log.train("vote-train.txt", 10000, .0001d); | |
log.test("vote-test.txt"); | |
} | |
public static ArrayList<DataPoint> parse(String fileName) throws FileNotFoundException { | |
File file = new File(fileName); | |
Scanner scanner = new Scanner(new FileReader(file)); | |
scanner.nextLine(); | |
scanner.nextLine(); | |
ArrayList<DataPoint> dataVector = new ArrayList<DataPoint>(); | |
while(scanner.hasNextLine()) { | |
String line = scanner.nextLine(); | |
String[] tokens = line.split(":"); | |
Integer category = Integer.parseInt(tokens[1].trim()); | |
ArrayList<Integer> data = new ArrayList<Integer>(); | |
for(String s : tokens[0].split(" ")) { | |
data.add(Integer.parseInt(s.trim())); | |
} | |
dataVector.add(new DataPoint(category, data)); | |
} | |
scanner.close(); | |
return dataVector; | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.io.FileNotFoundException; | |
import java.io.PrintStream; | |
import java.util.ArrayList; | |
import java.util.HashMap; | |
public class NaiveBayes extends Classifier { | |
// Only store prior for true (prior for false is 1 minus this) | |
Double prTrue = 0d; | |
// Only store condition probability for when each variable is true | |
// the conditional probability for when each variable is false is | |
// just 1 minus this | |
HashMap<Integer, Double> condFalse = new HashMap<Integer, Double>(); | |
HashMap<Integer, Double> condTrue = new HashMap<Integer, Double>(); | |
public void train(String fileName, Integer laplaceBias) throws FileNotFoundException{ | |
train(Main.parse(fileName), laplaceBias); | |
} | |
public void train(ArrayList<DataPoint> trainingData, Integer laplaceBias) { | |
Integer dim = trainingData.get(0).data.size(); | |
Integer numTrue = 0; | |
Integer numFalse = 0; | |
int[] numCondFalse = new int[dim]; | |
int[] numCondTrue = new int[dim]; | |
for(DataPoint d : trainingData) { | |
if (d.category == 1) | |
numTrue++; | |
else | |
numFalse++; | |
for(int i = 0; i < dim; i++) { | |
if (d.category == 0 && d.data.get(i) == 1) { | |
numCondFalse[i]++; | |
} | |
if (d.category == 1 && d.data.get(i) == 1) { | |
numCondTrue[i]++; | |
} | |
} | |
} | |
//prTrue = (double)(numTrue + laplaceBias)/(trainingData.size() + 2*laplaceBias); | |
prTrue = (double)(numTrue)/(trainingData.size()); | |
for(int i = 0; i < dim; i++) { | |
condFalse.put(i, (double)(numCondFalse[i]+ laplaceBias)/(numFalse+2*laplaceBias)); | |
condTrue.put(i, (double)(numCondTrue[i]+ laplaceBias)/(numTrue+2*laplaceBias)); | |
} | |
} | |
public Integer classify(DataPoint dataPoint) { | |
Double jointFalse = (1-prTrue); | |
for(int i = 0; i < dataPoint.data.size(); i++) { | |
if (dataPoint.data.get(i) == 1) | |
jointFalse *= condFalse.get(i); | |
else | |
jointFalse *= 1 - condFalse.get(i); | |
} | |
Double jointTrue = prTrue; | |
for(int i = 0; i < dataPoint.data.size(); i++) { | |
if (dataPoint.data.get(i) == 1) | |
jointTrue *= condTrue.get(i); | |
else | |
jointTrue *= 1 - condTrue.get(i); | |
} | |
return (jointTrue > jointFalse) ? 1 : 0; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment