Skip to content

Instantly share code, notes, and snippets.

@osipov
Created October 11, 2016 14:54
Show Gist options
  • Save osipov/2da9af5273dd2d169b9f04be503aebd1 to your computer and use it in GitHub Desktop.
Save osipov/2da9af5273dd2d169b9f04be503aebd1 to your computer and use it in GitHub Desktop.
package org.deeplearning4j.examples.feedforward.xor;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
//import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator;
//import org.deeplearning4j.eval.Evaluation;
//import org.deeplearning4j.nn.api.Model;
//import org.deeplearning4j.nn.api.OptimizationAlgorithm;
//import org.deeplearning4j.nn.conf.Updater;
//import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
//import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
//import org.deeplearning4j.nn.conf.layers.DenseLayer;
//import org.deeplearning4j.nn.conf.layers.OutputLayer;
//import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
//import org.deeplearning4j.nn.weights.WeightInit;
//import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
//import org.nd4j.linalg.api.ndarray.INDArray;
//import org.nd4j.linalg.dataset.DataSet;
//import org.nd4j.linalg.factory.Nd4j;
//import org.nd4j.linalg.lossfunctions.LossFunctions;
//
////import org.nd4j.jita.conf.CudaEnvironment;
import java.util.Arrays;
import java.util.Random;
/**
* Created by osipov on 6/28/16.
*/
public class FizzBuzz {
public static int decodeBinary(INDArray arr) {
int i = 0;
for (int j = 0; j < arr.length(); j++) {
i += Math.pow(2, j) * arr.getInt(j);
}
return i;
}
public static int decodeBinary(float[] b) {
int i = 0;
for (int j = 0; j < b.length; j++) {
i += Math.pow(2, j)*b[j];
}
return i;
}
// public static float[] encodeBinary(int val, int numDigits) {
// float[] result = new float[numDigits];
// for (int i = 0; i < numDigits; i++) {
// result[i] = (val >> i) & 1;
// }
// return result;
// }
public static INDArray encodeBinary(int val, int numDigits) {
INDArray encoded = Nd4j.zeros(numDigits);
for (int i = 0; i < numDigits; i++)
encoded.putScalar(i, (val >> i) & 1);
return encoded;
}
// public static float[] encodeFizzBuzz(int i) {
// if (i % 15 == 0) return new float[]{0.0f, 0.0f, 0.0f, 1.0f};
// else
// if (i % 5 == 0) return new float[]{0.0f, 0.0f, 1.0f, 0.0f};
// else
// if (i % 3 == 0) return new float[]{0.0f, 1.0f, 0.0f, 0.0f};
//
// else return new float[]{1.0f, 0.0f, 0.0f, 0.0f};
// }
public static INDArray encodeFizzBuzz(int i) {
INDArray encoded = Nd4j.zeros(4);
if (i % 15 == 0) return encoded.putScalar(3, 1);
else if (i % 5 == 0) return encoded.putScalar(2, 1);
else if (i % 3 == 0) return encoded.putScalar(1, 1);
else return encoded.putScalar(0, 1);
}
//
// public static int[] encodeFizzBuzz(int i) {
// if (i % 15 == 0) return new int[]{0, 0, 0, 1};
// else
// if (i % 5 == 0) return new int[]{0, 0, 1, 0};
// else
// if (i % 3 == 0) return new int[]{0, 1, 0, 0};
//
// else return new int[]{1, 0, 0, 0};
// }
public static void main(String[] args) {
// org.nd4j.jita.conf.CudaEnvironment.getInstance().getConfiguration().allowMultiGPU(true);
Nd4j.ENFORCE_NUMERICAL_STABILITY = true;
// final int NUM_UPPER = 32768;
final int NUM_UPPER = 8192;
final int NUM_DIGITS = 10;
int rngSeed = 12345;
int numEpochs = 5000;
int batchSize = 128;
double learningRate = 0.3;
double regularizationRate = learningRate * 0.0005;
double nesterovsMomentum = 0.9;
Random rnd = new Random(rngSeed);
// int numEpochs = 1000;
INDArray trainFeaturesTmp = Nd4j.zeros(NUM_UPPER - 101, NUM_DIGITS);
INDArray trainLabelsTmp = Nd4j.zeros(NUM_UPPER - 101, 4);
int trainCount = 0;
for (int i = 101; i < NUM_UPPER; i++) {
INDArray features = encodeBinary(i, NUM_DIGITS);
INDArray labels = encodeFizzBuzz(i);
boolean lucky = false;
if (labels.getInt(0) == 1) lucky = rnd.nextInt(8) == 0;
else
if (labels.getInt(1) == 1) lucky = rnd.nextInt(4) == 0;
else
if (labels.getInt(2) == 1) lucky = rnd.nextInt(2) == 0;
else
if (labels.getInt(3) == 1) lucky = true;
if (lucky) {
trainFeaturesTmp.putRow(trainCount, features);
trainLabelsTmp.putRow(trainCount, labels);
trainCount++;
}
}
int[] counts = new int[4];
for (int i = 0; i < trainCount; i++) {
if (trainLabelsTmp.getRow(i).getInt(0) == 1) counts[0] += 1;
else
if (trainLabelsTmp.getRow(i).getInt(1) == 1) counts[1] += 1;
else
if (trainLabelsTmp.getRow(i).getInt(2) == 1) counts[2] += 1;
else
if (trainLabelsTmp.getRow(i).getInt(3) == 1) counts[3] += 1;
}
System.out.println("Train count: " + Arrays.toString(counts));
INDArray trainFeatures = Nd4j.zeros(trainCount, NUM_DIGITS);
INDArray trainLabels = Nd4j.zeros(trainCount, 4);
for (int i = 0; i < trainCount; i++) {
trainFeatures.putRow(i, trainFeaturesTmp.getRow(i));
trainLabels.putRow(i, trainLabelsTmp.getRow(i));
}
INDArray testFeatures = Nd4j.zeros(100, NUM_DIGITS);
for (int i = 1; i < 101; i++) testFeatures.putRow(i - 1, encodeBinary(i, NUM_DIGITS));
INDArray testLabels = Nd4j.zeros(100, 4);
for (int i = 1; i < 101; i++) testLabels.putRow(i - 1, encodeFizzBuzz(i));
final DataSet trainDataset = new DataSet(trainFeatures, trainLabels);
final DataSet testDataset = new DataSet(testFeatures, testLabels);
// for (int i = 0; i < 100; i++) {
// System.out.println(testFeatures.getRow(i).toString() + " " + testLabels.getRow(i).toString());
// }
// if (true) return;
trainDataset.shuffle(rngSeed);
DataSetIterator trainDatasetBatches = new ListDataSetIterator(trainDataset.asList(), batchSize);
// DataSetIterator testDatasetBatches = new ListDataSetIterator(testDataset.asList(), batchSize);
System.out.println("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(rngSeed)
.optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT)
// .biasInit(0)
.iterations(1)
.learningRate(learningRate)
.activation("relu")
.weightInit(WeightInit.XAVIER)
.miniBatch(true)
.useDropConnect(false)
.updater(Updater.NESTEROVS).momentum(nesterovsMomentum)
// .regularization(true).l2(regularizationRate)
.list()
.layer(0, new DenseLayer.Builder()
.nIn(10)
.nOut(100)
// .weightInit(WeightInit.DISTRIBUTION)
// .dist(new NormalDistribution(0.0, 0.01))
// .activation("relu")
.build())
.layer(1, new DenseLayer.Builder()
.nIn(100)
.nOut(100)
// .weightInit(WeightInit.DISTRIBUTION)
// .dist(new NormalDistribution(0.0, 0.01))
// .activation("relu")
.build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.nIn(100)
.nOut(4)
// .weightInit(WeightInit.DISTRIBUTION)
// .dist(new UniformDistribution(0.1, 1))
.activation("softmax")
.build())
.pretrain(false).backprop(true)
.build();
final MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
// add an listener which outputs the error every 100 parameter updates
// model.setListeners(new ScoreIterationListener(100));
model.setListeners(new ScoreIterationListener[]{
new ScoreIterationListener(100),
// new ScoreIterationListener(200) {
// private int myCount = 0;
//
// @Override
// public void iterationDone(Model m, int iter) {
//// super.iterationDone(m, iter);
// try {
// if (myCount % 200 == 0 && myCount > 0) {
// org.deeplearning4j.nn.multilayer.MultiLayerNetwork mod = (org.deeplearning4j.nn.multilayer.MultiLayerNetwork) m;
// Evaluation eval = new Evaluation(4);
// INDArray output = mod.output(testDataset.getFeatures());
// eval.eval(testDataset.getLabels(), output);
// System.out.println("Test Iteration " + myCount);
// System.out.println(eval.stats(true));
// }
// myCount++;
// } catch (Throwable t) {
// System.out.println("caught throwable " + t);
// }
// }
// }
});
System.out.println("Train model....");
for( int i=0; i<numEpochs; i++ ) {
model.fit(trainDatasetBatches);
// model.fit(trainDataset);
}
System.out.println("Evaluate model....");
{
System.out.println("****************Train eval********************");
Evaluation eval = new Evaluation(4);
eval.eval(trainDataset.getLabels(), model.output(trainDataset.getFeatures()));
System.out.println(eval.stats());
System.out.println("****************Train eval********************");
}
{
System.out.println("****************Test eval********************");
Evaluation eval = new Evaluation(4);
eval.eval(testDataset.getLabels(), model.output(testDataset.getFeatures()));
System.out.println(eval.stats());
System.out.println("****************Test eval********************");
}
System.out.println("****************Example finished********************");
for (int i = 0; i < 16; i++) {
System.out.println((i + 1) + " " + testFeatures.getRow(i).toString() + " " + model.output(testFeatures.getRow(i)).toString());
}
// System.out.println(model.output(testDataset.getFeatures()));
// for (int i = 1; i < 101; i++) {
// INDArray o = model.output(encodeBinary(i, NUM_DIGITS));
//
// System.out.println(i + " " + o.toString() + " " + o.maxNumber() + " " + o.eps(o.maxNumber()).toString());
// }
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment