Skip to content

Instantly share code, notes, and snippets.

@ragnarok22
Created June 13, 2018 16:25
Show Gist options
  • Save ragnarok22/9803f534aa70019d39aa48532a855d00 to your computer and use it in GitHub Desktop.
Save ragnarok22/9803f534aa70019d39aa48532a855d00 to your computer and use it in GitHub Desktop.
sample code in DL4J and MNIST dataset
String dataset_home = System.getProperty("user.home") + File.separator + "datasets" + File.separator;
String mnist_home = dataset_home + "mnist_png" + File.separator;
String trainPath = mnist_home + "training";
String testPath = mnist_home + "testing";
int numRows = 28; // height
int numColumns = 28; // width
int channels = 1; // depth
int outputNum = 10;
int batchSize = 128;
int seed = 123;
Random rng = new Random(seed);
int numEpochs = 1;
String pathToSaveModel = "D:\\";
boolean saveModel = true;
File trainData = new File(trainPath);
FileSplit train = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, rng);
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader recordReader = new ImageRecordReader(numRows, numColumns, channels, labelMaker);
recordReader.initialize(train);
DataSetIterator mnistTrain = new RecordReaderDataSetIterator(recordReader, batchSize, 1, outputNum);
DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
scaler.fit(mnistTrain);
mnistTrain.setPreProcessor(scaler);
System.out.println("contruyendo el modelo");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Nesterovs())
// .updater(new Nesterovs(0.006, 0.9))
.l2(1e-4)
.list()
.layer(0, new DenseLayer.Builder()
.nIn(numRows * numColumns)
.nOut(100)
.activation(Activation.RELU)
.weightInit(WeightInit.XAVIER)
.build()
)
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(outputNum)
.activation(Activation.SOFTMAX)
.weightInit(WeightInit.XAVIER)
.build()
)
.pretrain(false).backprop(true)
.setInputType(InputType.convolutional(numRows, numColumns, channels))
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(100));
System.out.println("entrenando el modelo");
for (int i = 0; i < numEpochs; i++) {
model.fit(mnistTrain);
}
if (saveModel)
model.save(new File(pathToSaveModel + "modelDL4J.zip"));
System.out.println("evaluando el modelo");
File testData = new File(testPath);
FileSplit test = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, rng);
recordReader.reset();
recordReader.initialize(test);
DataSetIterator mnistTest = new RecordReaderDataSetIterator(recordReader, batchSize, 1, outputNum);
scaler.fit(mnistTest);
mnistTest.setPreProcessor(scaler);
Evaluation evaluation = new Evaluation(outputNum);
while (mnistTest.hasNext()) {
DataSet next = mnistTest.next();
INDArray output = model.output(next.getFeatureMatrix());
evaluation.eval(next.getLabels(), output);
}
System.out.println(evaluation.stats());
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment