Created
February 9, 2016 15:52
-
-
Save kreyssel/44ba25fbf30ef65a3ce6 to your computer and use it in GitHub Desktop.
DBNMnistFullExample
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.canova.api.split.FileSplit; | |
import org.canova.image.recordreader.ImageRecordReader; | |
import org.deeplearning4j.datasets.canova.RecordReaderDataSetIterator; | |
import org.deeplearning4j.datasets.iterator.DataSetIterator; | |
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; | |
import org.deeplearning4j.eval.Evaluation; | |
import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
import org.deeplearning4j.nn.conf.GradientNormalization; | |
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.layers.OutputLayer; | |
import org.deeplearning4j.nn.conf.layers.RBM; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.optimize.api.IterationListener; | |
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.dataset.DataSet; | |
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import java.awt.image.BufferedImage; | |
import java.io.File; | |
import java.util.Collections; | |
/** | |
* Created by agibsonccc on 9/11/14. | |
*/ | |
public class DBNMnistFullExample { | |
/** DOCUMENT ME! */ | |
private static Logger log = LoggerFactory.getLogger( DBNMnistFullExample.class ); | |
/** | |
* DOCUMENT ME! | |
* | |
* @param args DOCUMENT ME! | |
* | |
* @throws Exception DOCUMENT ME! | |
*/ | |
public static void main( String[] args ) throws Exception { | |
final int numRows = 20; | |
final int numColumns = 20; | |
int outputNum = 10; | |
int numSamples = 60000; | |
int batchSize = numRows * numColumns; | |
int batchRows = batchSize; | |
int iterations = 10; | |
int seed = 123; | |
int listenerFreq = batchSize / 5; | |
log.info( "Load data...." ); | |
//DataSetIterator iter = new MnistDataSetIterator( batchSize, numSamples, true ); | |
ImageRecordReader imageRecordReader = new ImageRecordReader( numRows, numColumns, BufferedImage.TYPE_BYTE_GRAY, | |
true ); | |
imageRecordReader.initialize( new FileSplit( new File( "src/test/resources/data/" ) ) ); | |
//DataSetIterator iter = new MnistDataSetIterator( batchSize, numSamples, true ); | |
DataSetIterator iter = new RecordReaderDataSetIterator( imageRecordReader, batchSize ); | |
log.info( "Build model...." ); | |
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed( seed ) | |
.gradientNormalization( | |
GradientNormalization.ClipElementWiseAbsoluteValue ) | |
.gradientNormalizationThreshold( 1.0 ) | |
.iterations( iterations ).momentum( 0.5 ) | |
.momentumAfter( Collections.singletonMap( 3, | |
0.9 ) ) | |
.optimizationAlgo( | |
OptimizationAlgorithm.CONJUGATE_GRADIENT ) | |
.list( 4 ) | |
.layer( 0, new RBM.Builder().nIn( | |
batchSize ) | |
.nOut( | |
batchRows ) | |
.weightInit( | |
WeightInit.XAVIER ) | |
.lossFunction( | |
LossFunction.RMSE_XENT ) | |
.visibleUnit( | |
RBM | |
.VisibleUnit.BINARY ) | |
.hiddenUnit( | |
RBM | |
.HiddenUnit.BINARY ) | |
.build() ) | |
.layer( 1, new RBM.Builder().nIn( | |
batchRows ) | |
.nOut( 250 ) | |
.weightInit( | |
WeightInit.XAVIER ) | |
.lossFunction( | |
LossFunction.RMSE_XENT ) | |
.visibleUnit( | |
RBM | |
.VisibleUnit.BINARY ) | |
.hiddenUnit( | |
RBM | |
.HiddenUnit.BINARY ) | |
.build() ) | |
.layer( 2, new RBM.Builder().nIn( 250 ) | |
.nOut( 200 ) | |
.weightInit( | |
WeightInit.XAVIER ) | |
.lossFunction( | |
LossFunction.RMSE_XENT ) | |
.visibleUnit( | |
RBM | |
.VisibleUnit.BINARY ) | |
.hiddenUnit( | |
RBM | |
.HiddenUnit.BINARY ) | |
.build() ) | |
.layer( 3, | |
new OutputLayer.Builder( | |
LossFunction.NEGATIVELOGLIKELIHOOD ) | |
.activation( "softmax" ).nIn( 200 ).nOut( outputNum ).build() ).pretrain( true ).backprop( false ) | |
.build(); | |
MultiLayerNetwork model = new MultiLayerNetwork( conf ); | |
model.init(); | |
model.setListeners( Collections.singletonList( | |
( IterationListener )new ScoreIterationListener( listenerFreq ) ) ); | |
log.info( "Train model...." ); | |
model.fit( iter ); // achieves end to end pre-training | |
log.info( "Evaluate model...." ); | |
Evaluation eval = new Evaluation( outputNum ); | |
DataSetIterator testIter = new MnistDataSetIterator( 100, 10000 ); | |
while( testIter.hasNext() ) { | |
DataSet testMnist = testIter.next(); | |
INDArray predict2 = model.output( testMnist.getFeatureMatrix() ); | |
eval.eval( testMnist.getLabels(), predict2 ); | |
} | |
log.info( eval.stats() ); | |
log.info( "****************Example finished********************" ); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment