Created
August 25, 2016 14:47
-
-
Save frankibem/94e588cb2d8ccda2af675f9bde3e25fa to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.test; | |
import org.encog.engine.network.activation.ActivationSigmoid; | |
import org.encog.engine.network.activation.ActivationTANH; | |
import org.encog.ml.data.MLDataSet; | |
import org.encog.ml.data.basic.BasicMLDataSet; | |
import org.encog.ml.train.strategy.StopTrainingStrategy; | |
import org.encog.neural.networks.BasicNetwork; | |
import org.encog.neural.networks.layers.BasicLayer; | |
import org.encog.neural.networks.training.propagation.Propagation; | |
import org.encog.neural.networks.training.propagation.back.Backpropagation; | |
import org.encog.persist.EncogDirectoryPersistence; | |
import org.encog.util.simple.EncogUtility; | |
import java.io.File; | |
public class Xor { | |
private static final String FILENAME = "network.eg"; | |
private static final double[][] XOR_INPUT = {{0.0, 0.0}, {0.0, 1.0}, {1.0, 0.0}, {1.0, 1.0}}; | |
private static final double[][] XOR_IDEAL = {{0.0}, {1.0}, {1.0}, {0.0}}; | |
public static void main(String[] args) { | |
trainNetwork(); | |
loadAndEvaluate(); | |
} | |
private static void trainNetwork() { | |
// Create a simple 2-3-1 feedforward network with tanh activation functions | |
BasicNetwork network = EncogUtility.simpleFeedForward(2, 3, 0, 1, true); | |
MLDataSet trainingSet = new BasicMLDataSet(XOR_INPUT, XOR_IDEAL); | |
Propagation train = new Backpropagation(network, trainingSet); | |
int epoch = 1; | |
do { | |
train.iteration(); | |
System.out.println("Epoch #" + epoch + " Error: " + train.getError()); | |
epoch++; | |
} while (train.getError() > 0.009); | |
double e = network.calculateError(trainingSet); | |
System.out.println("Network trained to error: " + e); | |
System.out.println("Saving network..."); | |
EncogDirectoryPersistence.saveObject(new File(FILENAME), network); | |
} | |
private static void loadAndEvaluate() { | |
System.out.println("Loading Network..."); | |
BasicNetwork network = (BasicNetwork) EncogDirectoryPersistence.loadObject(new File(FILENAME)); | |
BasicMLDataSet trainingSet = new BasicMLDataSet(XOR_INPUT, XOR_IDEAL); | |
double e = network.calculateError(trainingSet); | |
System.out.println("Loaded network's error is (should be the same as above):" + e); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment