Skip to content

Instantly share code, notes, and snippets.

@P8H
Last active August 31, 2016 16:58
Show Gist options
  • Save P8H/f9576e994bfbfa7ab91e76b593b52c8b to your computer and use it in GitHub Desktop.
Save P8H/f9576e994bfbfa7ab91e76b593b52c8b to your computer and use it in GitHub Desktop.
Autoencoder with evaluation
import org.deeplearning4j.datasets.fetchers.MnistDataFetcher;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.RBM;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Collections;
import org.deeplearning4j.eval.Evaluation;
import org.nd4j.linalg.api.ndarray.INDArray;
/**
* ***** NOTE: This example has not been tuned. It requires additional work to
* produce sensible results *****
*
* @author Adam Gibson
*/
public class DeepAutoEncoderExample {
private static Logger log = LoggerFactory.getLogger(DeepAutoEncoderExample.class);
public static void main(String[] args) throws Exception {
final int numRows = 28;
final int numColumns = 28;
int seed = 123;
int numSamples = MnistDataFetcher.NUM_EXAMPLES;
int batchSize = 1000;
int iterations = 1;
int listenerFreq = iterations / 5;
log.info("Load data....");
DataSetIterator iter = new MnistDataSetIterator(batchSize, numSamples, true);
log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(iterations)
.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)
.list()
.layer(0, new RBM.Builder().nIn(numRows * numColumns).nOut(1000).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.layer(1, new RBM.Builder().nIn(1000).nOut(500).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.layer(2, new RBM.Builder().nIn(500).nOut(250).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.layer(3, new RBM.Builder().nIn(250).nOut(100).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.layer(4, new RBM.Builder().nIn(100).nOut(30).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()) //encoding stops
.layer(5, new RBM.Builder().nIn(30).nOut(100).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()) //decoding starts
.layer(6, new RBM.Builder().nIn(100).nOut(250).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.layer(7, new RBM.Builder().nIn(250).nOut(500).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.layer(8, new RBM.Builder().nIn(500).nOut(1000).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.layer(9, new OutputLayer.Builder(LossFunctions.LossFunction.RMSE_XENT).nIn(1000).nOut(numRows * numColumns).build())
.pretrain(true).backprop(true)
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(Collections.singletonList((IterationListener) new ScoreIterationListener(listenerFreq)));
log.info("Train model....");
while (iter.hasNext()) {
DataSet next = iter.next();
model.fit(new DataSet(next.getFeatureMatrix(), next.getFeatureMatrix()));
}
System.out.println("Evaluate model....");
iter.reset();
Evaluation eval = new Evaluation(numRows*numColumns);
while (iter.hasNext()) {
DataSet next = iter.next();
INDArray features = next.getFeatureMatrix();
INDArray predicted = model.output(features, false);
eval.eval(features, predicted);
//check autoencode requirement input == output
}
System.out.println(eval.stats());
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment