Skip to content

Instantly share code, notes, and snippets.

@aidancbrady
Created February 10, 2019 19:01
Show Gist options
  • Save aidancbrady/8be61d0f97e6d6311846acfb25683d03 to your computer and use it in GitHub Desktop.
Save aidancbrady/8be61d0f97e6d6311846acfb25683d03 to your computer and use it in GitHub Desktop.
import java.text.DecimalFormat;
import java.util.Random;
import java.util.function.Function;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.functions.MultilayerPerceptron;
import weka.classifiers.functions.SMO;
import weka.classifiers.lazy.IBk;
import weka.classifiers.meta.AdaBoostM1;
import weka.classifiers.trees.J48;
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
import weka.filters.Filter;
import weka.filters.unsupervised.instance.Resample;
public class WekaWrapper
{
private static DecimalFormat format = new DecimalFormat("#.00");
private static long timestamp = 0;
public static void main(String[] args) {
try {
Instances cancer = getData("/Documents/Georgia Tech/Spring 2019/cs4641/Assignment 1/Datasets/breast-cancer.arff");
Instances phishing = getData("/Documents/Georgia Tech/Spring 2019/cs4641/Assignment 1/Datasets/phishing-websites.arff");
//DT
J48 tree = new J48();
tree.setOptions(new String[] {"-C", "0.25", "-M", "2"});
startRecording("J48, cancer data");
buildLearningCurve(cancer, tree, 0.8);
stopRecording();
startRecording("J48, phishing data");
buildLearningCurve(phishing, tree, 0.8);
stopRecording();
//Boost
AdaBoostM1 boost = new AdaBoostM1();
boost.setOptions(new String[] {"-P", "100", "-S", "1", "-I", "10", "-W", "weka.classifiers.trees.J48", "--", "-C", "0.25", "-M", "2"});
boost.setClassifier(tree);
startRecording("Boost, cancer data");
buildLearningCurve(cancer, boost, 0.8);
stopRecording();
boost = new AdaBoostM1();
boost.setOptions(new String[] {"-P", "100", "-S", "1", "-I", "40", "-W", "weka.classifiers.trees.J48", "--", "-C", "0.25", "-M", "2"});
boost.setClassifier(tree);
startRecording("Boost, phishing data");
buildLearningCurve(phishing, boost, 0.8);
stopRecording();
//NN
MultilayerPerceptron nn = new MultilayerPerceptron();
nn.setOptions(new String[] {"-L", "0.3", "-M", "0.2", "-N", "300", "-V", "0", "-S", "0", "-E", "20", "-H", "a,1"});
startRecording("NN, cancer data");
buildLearningCurve(cancer, boost, 0.8);
stopRecording();
nn = new MultilayerPerceptron();
nn.setOptions(new String[] {"-L", "0.3", "-M", "0.2", "-N", "200", "-V", "0", "-S", "0", "-E", "20", "-H", "a,5"});
startRecording("NN, phishing data");
buildLearningCurve(phishing, boost, 0.8);
stopRecording();
//KNN
IBk knn = new IBk();
knn.setOptions(new String[] {"-K", "15", "-W", "0", "-F", "-A", "\"weka.core.neighboursearch.LinearNNSearch\" -A \"weka.core.EuclideanDistance -R first-last\""});
startRecording("KNN, cancer data");
buildLearningCurve(cancer, knn, 0.8);
stopRecording();
knn = new IBk();
knn.setOptions(new String[] {"-K", "1", "-W", "0", "-I", "-A", "\"weka.core.neighboursearch.LinearNNSearch\" -A \"weka.core.EuclideanDistance -R first-last\""});
startRecording("KNN, phishing data");
buildLearningCurve(phishing, knn, 0.8);
stopRecording();
//SMO
SMO svm = new SMO();
svm.setOptions(new String[] {"-C", "1.0", "-L", "0.001", "-P", "1.0E-12", "-N", "0", "-V", "-1", "-W", "1", "-K", "\"weka.classifiers.functions.supportVector.NormalizedPolyKernel\" -E 2.0 -C 250007", "-calibrator", "\"weka.classifiers.functions.Logistic\" -R 1.0E-8 -M -1 -num-decimal-places 4"});
startRecording("SVM, cancer data");
buildLearningCurve(cancer, svm, 0.8);
stopRecording();
svm = new SMO();
svm.setOptions(new String[] {"-C", "1.0", "-L", "0.001", "-P", "1.0E-12", "-N", "0", "-V", "-1", "-W", "1", "-K", "\"weka.classifiers.functions.supportVector.NormalizedPolyKernel\" -E 2.0 -C 250007", "-calibrator", "\"weka.classifiers.functions.Logistic\" -R 1.0E-8 -M -1 -num-decimal-places 4"});
startRecording("SVM, phishing data");
buildLearningCurve(phishing, svm, 0.8);
stopRecording();
/*IBk knn = new IBk();
tuneParam(cancer, knn, 1, 35, 2, (val) -> {
return new String[] {"-K", Integer.toString((int)Math.round(val)), "-W", "0", "-I", "-A", "\"weka.core.neighboursearch.LinearNNSearch\" -A \"weka.core.EuclideanDistance -R first-last\""};
});*/
} catch(Exception e) {
e.printStackTrace();
}
}
public static void tuneParam(Instances data, AbstractClassifier classifier, double start, double end, double step, Function<Double, String[]> func) throws Exception {
for(double d = start; d <= end; d += step) {
classifier.setOptions(func.apply(d));
Evaluation eval = new Evaluation(data);
eval.crossValidateModel(classifier, data, 10, new Random(1));
System.out.println(format.format(eval.pctCorrect()));
}
}
public static void buildLearningCurve(Instances data, Classifier classifier, double split) throws Exception {
System.out.println("Training results:");
//train
for(int i = 0; i < 10; i++) {
Resample filter = new Resample();
String[] options = new String[] {"-S", "1", "-Z", Double.toString(((i+1))*split*10D)};
filter.setOptions(options);
filter.setInputFormat(data);
Instances newData = Filter.useFilter(data, filter);
classifier.buildClassifier(newData);
Evaluation eval = new Evaluation(newData);
eval.evaluateModel(classifier, newData);
System.out.println(format.format(eval.pctCorrect()));
}
System.out.println();
System.out.println("Testing results:");
//test
for(int i = 0; i < 10; i++) {
Resample filter = new Resample();
double datasetMult = (1-split) + (i+1)*(split/10D);
String[] options = new String[] {"-S", "1", "-Z", Double.toString(datasetMult*100)};
filter.setOptions(options);
filter.setInputFormat(data);
Instances newData = Filter.useFilter(data, filter);
double testSplit = ((i+1)*(split/10D))/datasetMult;
int trainSize = (int)Math.round(newData.numInstances() * testSplit);
Instances trainData = new Instances(newData, 0, trainSize);
Instances testData = new Instances(newData, trainSize, newData.numInstances() - trainSize);
classifier.buildClassifier(trainData);
Evaluation eval = new Evaluation(trainData);
eval.evaluateModel(classifier, testData);
System.out.println(format.format(eval.pctCorrect()));
}
}
public static void startRecording(String s) {
timestamp = System.currentTimeMillis();
System.out.println("Recording time for: " + s);
}
public static void stopRecording() {
long diff = System.currentTimeMillis()-timestamp;
System.out.println("Time elapsed: " + diff);
}
public static Instances getData(String path) throws Exception {
DataSource source = new DataSource(getHomeDirectory() + path);
Instances data = source.getDataSet();
if(data.classIndex() == -1) data.setClassIndex(data.numAttributes() - 1);
return data;
}
public static String getHomeDirectory() {
return System.getProperty("user.home");
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment