Skip to content

Instantly share code, notes, and snippets.

@AlexDBlack
Created November 7, 2018 07:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AlexDBlack/8c24ec9b50a979a85a17d14bed5d1d0c to your computer and use it in GitHub Desktop.
Save AlexDBlack/8c24ec9b50a979a85a17d14bed5d1d0c to your computer and use it in GitHub Desktop.
SameDiff sd = SameDiff.create();
//Properties for MNIST dataset:
int nIn = 28*28;
int nOut = 10;
//Create input and label variables
SDVariable in = sd.var("input", -1, nIn); //Shape: [?, 784] - i.e., minibatch x 784 for MNIST
SDVariable label = sd.var("label", -1, nOut); //Shape: [?, 10] - i.e., minibatch x 10 for MNIST
sd.addAsPlaceHolder("input");
sd.addAsPlaceHolder("label");
//Define hidden layer - MLP (fully connected)
int layerSize0 = 128;
SDVariable w0 = sd.var("w0", new XavierInitScheme('c', nIn, layerSize0), nIn, layerSize0);
SDVariable b0 = sd.zero("b0", 1, layerSize0);
SDVariable activations0 = sd.tanh(in.mmul(w0).add(b0));
//Define output layer - MLP (fully connected) + softmax
SDVariable w1 = sd.var("w1", new XavierInitScheme('c', layerSize0, nOut), layerSize0, nOut);
SDVariable b1 = sd.zero("b1", 1, nOut);
SDVariable z1 = activations0.mmul(w1).add("prediction", b1);
SDVariable softmax = sd.softmax("softmax", z1);
//Define loss function:
SDVariable diff = sd.f().squaredDifference(softmax, label);
SDVariable lossMse = diff.mean();
//Create and set the training configuration
double learningRate = 1e-3;
TrainingConfig config = new TrainingConfig.Builder()
.l2(1e-4) //L2 regularization
.updater(new Adam(learningRate)) //Adam optimizer with specified learning rate
.dataSetFeatureMapping("input") //DataSet features array should be associated with variable "input"
.dataSetLabelMapping("label") //DataSet label array should be associated with variable "label"
.build();
sd.setTrainingConfig(config);
int batchSize = 32;
DataSetIterator trainData = new MnistDataSetIterator(batchSize, true, 12345);
DataSetIterator testData = new MnistDataSetIterator(batchSize, false, 12345);
//Perform training for 2 epochs
int numEpochs = 2;
sd.fit(trainData, numEpochs);
//Evaluate on test set:
String outputVariable = "softmax";
Evaluation evaluation = new Evaluation();
sd.evaluate(testData, outputVariable, evaluation);
//Print evaluation statistics:
System.out.println(evaluation.stats());
//Save the trained network. 2 options
//Saving Option 1: Save for inference only - no updater state. FlatBuffers format
File saveFileForInference = new File("sameDiffExampleInference.fb");
sd.asFlatFile(saveFileForInference);
SameDiff loadedForInference = SameDiff.fromFlatFile(saveFileForInference);
//Saving Option 2: Save for further training - with updater state (larger file size, updater not required for inference). Zip format
File saveFileForFurtherTraining = new File("sameDiffExampleTraining.zip");
sd.saveWithTrainingConfig(saveFileForFurtherTraining);
SameDiff loadedForFurtherTraining = SameDiff.restoreFromTrainingConfigZip(saveFileForFurtherTraining);
loadedForFurtherTraining.fit(trainData, 1); //************ EXCEPTION HERE **************
evaluation = new Evaluation();
sd.evaluate(testData, outputVariable, evaluation);
System.out.println(evaluation.stats());
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment