Skip to content

Instantly share code, notes, and snippets.

@kreyssel
Created February 9, 2016 15:52
Show Gist options
  • Save kreyssel/44ba25fbf30ef65a3ce6 to your computer and use it in GitHub Desktop.
Save kreyssel/44ba25fbf30ef65a3ce6 to your computer and use it in GitHub Desktop.
DBNMnistFullExample
import org.canova.api.split.FileSplit;
import org.canova.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.canova.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.RBM;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.awt.image.BufferedImage;
import java.io.File;
import java.util.Collections;
/**
* Created by agibsonccc on 9/11/14.
*/
public class DBNMnistFullExample {
/** DOCUMENT ME! */
private static Logger log = LoggerFactory.getLogger( DBNMnistFullExample.class );
/**
* DOCUMENT ME!
*
* @param args DOCUMENT ME!
*
* @throws Exception DOCUMENT ME!
*/
public static void main( String[] args ) throws Exception {
final int numRows = 20;
final int numColumns = 20;
int outputNum = 10;
int numSamples = 60000;
int batchSize = numRows * numColumns;
int batchRows = batchSize;
int iterations = 10;
int seed = 123;
int listenerFreq = batchSize / 5;
log.info( "Load data...." );
//DataSetIterator iter = new MnistDataSetIterator( batchSize, numSamples, true );
ImageRecordReader imageRecordReader = new ImageRecordReader( numRows, numColumns, BufferedImage.TYPE_BYTE_GRAY,
true );
imageRecordReader.initialize( new FileSplit( new File( "src/test/resources/data/" ) ) );
//DataSetIterator iter = new MnistDataSetIterator( batchSize, numSamples, true );
DataSetIterator iter = new RecordReaderDataSetIterator( imageRecordReader, batchSize );
log.info( "Build model...." );
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed( seed )
.gradientNormalization(
GradientNormalization.ClipElementWiseAbsoluteValue )
.gradientNormalizationThreshold( 1.0 )
.iterations( iterations ).momentum( 0.5 )
.momentumAfter( Collections.singletonMap( 3,
0.9 ) )
.optimizationAlgo(
OptimizationAlgorithm.CONJUGATE_GRADIENT )
.list( 4 )
.layer( 0, new RBM.Builder().nIn(
batchSize )
.nOut(
batchRows )
.weightInit(
WeightInit.XAVIER )
.lossFunction(
LossFunction.RMSE_XENT )
.visibleUnit(
RBM
.VisibleUnit.BINARY )
.hiddenUnit(
RBM
.HiddenUnit.BINARY )
.build() )
.layer( 1, new RBM.Builder().nIn(
batchRows )
.nOut( 250 )
.weightInit(
WeightInit.XAVIER )
.lossFunction(
LossFunction.RMSE_XENT )
.visibleUnit(
RBM
.VisibleUnit.BINARY )
.hiddenUnit(
RBM
.HiddenUnit.BINARY )
.build() )
.layer( 2, new RBM.Builder().nIn( 250 )
.nOut( 200 )
.weightInit(
WeightInit.XAVIER )
.lossFunction(
LossFunction.RMSE_XENT )
.visibleUnit(
RBM
.VisibleUnit.BINARY )
.hiddenUnit(
RBM
.HiddenUnit.BINARY )
.build() )
.layer( 3,
new OutputLayer.Builder(
LossFunction.NEGATIVELOGLIKELIHOOD )
.activation( "softmax" ).nIn( 200 ).nOut( outputNum ).build() ).pretrain( true ).backprop( false )
.build();
MultiLayerNetwork model = new MultiLayerNetwork( conf );
model.init();
model.setListeners( Collections.singletonList(
( IterationListener )new ScoreIterationListener( listenerFreq ) ) );
log.info( "Train model...." );
model.fit( iter ); // achieves end to end pre-training
log.info( "Evaluate model...." );
Evaluation eval = new Evaluation( outputNum );
DataSetIterator testIter = new MnistDataSetIterator( 100, 10000 );
while( testIter.hasNext() ) {
DataSet testMnist = testIter.next();
INDArray predict2 = model.output( testMnist.getFeatureMatrix() );
eval.eval( testMnist.getLabels(), predict2 );
}
log.info( eval.stats() );
log.info( "****************Example finished********************" );
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment