Created
October 19, 2016 09:44
-
-
Save yilaguan/c3427958b86cbc13959921dc1bda82fb to your computer and use it in GitHub Desktop.
MnistTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.deeplearning4j.mlp; | |
import com.beust.jcommander.JCommander; | |
import com.beust.jcommander.Parameter; | |
import com.beust.jcommander.ParameterException; | |
import org.apache.hadoop.io.BytesWritable; | |
import org.apache.hadoop.io.Text; | |
import org.apache.hadoop.mapred.SequenceFileOutputFormat; | |
import org.apache.spark.SparkConf; | |
import org.apache.spark.api.java.JavaPairRDD; | |
import org.apache.spark.api.java.JavaRDD; | |
import org.apache.spark.api.java.JavaSparkContext; | |
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; | |
import org.deeplearning4j.eval.Evaluation; | |
import org.deeplearning4j.mlp.sequence.FromSequenceFilePairFunction; | |
import org.deeplearning4j.mlp.sequence.ToSequenceFilePairFunction; | |
import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.Updater; | |
import org.deeplearning4j.nn.conf.layers.DenseLayer; | |
import org.deeplearning4j.nn.conf.layers.OutputLayer; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.spark.api.Repartition; | |
import org.deeplearning4j.spark.api.RepartitionStrategy; | |
import org.deeplearning4j.spark.api.TrainingMaster; | |
import org.deeplearning4j.spark.api.stats.SparkTrainingStats; | |
import org.deeplearning4j.spark.data.DataSetExportFunction; | |
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer; | |
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; | |
import org.deeplearning4j.spark.stats.StatsUtils; | |
import org.nd4j.linalg.dataset.DataSet; | |
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
import java.util.ArrayList; | |
import java.util.List; | |
public class MnistTest { | |
@Parameter(names="-preprocessData", description = "Whether data should be saved and preprocessed (set to false to use already saved data)", arity = 1) | |
private boolean preprocessData = true; | |
@Parameter(names="-dataTrainSavePath", description = "Directory in which to save the serialized data sets - required. For example, file:/C:/Temp/MnistMLPPreprocessed/", required = true) | |
private String dataTrainSavePath; | |
@Parameter(names="-dataTestSavePath", description = "Directory in which to save the serialized data sets - required. For example, file:/C:/Temp/MnistMLPPreprocessed/", required = true) | |
private String dataTestSavePath; | |
@Parameter(names="-useSparkLocal", description = "Use spark local (helper for testing/running without spark submit)", arity = 1) | |
private boolean useSparkLocal = false; | |
@Parameter(names="-batchSizePerWorker", description = "Number of examples to fit each worker with") | |
private int batchSizePerWorker = 32; | |
@Parameter(names="-numEpochs", description = "Number of epochs for training") | |
private int numEpochs = 3; | |
public static void main(String[] args) throws Exception{ | |
new MnistTest().entryPoint(args); | |
} | |
protected void entryPoint(String[] args) throws Exception { | |
JCommander jcmdr = new JCommander(this); | |
try{ | |
jcmdr.parse(args); | |
} catch(ParameterException e){ | |
//User provides invalid input -> print the usage info | |
jcmdr.usage(); | |
try{ Thread.sleep(500); } catch(Exception e2){ } | |
throw e; | |
} | |
SparkConf sparkConf = new SparkConf(); | |
if(useSparkLocal) sparkConf.setMaster("local[*]"); | |
sparkConf.setAppName("MnistTestMLP"); | |
sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer"); | |
sparkConf.set("spark.kryo.registrator", "org.nd4j.Nd4jRegistrator"); | |
JavaSparkContext sc = new JavaSparkContext(sparkConf); | |
//First: preprocess data into a sequence file | |
if(preprocessData) { | |
DataSetIterator trainIter = new MnistDataSetIterator(batchSizePerWorker, true, 12345); | |
DataSetIterator testIter = new MnistDataSetIterator(batchSizePerWorker, false, 12345); | |
List<DataSet> trainList = new ArrayList<>(); | |
List<DataSet> testList = new ArrayList<>(); | |
while (trainIter.hasNext()) { | |
trainList.add(trainIter.next()); | |
} | |
while (testIter.hasNext()) { | |
testList.add(testIter.next()); | |
} | |
JavaRDD<DataSet> trainrdd = sc.parallelize(trainList); | |
JavaRDD<DataSet> testrdd = sc.parallelize(testList); | |
JavaPairRDD<Text,BytesWritable> forTrainSequenceFile = trainrdd.mapToPair(new ToSequenceFilePairFunction()); | |
JavaPairRDD<Text, BytesWritable> forTestSequenceFile = testrdd.mapToPair(new ToSequenceFilePairFunction()); | |
forTrainSequenceFile.saveAsHadoopFile(dataTrainSavePath, Text.class, BytesWritable.class, SequenceFileOutputFormat.class); | |
forTestSequenceFile.saveAsHadoopFile(dataTestSavePath, Text.class, BytesWritable.class, SequenceFileOutputFormat.class); | |
} | |
//Second: load the data from a sequence file | |
JavaPairRDD<Text,BytesWritable> trainSequenceFile = sc.sequenceFile(dataTrainSavePath, Text.class, BytesWritable.class); | |
JavaPairRDD<Text, BytesWritable> testSequenceFile = sc.sequenceFile(dataTestSavePath, Text.class, BytesWritable.class); | |
JavaRDD<DataSet> trainData = trainSequenceFile.map(new FromSequenceFilePairFunction()); | |
JavaRDD<DataSet> testData = testSequenceFile.map(new FromSequenceFilePairFunction()); | |
//---------------------------------- | |
//Second: conduct network training | |
//Network configuration: | |
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() | |
.seed(12345) | |
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | |
.iterations(1) | |
.activation("relu") | |
.weightInit(WeightInit.XAVIER) | |
.learningRate(0.0069) | |
.updater(Updater.NESTEROVS).momentum(0.9) | |
.regularization(true).l2(1e-4) | |
.list() | |
.layer(0, new DenseLayer.Builder().nIn(28*28).nOut(500).build()) | |
.layer(1, new DenseLayer.Builder().nIn(500).nOut(100).build()) | |
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) | |
.activation("softmax").nIn(100).nOut(10).build()) | |
.pretrain(false).backprop(true) | |
.build(); | |
//Configuration for Spark training: see http://deeplearning4j.org/spark for explanation of these configuration options | |
TrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(batchSizePerWorker) | |
.averagingFrequency(10) | |
.saveUpdater(true) | |
.workerPrefetchNumBatches(2) | |
.batchSizePerWorker(batchSizePerWorker) | |
.repartionData(Repartition.Always) | |
.repartitionStrategy(RepartitionStrategy.SparkDefault) | |
.exportDirectory("hdfs:///user/yilaguan/deeplearning4jParmerter/") | |
.build(); | |
SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, conf, tm); | |
// sparkNet.setCollectTrainingStats(true); | |
//Execute training: | |
for( int i=0; i<numEpochs; i++ ){ | |
sparkNet.fit(trainData); | |
System.out.println("Completed Epoch " + i); | |
} | |
//Execute evaluate model: | |
// System.out.println("Begin to evaluate model"); | |
// Evaluation eval = new Evaluation(10); | |
sparkNet.evaluate(testData); | |
sparkNet.getScore(); | |
System.out.println("Finish to evaluate model"); | |
// SparkTrainingStats stats = sparkNet.getSparkTrainingStats(); | |
// StatsUtils.exportStatsAsHtml(stats, "hdfs:///user/yilaguan/SparkStats.html", sc); | |
System.out.println("----- DONE -----"); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment