Skip to content

Instantly share code, notes, and snippets.

@osipov
Created October 12, 2016 19:08
Show Gist options
  • Save osipov/11bcc59c14b1a140d4f67ca865d56648 to your computer and use it in GitHub Desktop.
Save osipov/11bcc59c14b1a140d4f67ca865d56648 to your computer and use it in GitHub Desktop.
package org.deeplearning4j.examples.feedforward.xor;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.Random;
/**
* Created by @osipov on 10/12/16.
*/
public class FizzBuzz {
/**
* Encode a positive integer in binary using little endian style with a numDigits width
* @param val positive integer
* @param numDigits width of the binary number
* @return INDArray array with the binary encoding
*/
public static INDArray encodeBinary(int val, int numDigits) {
INDArray encoded = Nd4j.zeros(numDigits);
for (int i = 0; i < numDigits; i++)
encoded.putScalar(i, (val >> i) & 1);
return encoded;
}
/**
* Decode a binary number into a positive integer
* @param arr binary number stored little endian style in INDArray
* @return decoding of the binary number back into an integer
*/
public static int decodeBinary(INDArray arr) {
int i = 0;
for (int j = 0; j < arr.length(); j++) {
i += Math.pow(2, j) * arr.getInt(j);
}
return i;
}
/**
* Hot one encode a positive integer to one of the 4 "fizzbuzz" classes as follows
* [0.0 1.0 0.0 0.0] if the number is divisible by 3
* [0.0 0.0 1.0 0.0] if the number is divisible by 5
* [0.0 0.0 0.0 1.0] if the number is divisible by 3 and 5
* [1.0 0.0 0.0 0.0] otherwise
* @param i
* @return INDArray containing the encoding
*/
public static INDArray encodeFizzBuzz(int i) {
INDArray encoded = Nd4j.zeros(4);
if (i % 15 == 0 && i != 1) return encoded.putScalar(3, 1);
else if (i % 5 == 0) return encoded.putScalar(2, 1);
else if (i % 3 == 0) return encoded.putScalar(1, 1);
else return encoded.putScalar(0, 1);
}
/**
* Decode a hotone encoded binary number using the following rules
* if the number is
* - divisible by 3 return "fizz"
* - divisible by 5 return "buzz"
* - divisible by 3 and 5, return "fizzbuzz"
* otherwise, return null
* @param arr INDArray specified hotone encoding per @see encodeFizzBuzz.
* @return String which can be null, "fizz", "buzz", or "fizzbuzz"
*/
public static String decodeFizzBuzz(INDArray arr) {
int idx = Nd4j.getExecutioner().execAndReturn(new IAMax(arr)).getFinalResult();
if (idx == 0)
return null;
else
if (idx == 1)
return "fizz";
else
if (idx == 2)
return "buzz";
else
return "fizzbuzz";
}
public static void main(String[] args) {
Nd4j.ENFORCE_NUMERICAL_STABILITY = true;
//random number generator seed
int rngSeed = 12345;
Random rnd = new Random(rngSeed);
//width of the binary number to store hotone encoding of the input
final int NUM_DIGITS = 10;
//figure out the largest number we can represent using a NUM_DIGITS wide binary number
final int NUM_UPPER = (int)Math.pow(2.0, NUM_DIGITS);
int numEpochs = 10000;
double learningRate = 0.5;
double regularizationRate = 0.75;
double nesterovsMomentum = 0.95;
//populate the train set with the numbers in the range [101,923]
INDArray trainFeatures = Nd4j.zeros(NUM_UPPER - 101, NUM_DIGITS);
INDArray trainLabels = Nd4j.zeros(NUM_UPPER - 101, 4);
for (int i = 0; i < NUM_UPPER - 101; i++) {
INDArray features = encodeBinary(i + 101, NUM_DIGITS);
INDArray labels = encodeFizzBuzz(i + 101);
trainFeatures.putRow(i, features);
trainLabels.putRow(i, labels);
}
//populate the test set with the numbers in the range [1, 100]
INDArray testFeatures = Nd4j.zeros(100, NUM_DIGITS);
INDArray testLabels = Nd4j.zeros(100, 4);
for (int i = 1; i < 101; i++) {
testFeatures.putRow(i - 1, encodeBinary(i, NUM_DIGITS));
testLabels.putRow(i - 1, encodeFizzBuzz(i));
}
final DataSet trainDataset = new DataSet(trainFeatures, trainLabels);
final DataSet testDataset = new DataSet(testFeatures, testLabels);
trainDataset.shuffle(rngSeed);
System.out.println("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(rngSeed)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.iterations(1)
.learningRate(learningRate)
.activation("relu")
.weightInit(WeightInit.XAVIER)
.updater(Updater.NESTEROVS).momentum(nesterovsMomentum)
.regularization(regularizationRate > 0.0).l2(regularizationRate)
.list()
.layer(0, new DenseLayer.Builder()
.nIn(NUM_DIGITS)
.nOut(100)
.build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.nIn(100)
.nOut(4)
.activation("softmax")
.build())
.pretrain(false).backprop(true)
.build();
final MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(100));
System.out.println("Train model....");
for( int i=0; i<numEpochs; i++ ) {
model.fit(trainDataset);
}
System.out.println("Evaluate model....");
{
System.out.println("****************Train eval********************");
Evaluation eval = new Evaluation(4);
eval.eval(trainDataset.getLabels(), model.output(trainDataset.getFeatures()));
System.out.println(eval.stats());
System.out.println("****************Train eval********************");
}
{
System.out.println("****************Test eval********************");
Evaluation eval = new Evaluation(4);
eval.eval(testDataset.getLabels(), model.output(testDataset.getFeatures()));
System.out.println(eval.stats());
System.out.println("****************Test eval********************");
}
System.out.println("****************Example finished********************");
for (int i = 0; i < 100; i++) {
String decoded = decodeFizzBuzz(model.output(testFeatures.getRow(i)));
System.out.println((i + 1) + " " + testFeatures.getRow(i).toString() + " " + encodeFizzBuzz(i + 1) + " " + model.output(testFeatures.getRow(i)).toString() + " " + (decoded == null ? i + 1 : decoded));
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment