Skip to content

Instantly share code, notes, and snippets.

@traviskaufman
Created October 25, 2016 01:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save traviskaufman/3f0fca735e7f482e16ad77d22fcfddf7 to your computer and use it in GitHub Desktop.
Save traviskaufman/3f0fca735e7f482e16ad77d22fcfddf7 to your computer and use it in GitHub Desktop.
dl4j's MLPMnistSingleLayerExample in Scala
/**
* @see https://github.com/deeplearning4j/dl4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/feedforward/mnist/MLPMnistSingleLayerExample.java
*/
import scala.collection.JavaConversions._
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator
import org.deeplearning4j.eval.Evaluation
import org.deeplearning4j.nn.api.OptimizationAlgorithm
import org.deeplearning4j.nn.conf.NeuralNetConfiguration
import org.deeplearning4j.nn.conf.Updater
import org.deeplearning4j.nn.conf.layers.DenseLayer
import org.deeplearning4j.nn.conf.layers.OutputLayer
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.deeplearning4j.nn.weights.WeightInit
import org.deeplearning4j.optimize.listeners.ScoreIterationListener
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction
import org.slf4j.LoggerFactory
object MLPMnistSingleLayerExample extends App {
val log = LoggerFactory.getLogger(getClass)
// The number of rows of a matrix
val numRows = 28
// The number of columns of a matrix
val numCols = 28
// Number of possible outcomes (e.g. labels 0 through 9).
val outputNum = 10
// How many examples to fetch with each step
val batchSize = 128
// This random-number generator applies a seed to ensure that the same initial weights are used
// when training. We'll explain why this matters later.
val rngSeed = 123
// An epoch is a complete pass through a given dataset
val numEpochs = 15
val mnistTrain: DataSetIterator = new MnistDataSetIterator(batchSize, true, rngSeed)
val mnistTest: DataSetIterator = new MnistDataSetIterator(batchSize, false, rngSeed)
log.info("Build model...")
val conf = new NeuralNetConfiguration.Builder()
.seed(rngSeed)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.iterations(1)
.learningRate(0.006)
.updater(Updater.NESTEROVS).momentum(0.9)
.regularization(true).l2(1e-4)
.list()
.layer(0, new DenseLayer.Builder()
.nIn(numRows * numCols)
.nOut(1000)
.activation("relu")
.weightInit(WeightInit.XAVIER)
.build())
.layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
.nIn(1000)
.nOut(outputNum)
.activation("softmax")
.weightInit(WeightInit.XAVIER)
.build())
.pretrain(false)
.backprop(true)
.build();
val model = new MultiLayerNetwork(conf)
model.init()
// print the score with every 1 iteration
model.setListeners(new ScoreIterationListener(200))
log.info("Train model...")
for (i <- 0 to numEpochs) {
model.fit(mnistTrain)
}
log.info("Evaluate model...")
val eval = new Evaluation(outputNum)
mnistTest.map(t => (t.getFeatureMatrix, t.getLabels)).foreach {
case (fm, labels) => eval.eval(labels, model.output(fm))
}
log.info(eval.stats())
log.info("***** Example Finished *****")
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment