Skip to content

Instantly share code, notes, and snippets.

@aidancbrady
Created February 20, 2019 18:32
Show Gist options
  • Save aidancbrady/274ef25ba3b402d42749e73458806219 to your computer and use it in GitHub Desktop.
Save aidancbrady/274ef25ba3b402d42749e73458806219 to your computer and use it in GitHub Desktop.
import func.nn.backprop.BackPropagationNetwork;
import func.nn.backprop.BackPropagationNetworkFactory;
import opt.OptimizationAlgorithm;
import opt.example.NeuralNetworkOptimizationProblem;
import opt.ga.StandardGeneticAlgorithm;
import shared.DataSet;
import shared.Instance;
import shared.SumOfSquaresError;
import shared.filt.TestTrainSplitFilter;
import shared.reader.ArffDataSetReader;
public class RandomizedOptimization
{
private static final double ITERATIONS = 1000;
public static void main(String[] args) {
// Load instances
Instance[] data = getData("/Documents/Georgia Tech/Spring 2019/cs4641/Datasets/phishing-websites.arff");
if(data == null) {
System.out.println("Error loading data.");
return;
}
DataSet dataSet = new DataSet(data);
SumOfSquaresError errorFunction = new SumOfSquaresError();
TestTrainSplitFilter filter = new TestTrainSplitFilter(70);
filter.filter(dataSet);
BackPropagationNetworkFactory nnFactory = new BackPropagationNetworkFactory();
BackPropagationNetwork neuralNetwork = nnFactory.createClassificationNetwork(new int[] {data[0].getData().size(), 1, 1});
NeuralNetworkOptimizationProblem weightProblem = new NeuralNetworkOptimizationProblem(filter.getTrainingSet(), neuralNetwork, errorFunction);
//RandomizedHillClimbing rhc = new RandomizedHillClimbing(weightProblem);
//SimulatedAnnealing sa = new SimulatedAnnealing(1.0E5, 0.25, weightProblem);
StandardGeneticAlgorithm ga = new StandardGeneticAlgorithm(300, 72, 54, weightProblem);
startRecording("Curve");
run(dataSet, neuralNetwork, ga, filter);
stopRecording();
//gridSearch(dataSet, neuralNetwork, weightProblem, filter);
}
private static void gridSearch(DataSet dataSet, BackPropagationNetwork neuralNetwork, NeuralNetworkOptimizationProblem problem, TestTrainSplitFilter filter) {
Instance[] testSet = filter.getTestingSet().getInstances();
int[] populations = {100, 200, 300, 400, 500};
for(int population : populations) {
StandardGeneticAlgorithm algorithm = new StandardGeneticAlgorithm(population, (int)(0.24*population), (int)(0.18*population), problem);
for(int i = 0; i < ITERATIONS; i++) {
int correctTest = 0;
algorithm.train();
for(int j = 0; j < testSet.length; j++) {
neuralNetwork.setInputValues(testSet[j].getData());
neuralNetwork.run(); // feed forward
double expected = Double.parseDouble(testSet[j].getLabel().toString());
double out = Double.parseDouble(neuralNetwork.getOutputValues().toString());
if(Math.abs(expected-out) < 0.5D) {
correctTest++;
}
}
if(i == ITERATIONS-1) {
System.out.println("Pop: " + population + " = " + ((double)correctTest/testSet.length));
}
}
}
}
private static void run(DataSet dataSet, BackPropagationNetwork neuralNetwork, OptimizationAlgorithm algorithm, TestTrainSplitFilter filter) {
Instance[] trainSet = filter.getTrainingSet().getInstances();
Instance[] testSet = filter.getTestingSet().getInstances();
for(int i = 0; i < ITERATIONS; i++) {
int correctTrain = 0, correctTest = 0;
algorithm.train();
for(int j = 0; j < trainSet.length; j++) {
neuralNetwork.setInputValues(trainSet[j].getData());
neuralNetwork.run(); // feed forward
double expected = Double.parseDouble(trainSet[j].getLabel().toString());
double out = Double.parseDouble(neuralNetwork.getOutputValues().toString());
if(Math.abs(expected-out) < 0.5D) {
correctTrain++;
}
}
for(int j = 0; j < testSet.length; j++) {
neuralNetwork.setInputValues(testSet[j].getData());
neuralNetwork.run(); // feed forward
double expected = Double.parseDouble(testSet[j].getLabel().toString());
double out = Double.parseDouble(neuralNetwork.getOutputValues().toString());
if(Math.abs(expected-out) < 0.5D) {
correctTest++;
}
}
System.out.println(i + "," + ((double)correctTrain/trainSet.length) + "," + ((double)correctTest/testSet.length));
}
}
private static Instance[] getData(String s) {
try {
ArffDataSetReader reader = new ArffDataSetReader(getHomeDirectory() + s);
Instance[] instances = reader.read().getInstances();
for(int i = 0; i < instances.length; i++) {
double[] in = new double[instances[i].getData().size()-1];
double out = instances[i].getData().get(instances[i].getData().size()-1);
for(int j = 0; j < instances[i].getData().size()-1; j++) {
in[j] = instances[i].getData().get(j);
}
instances[i] = new Instance(in);
instances[i].setLabel(new Instance(out == 1 ? 1 : 0));
}
return instances;
} catch(Exception e) {
e.printStackTrace();
return null;
}
}
private static String getHomeDirectory() {
return System.getProperty("user.home");
}
private static long timestamp = 0;
public static void startRecording(String s) {
timestamp = System.currentTimeMillis();
System.out.println("Recording time for: " + s);
}
public static void stopRecording() {
long diff = System.currentTimeMillis()-timestamp;
System.out.println("Time elapsed: " + diff);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment