Created
October 12, 2016 19:08
-
-
Save osipov/11bcc59c14b1a140d4f67ca865d56648 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.deeplearning4j.examples.feedforward.xor; | |
import org.deeplearning4j.eval.Evaluation; | |
import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.Updater; | |
import org.deeplearning4j.nn.conf.layers.DenseLayer; | |
import org.deeplearning4j.nn.conf.layers.OutputLayer; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax; | |
import org.nd4j.linalg.dataset.DataSet; | |
import org.nd4j.linalg.factory.Nd4j; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
import java.util.Random; | |
/** | |
* Created by @osipov on 10/12/16. | |
*/ | |
public class FizzBuzz { | |
/** | |
* Encode a positive integer in binary using little endian style with a numDigits width | |
* @param val positive integer | |
* @param numDigits width of the binary number | |
* @return INDArray array with the binary encoding | |
*/ | |
public static INDArray encodeBinary(int val, int numDigits) { | |
INDArray encoded = Nd4j.zeros(numDigits); | |
for (int i = 0; i < numDigits; i++) | |
encoded.putScalar(i, (val >> i) & 1); | |
return encoded; | |
} | |
/** | |
* Decode a binary number into a positive integer | |
* @param arr binary number stored little endian style in INDArray | |
* @return decoding of the binary number back into an integer | |
*/ | |
public static int decodeBinary(INDArray arr) { | |
int i = 0; | |
for (int j = 0; j < arr.length(); j++) { | |
i += Math.pow(2, j) * arr.getInt(j); | |
} | |
return i; | |
} | |
/** | |
* Hot one encode a positive integer to one of the 4 "fizzbuzz" classes as follows | |
* [0.0 1.0 0.0 0.0] if the number is divisible by 3 | |
* [0.0 0.0 1.0 0.0] if the number is divisible by 5 | |
* [0.0 0.0 0.0 1.0] if the number is divisible by 3 and 5 | |
* [1.0 0.0 0.0 0.0] otherwise | |
* @param i | |
* @return INDArray containing the encoding | |
*/ | |
public static INDArray encodeFizzBuzz(int i) { | |
INDArray encoded = Nd4j.zeros(4); | |
if (i % 15 == 0 && i != 1) return encoded.putScalar(3, 1); | |
else if (i % 5 == 0) return encoded.putScalar(2, 1); | |
else if (i % 3 == 0) return encoded.putScalar(1, 1); | |
else return encoded.putScalar(0, 1); | |
} | |
/** | |
* Decode a hotone encoded binary number using the following rules | |
* if the number is | |
* - divisible by 3 return "fizz" | |
* - divisible by 5 return "buzz" | |
* - divisible by 3 and 5, return "fizzbuzz" | |
* otherwise, return null | |
* @param arr INDArray specified hotone encoding per @see encodeFizzBuzz. | |
* @return String which can be null, "fizz", "buzz", or "fizzbuzz" | |
*/ | |
public static String decodeFizzBuzz(INDArray arr) { | |
int idx = Nd4j.getExecutioner().execAndReturn(new IAMax(arr)).getFinalResult(); | |
if (idx == 0) | |
return null; | |
else | |
if (idx == 1) | |
return "fizz"; | |
else | |
if (idx == 2) | |
return "buzz"; | |
else | |
return "fizzbuzz"; | |
} | |
public static void main(String[] args) { | |
Nd4j.ENFORCE_NUMERICAL_STABILITY = true; | |
//random number generator seed | |
int rngSeed = 12345; | |
Random rnd = new Random(rngSeed); | |
//width of the binary number to store hotone encoding of the input | |
final int NUM_DIGITS = 10; | |
//figure out the largest number we can represent using a NUM_DIGITS wide binary number | |
final int NUM_UPPER = (int)Math.pow(2.0, NUM_DIGITS); | |
int numEpochs = 10000; | |
double learningRate = 0.5; | |
double regularizationRate = 0.75; | |
double nesterovsMomentum = 0.95; | |
//populate the train set with the numbers in the range [101,923] | |
INDArray trainFeatures = Nd4j.zeros(NUM_UPPER - 101, NUM_DIGITS); | |
INDArray trainLabels = Nd4j.zeros(NUM_UPPER - 101, 4); | |
for (int i = 0; i < NUM_UPPER - 101; i++) { | |
INDArray features = encodeBinary(i + 101, NUM_DIGITS); | |
INDArray labels = encodeFizzBuzz(i + 101); | |
trainFeatures.putRow(i, features); | |
trainLabels.putRow(i, labels); | |
} | |
//populate the test set with the numbers in the range [1, 100] | |
INDArray testFeatures = Nd4j.zeros(100, NUM_DIGITS); | |
INDArray testLabels = Nd4j.zeros(100, 4); | |
for (int i = 1; i < 101; i++) { | |
testFeatures.putRow(i - 1, encodeBinary(i, NUM_DIGITS)); | |
testLabels.putRow(i - 1, encodeFizzBuzz(i)); | |
} | |
final DataSet trainDataset = new DataSet(trainFeatures, trainLabels); | |
final DataSet testDataset = new DataSet(testFeatures, testLabels); | |
trainDataset.shuffle(rngSeed); | |
System.out.println("Build model...."); | |
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() | |
.seed(rngSeed) | |
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | |
.iterations(1) | |
.learningRate(learningRate) | |
.activation("relu") | |
.weightInit(WeightInit.XAVIER) | |
.updater(Updater.NESTEROVS).momentum(nesterovsMomentum) | |
.regularization(regularizationRate > 0.0).l2(regularizationRate) | |
.list() | |
.layer(0, new DenseLayer.Builder() | |
.nIn(NUM_DIGITS) | |
.nOut(100) | |
.build()) | |
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) | |
.nIn(100) | |
.nOut(4) | |
.activation("softmax") | |
.build()) | |
.pretrain(false).backprop(true) | |
.build(); | |
final MultiLayerNetwork model = new MultiLayerNetwork(conf); | |
model.init(); | |
model.setListeners(new ScoreIterationListener(100)); | |
System.out.println("Train model...."); | |
for( int i=0; i<numEpochs; i++ ) { | |
model.fit(trainDataset); | |
} | |
System.out.println("Evaluate model...."); | |
{ | |
System.out.println("****************Train eval********************"); | |
Evaluation eval = new Evaluation(4); | |
eval.eval(trainDataset.getLabels(), model.output(trainDataset.getFeatures())); | |
System.out.println(eval.stats()); | |
System.out.println("****************Train eval********************"); | |
} | |
{ | |
System.out.println("****************Test eval********************"); | |
Evaluation eval = new Evaluation(4); | |
eval.eval(testDataset.getLabels(), model.output(testDataset.getFeatures())); | |
System.out.println(eval.stats()); | |
System.out.println("****************Test eval********************"); | |
} | |
System.out.println("****************Example finished********************"); | |
for (int i = 0; i < 100; i++) { | |
String decoded = decodeFizzBuzz(model.output(testFeatures.getRow(i))); | |
System.out.println((i + 1) + " " + testFeatures.getRow(i).toString() + " " + encodeFizzBuzz(i + 1) + " " + model.output(testFeatures.getRow(i)).toString() + " " + (decoded == null ? i + 1 : decoded)); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment