Created
November 7, 2018 07:57
-
-
Save AlexDBlack/8c24ec9b50a979a85a17d14bed5d1d0c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
SameDiff sd = SameDiff.create(); | |
//Properties for MNIST dataset: | |
int nIn = 28*28; | |
int nOut = 10; | |
//Create input and label variables | |
SDVariable in = sd.var("input", -1, nIn); //Shape: [?, 784] - i.e., minibatch x 784 for MNIST | |
SDVariable label = sd.var("label", -1, nOut); //Shape: [?, 10] - i.e., minibatch x 10 for MNIST | |
sd.addAsPlaceHolder("input"); | |
sd.addAsPlaceHolder("label"); | |
//Define hidden layer - MLP (fully connected) | |
int layerSize0 = 128; | |
SDVariable w0 = sd.var("w0", new XavierInitScheme('c', nIn, layerSize0), nIn, layerSize0); | |
SDVariable b0 = sd.zero("b0", 1, layerSize0); | |
SDVariable activations0 = sd.tanh(in.mmul(w0).add(b0)); | |
//Define output layer - MLP (fully connected) + softmax | |
SDVariable w1 = sd.var("w1", new XavierInitScheme('c', layerSize0, nOut), layerSize0, nOut); | |
SDVariable b1 = sd.zero("b1", 1, nOut); | |
SDVariable z1 = activations0.mmul(w1).add("prediction", b1); | |
SDVariable softmax = sd.softmax("softmax", z1); | |
//Define loss function: | |
SDVariable diff = sd.f().squaredDifference(softmax, label); | |
SDVariable lossMse = diff.mean(); | |
//Create and set the training configuration | |
double learningRate = 1e-3; | |
TrainingConfig config = new TrainingConfig.Builder() | |
.l2(1e-4) //L2 regularization | |
.updater(new Adam(learningRate)) //Adam optimizer with specified learning rate | |
.dataSetFeatureMapping("input") //DataSet features array should be associated with variable "input" | |
.dataSetLabelMapping("label") //DataSet label array should be associated with variable "label" | |
.build(); | |
sd.setTrainingConfig(config); | |
int batchSize = 32; | |
DataSetIterator trainData = new MnistDataSetIterator(batchSize, true, 12345); | |
DataSetIterator testData = new MnistDataSetIterator(batchSize, false, 12345); | |
//Perform training for 2 epochs | |
int numEpochs = 2; | |
sd.fit(trainData, numEpochs); | |
//Evaluate on test set: | |
String outputVariable = "softmax"; | |
Evaluation evaluation = new Evaluation(); | |
sd.evaluate(testData, outputVariable, evaluation); | |
//Print evaluation statistics: | |
System.out.println(evaluation.stats()); | |
//Save the trained network. 2 options | |
//Saving Option 1: Save for inference only - no updater state. FlatBuffers format | |
File saveFileForInference = new File("sameDiffExampleInference.fb"); | |
sd.asFlatFile(saveFileForInference); | |
SameDiff loadedForInference = SameDiff.fromFlatFile(saveFileForInference); | |
//Saving Option 2: Save for further training - with updater state (larger file size, updater not required for inference). Zip format | |
File saveFileForFurtherTraining = new File("sameDiffExampleTraining.zip"); | |
sd.saveWithTrainingConfig(saveFileForFurtherTraining); | |
SameDiff loadedForFurtherTraining = SameDiff.restoreFromTrainingConfigZip(saveFileForFurtherTraining); | |
loadedForFurtherTraining.fit(trainData, 1); //************ EXCEPTION HERE ************** | |
evaluation = new Evaluation(); | |
sd.evaluate(testData, outputVariable, evaluation); | |
System.out.println(evaluation.stats()); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment