Created
February 10, 2019 19:01
-
-
Save aidancbrady/8be61d0f97e6d6311846acfb25683d03 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.text.DecimalFormat; | |
import java.util.Random; | |
import java.util.function.Function; | |
import weka.classifiers.AbstractClassifier; | |
import weka.classifiers.Classifier; | |
import weka.classifiers.Evaluation; | |
import weka.classifiers.functions.MultilayerPerceptron; | |
import weka.classifiers.functions.SMO; | |
import weka.classifiers.lazy.IBk; | |
import weka.classifiers.meta.AdaBoostM1; | |
import weka.classifiers.trees.J48; | |
import weka.core.Instances; | |
import weka.core.converters.ConverterUtils.DataSource; | |
import weka.filters.Filter; | |
import weka.filters.unsupervised.instance.Resample; | |
public class WekaWrapper | |
{ | |
private static DecimalFormat format = new DecimalFormat("#.00"); | |
private static long timestamp = 0; | |
public static void main(String[] args) { | |
try { | |
Instances cancer = getData("/Documents/Georgia Tech/Spring 2019/cs4641/Assignment 1/Datasets/breast-cancer.arff"); | |
Instances phishing = getData("/Documents/Georgia Tech/Spring 2019/cs4641/Assignment 1/Datasets/phishing-websites.arff"); | |
//DT | |
J48 tree = new J48(); | |
tree.setOptions(new String[] {"-C", "0.25", "-M", "2"}); | |
startRecording("J48, cancer data"); | |
buildLearningCurve(cancer, tree, 0.8); | |
stopRecording(); | |
startRecording("J48, phishing data"); | |
buildLearningCurve(phishing, tree, 0.8); | |
stopRecording(); | |
//Boost | |
AdaBoostM1 boost = new AdaBoostM1(); | |
boost.setOptions(new String[] {"-P", "100", "-S", "1", "-I", "10", "-W", "weka.classifiers.trees.J48", "--", "-C", "0.25", "-M", "2"}); | |
boost.setClassifier(tree); | |
startRecording("Boost, cancer data"); | |
buildLearningCurve(cancer, boost, 0.8); | |
stopRecording(); | |
boost = new AdaBoostM1(); | |
boost.setOptions(new String[] {"-P", "100", "-S", "1", "-I", "40", "-W", "weka.classifiers.trees.J48", "--", "-C", "0.25", "-M", "2"}); | |
boost.setClassifier(tree); | |
startRecording("Boost, phishing data"); | |
buildLearningCurve(phishing, boost, 0.8); | |
stopRecording(); | |
//NN | |
MultilayerPerceptron nn = new MultilayerPerceptron(); | |
nn.setOptions(new String[] {"-L", "0.3", "-M", "0.2", "-N", "300", "-V", "0", "-S", "0", "-E", "20", "-H", "a,1"}); | |
startRecording("NN, cancer data"); | |
buildLearningCurve(cancer, boost, 0.8); | |
stopRecording(); | |
nn = new MultilayerPerceptron(); | |
nn.setOptions(new String[] {"-L", "0.3", "-M", "0.2", "-N", "200", "-V", "0", "-S", "0", "-E", "20", "-H", "a,5"}); | |
startRecording("NN, phishing data"); | |
buildLearningCurve(phishing, boost, 0.8); | |
stopRecording(); | |
//KNN | |
IBk knn = new IBk(); | |
knn.setOptions(new String[] {"-K", "15", "-W", "0", "-F", "-A", "\"weka.core.neighboursearch.LinearNNSearch\" -A \"weka.core.EuclideanDistance -R first-last\""}); | |
startRecording("KNN, cancer data"); | |
buildLearningCurve(cancer, knn, 0.8); | |
stopRecording(); | |
knn = new IBk(); | |
knn.setOptions(new String[] {"-K", "1", "-W", "0", "-I", "-A", "\"weka.core.neighboursearch.LinearNNSearch\" -A \"weka.core.EuclideanDistance -R first-last\""}); | |
startRecording("KNN, phishing data"); | |
buildLearningCurve(phishing, knn, 0.8); | |
stopRecording(); | |
//SMO | |
SMO svm = new SMO(); | |
svm.setOptions(new String[] {"-C", "1.0", "-L", "0.001", "-P", "1.0E-12", "-N", "0", "-V", "-1", "-W", "1", "-K", "\"weka.classifiers.functions.supportVector.NormalizedPolyKernel\" -E 2.0 -C 250007", "-calibrator", "\"weka.classifiers.functions.Logistic\" -R 1.0E-8 -M -1 -num-decimal-places 4"}); | |
startRecording("SVM, cancer data"); | |
buildLearningCurve(cancer, svm, 0.8); | |
stopRecording(); | |
svm = new SMO(); | |
svm.setOptions(new String[] {"-C", "1.0", "-L", "0.001", "-P", "1.0E-12", "-N", "0", "-V", "-1", "-W", "1", "-K", "\"weka.classifiers.functions.supportVector.NormalizedPolyKernel\" -E 2.0 -C 250007", "-calibrator", "\"weka.classifiers.functions.Logistic\" -R 1.0E-8 -M -1 -num-decimal-places 4"}); | |
startRecording("SVM, phishing data"); | |
buildLearningCurve(phishing, svm, 0.8); | |
stopRecording(); | |
/*IBk knn = new IBk(); | |
tuneParam(cancer, knn, 1, 35, 2, (val) -> { | |
return new String[] {"-K", Integer.toString((int)Math.round(val)), "-W", "0", "-I", "-A", "\"weka.core.neighboursearch.LinearNNSearch\" -A \"weka.core.EuclideanDistance -R first-last\""}; | |
});*/ | |
} catch(Exception e) { | |
e.printStackTrace(); | |
} | |
} | |
public static void tuneParam(Instances data, AbstractClassifier classifier, double start, double end, double step, Function<Double, String[]> func) throws Exception { | |
for(double d = start; d <= end; d += step) { | |
classifier.setOptions(func.apply(d)); | |
Evaluation eval = new Evaluation(data); | |
eval.crossValidateModel(classifier, data, 10, new Random(1)); | |
System.out.println(format.format(eval.pctCorrect())); | |
} | |
} | |
public static void buildLearningCurve(Instances data, Classifier classifier, double split) throws Exception { | |
System.out.println("Training results:"); | |
//train | |
for(int i = 0; i < 10; i++) { | |
Resample filter = new Resample(); | |
String[] options = new String[] {"-S", "1", "-Z", Double.toString(((i+1))*split*10D)}; | |
filter.setOptions(options); | |
filter.setInputFormat(data); | |
Instances newData = Filter.useFilter(data, filter); | |
classifier.buildClassifier(newData); | |
Evaluation eval = new Evaluation(newData); | |
eval.evaluateModel(classifier, newData); | |
System.out.println(format.format(eval.pctCorrect())); | |
} | |
System.out.println(); | |
System.out.println("Testing results:"); | |
//test | |
for(int i = 0; i < 10; i++) { | |
Resample filter = new Resample(); | |
double datasetMult = (1-split) + (i+1)*(split/10D); | |
String[] options = new String[] {"-S", "1", "-Z", Double.toString(datasetMult*100)}; | |
filter.setOptions(options); | |
filter.setInputFormat(data); | |
Instances newData = Filter.useFilter(data, filter); | |
double testSplit = ((i+1)*(split/10D))/datasetMult; | |
int trainSize = (int)Math.round(newData.numInstances() * testSplit); | |
Instances trainData = new Instances(newData, 0, trainSize); | |
Instances testData = new Instances(newData, trainSize, newData.numInstances() - trainSize); | |
classifier.buildClassifier(trainData); | |
Evaluation eval = new Evaluation(trainData); | |
eval.evaluateModel(classifier, testData); | |
System.out.println(format.format(eval.pctCorrect())); | |
} | |
} | |
public static void startRecording(String s) { | |
timestamp = System.currentTimeMillis(); | |
System.out.println("Recording time for: " + s); | |
} | |
public static void stopRecording() { | |
long diff = System.currentTimeMillis()-timestamp; | |
System.out.println("Time elapsed: " + diff); | |
} | |
public static Instances getData(String path) throws Exception { | |
DataSource source = new DataSource(getHomeDirectory() + path); | |
Instances data = source.getDataSet(); | |
if(data.classIndex() == -1) data.setClassIndex(data.numAttributes() - 1); | |
return data; | |
} | |
public static String getHomeDirectory() { | |
return System.getProperty("user.home"); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment