Skip to content

Instantly share code, notes, and snippets.

@frankibem
Created August 25, 2016 14:47
Show Gist options
  • Save frankibem/94e588cb2d8ccda2af675f9bde3e25fa to your computer and use it in GitHub Desktop.
Save frankibem/94e588cb2d8ccda2af675f9bde3e25fa to your computer and use it in GitHub Desktop.
package com.test;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.engine.network.activation.ActivationTANH;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.encog.ml.train.strategy.StopTrainingStrategy;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.BasicLayer;
import org.encog.neural.networks.training.propagation.Propagation;
import org.encog.neural.networks.training.propagation.back.Backpropagation;
import org.encog.persist.EncogDirectoryPersistence;
import org.encog.util.simple.EncogUtility;
import java.io.File;
public class Xor {
private static final String FILENAME = "network.eg";
private static final double[][] XOR_INPUT = {{0.0, 0.0}, {0.0, 1.0}, {1.0, 0.0}, {1.0, 1.0}};
private static final double[][] XOR_IDEAL = {{0.0}, {1.0}, {1.0}, {0.0}};
public static void main(String[] args) {
trainNetwork();
loadAndEvaluate();
}
private static void trainNetwork() {
// Create a simple 2-3-1 feedforward network with tanh activation functions
BasicNetwork network = EncogUtility.simpleFeedForward(2, 3, 0, 1, true);
MLDataSet trainingSet = new BasicMLDataSet(XOR_INPUT, XOR_IDEAL);
Propagation train = new Backpropagation(network, trainingSet);
int epoch = 1;
do {
train.iteration();
System.out.println("Epoch #" + epoch + " Error: " + train.getError());
epoch++;
} while (train.getError() > 0.009);
double e = network.calculateError(trainingSet);
System.out.println("Network trained to error: " + e);
System.out.println("Saving network...");
EncogDirectoryPersistence.saveObject(new File(FILENAME), network);
}
private static void loadAndEvaluate() {
System.out.println("Loading Network...");
BasicNetwork network = (BasicNetwork) EncogDirectoryPersistence.loadObject(new File(FILENAME));
BasicMLDataSet trainingSet = new BasicMLDataSet(XOR_INPUT, XOR_IDEAL);
double e = network.calculateError(trainingSet);
System.out.println("Loaded network's error is (should be the same as above):" + e);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment