Created
February 20, 2019 18:32
-
-
Save aidancbrady/274ef25ba3b402d42749e73458806219 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import func.nn.backprop.BackPropagationNetwork; | |
import func.nn.backprop.BackPropagationNetworkFactory; | |
import opt.OptimizationAlgorithm; | |
import opt.example.NeuralNetworkOptimizationProblem; | |
import opt.ga.StandardGeneticAlgorithm; | |
import shared.DataSet; | |
import shared.Instance; | |
import shared.SumOfSquaresError; | |
import shared.filt.TestTrainSplitFilter; | |
import shared.reader.ArffDataSetReader; | |
public class RandomizedOptimization | |
{ | |
private static final double ITERATIONS = 1000; | |
public static void main(String[] args) { | |
// Load instances | |
Instance[] data = getData("/Documents/Georgia Tech/Spring 2019/cs4641/Datasets/phishing-websites.arff"); | |
if(data == null) { | |
System.out.println("Error loading data."); | |
return; | |
} | |
DataSet dataSet = new DataSet(data); | |
SumOfSquaresError errorFunction = new SumOfSquaresError(); | |
TestTrainSplitFilter filter = new TestTrainSplitFilter(70); | |
filter.filter(dataSet); | |
BackPropagationNetworkFactory nnFactory = new BackPropagationNetworkFactory(); | |
BackPropagationNetwork neuralNetwork = nnFactory.createClassificationNetwork(new int[] {data[0].getData().size(), 1, 1}); | |
NeuralNetworkOptimizationProblem weightProblem = new NeuralNetworkOptimizationProblem(filter.getTrainingSet(), neuralNetwork, errorFunction); | |
//RandomizedHillClimbing rhc = new RandomizedHillClimbing(weightProblem); | |
//SimulatedAnnealing sa = new SimulatedAnnealing(1.0E5, 0.25, weightProblem); | |
StandardGeneticAlgorithm ga = new StandardGeneticAlgorithm(300, 72, 54, weightProblem); | |
startRecording("Curve"); | |
run(dataSet, neuralNetwork, ga, filter); | |
stopRecording(); | |
//gridSearch(dataSet, neuralNetwork, weightProblem, filter); | |
} | |
private static void gridSearch(DataSet dataSet, BackPropagationNetwork neuralNetwork, NeuralNetworkOptimizationProblem problem, TestTrainSplitFilter filter) { | |
Instance[] testSet = filter.getTestingSet().getInstances(); | |
int[] populations = {100, 200, 300, 400, 500}; | |
for(int population : populations) { | |
StandardGeneticAlgorithm algorithm = new StandardGeneticAlgorithm(population, (int)(0.24*population), (int)(0.18*population), problem); | |
for(int i = 0; i < ITERATIONS; i++) { | |
int correctTest = 0; | |
algorithm.train(); | |
for(int j = 0; j < testSet.length; j++) { | |
neuralNetwork.setInputValues(testSet[j].getData()); | |
neuralNetwork.run(); // feed forward | |
double expected = Double.parseDouble(testSet[j].getLabel().toString()); | |
double out = Double.parseDouble(neuralNetwork.getOutputValues().toString()); | |
if(Math.abs(expected-out) < 0.5D) { | |
correctTest++; | |
} | |
} | |
if(i == ITERATIONS-1) { | |
System.out.println("Pop: " + population + " = " + ((double)correctTest/testSet.length)); | |
} | |
} | |
} | |
} | |
private static void run(DataSet dataSet, BackPropagationNetwork neuralNetwork, OptimizationAlgorithm algorithm, TestTrainSplitFilter filter) { | |
Instance[] trainSet = filter.getTrainingSet().getInstances(); | |
Instance[] testSet = filter.getTestingSet().getInstances(); | |
for(int i = 0; i < ITERATIONS; i++) { | |
int correctTrain = 0, correctTest = 0; | |
algorithm.train(); | |
for(int j = 0; j < trainSet.length; j++) { | |
neuralNetwork.setInputValues(trainSet[j].getData()); | |
neuralNetwork.run(); // feed forward | |
double expected = Double.parseDouble(trainSet[j].getLabel().toString()); | |
double out = Double.parseDouble(neuralNetwork.getOutputValues().toString()); | |
if(Math.abs(expected-out) < 0.5D) { | |
correctTrain++; | |
} | |
} | |
for(int j = 0; j < testSet.length; j++) { | |
neuralNetwork.setInputValues(testSet[j].getData()); | |
neuralNetwork.run(); // feed forward | |
double expected = Double.parseDouble(testSet[j].getLabel().toString()); | |
double out = Double.parseDouble(neuralNetwork.getOutputValues().toString()); | |
if(Math.abs(expected-out) < 0.5D) { | |
correctTest++; | |
} | |
} | |
System.out.println(i + "," + ((double)correctTrain/trainSet.length) + "," + ((double)correctTest/testSet.length)); | |
} | |
} | |
private static Instance[] getData(String s) { | |
try { | |
ArffDataSetReader reader = new ArffDataSetReader(getHomeDirectory() + s); | |
Instance[] instances = reader.read().getInstances(); | |
for(int i = 0; i < instances.length; i++) { | |
double[] in = new double[instances[i].getData().size()-1]; | |
double out = instances[i].getData().get(instances[i].getData().size()-1); | |
for(int j = 0; j < instances[i].getData().size()-1; j++) { | |
in[j] = instances[i].getData().get(j); | |
} | |
instances[i] = new Instance(in); | |
instances[i].setLabel(new Instance(out == 1 ? 1 : 0)); | |
} | |
return instances; | |
} catch(Exception e) { | |
e.printStackTrace(); | |
return null; | |
} | |
} | |
private static String getHomeDirectory() { | |
return System.getProperty("user.home"); | |
} | |
private static long timestamp = 0; | |
public static void startRecording(String s) { | |
timestamp = System.currentTimeMillis(); | |
System.out.println("Recording time for: " + s); | |
} | |
public static void stopRecording() { | |
long diff = System.currentTimeMillis()-timestamp; | |
System.out.println("Time elapsed: " + diff); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment