Created
August 13, 2016 23:31
-
-
Save vvpreetham/801817b949f0622d6f1eae0d3c979bf8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// package nn.example; | |
import org.deeplearning4j.nn.api.Layer; | |
import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.ListBuilder; | |
import org.deeplearning4j.nn.conf.distribution.UniformDistribution; | |
import org.deeplearning4j.nn.conf.layers.DenseLayer; | |
import org.deeplearning4j.nn.conf.layers.OutputLayer; | |
import org.deeplearning4j.nn.conf.layers.OutputLayer.Builder; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.dataset.DataSet; | |
import org.nd4j.linalg.factory.Nd4j; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
public class ANDExample { | |
private static final int SEED = 100; | |
private static final int ITERATIONS = 350; | |
private static final String OUTPUT_ACTIVATION = "sigmoid"; | |
private static final double LEARNING_RATE = 0.7; | |
private static final int INPUT_NEURONS = 2; | |
private static final int HIDDEN_NEURONS = 4; | |
private static final int OUTPUT_NEURONS = 1; | |
public static void main(String[] args) { | |
INDArray input = Nd4j.zeros(4, 2); | |
INDArray labels = Nd4j.zeros(4, 1); | |
input.putScalar(new int[] { 0, 0 }, 0); | |
input.putScalar(new int[] { 0, 1 }, 0); | |
labels.putScalar(new int[] { 0, 0 }, 0); | |
input.putScalar(new int[] { 1, 0 }, 1); | |
input.putScalar(new int[] { 1, 1 }, 0); | |
labels.putScalar(new int[] { 1, 0 }, 0); | |
input.putScalar(new int[] { 2, 0 }, 0); | |
input.putScalar(new int[] { 2, 1 }, 1); | |
labels.putScalar(new int[] { 2, 0 }, 0); | |
input.putScalar(new int[] { 3, 0 }, 1); | |
input.putScalar(new int[] { 3, 1 }, 1); | |
labels.putScalar(new int[] { 3, 0 }, 1); | |
DataSet ds = new DataSet(input, labels); | |
NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder(); | |
builder.iterations(ITERATIONS); | |
builder.learningRate(LEARNING_RATE); | |
builder.seed(SEED); | |
builder.useDropConnect(false); | |
builder.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT); | |
builder.biasInit(0); | |
builder.miniBatch(false); | |
ListBuilder listBuilder = builder.list(); | |
// Hidden Layer | |
DenseLayer.Builder hiddenLayerBuilder = new DenseLayer.Builder(); | |
hiddenLayerBuilder.nIn(INPUT_NEURONS); | |
hiddenLayerBuilder.nOut(HIDDEN_NEURONS); | |
hiddenLayerBuilder.activation("sigmoid"); | |
hiddenLayerBuilder.weightInit(WeightInit.DISTRIBUTION); | |
hiddenLayerBuilder.dist(new UniformDistribution(0, 1)); | |
listBuilder.layer(0, hiddenLayerBuilder.build()); | |
// Output Layer | |
Builder outputLayerBuilder = new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD); | |
outputLayerBuilder.nIn(HIDDEN_NEURONS); | |
outputLayerBuilder.nOut(OUTPUT_NEURONS); | |
outputLayerBuilder.activation(OUTPUT_ACTIVATION); | |
outputLayerBuilder.weightInit(WeightInit.DISTRIBUTION); | |
outputLayerBuilder.dist(new UniformDistribution(0, 1)); | |
listBuilder.layer(1, outputLayerBuilder.build()); | |
listBuilder.pretrain(false); | |
listBuilder.backprop(true); | |
MultiLayerConfiguration conf = listBuilder.build(); | |
MultiLayerNetwork net = new MultiLayerNetwork(conf); | |
net.init(); | |
Layer[] layers = net.getLayers(); | |
int totalNumParams = 0; | |
for (int i = 0; i < layers.length; i++) { | |
int nParams = layers[i].numParams(); | |
System.out.println("Number of parameters in layer " + i + ": " + nParams); | |
totalNumParams += nParams; | |
} | |
System.out.println("Total number of network parameters: " + totalNumParams); | |
net.setListeners(new ScoreIterationListener(100)); | |
System.out.println(net.output(input)); | |
net.fit(ds); | |
System.out.println(net.output(input)); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment