Skip to content

Instantly share code, notes, and snippets.

@vvpreetham
Created August 13, 2016 23:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vvpreetham/801817b949f0622d6f1eae0d3c979bf8 to your computer and use it in GitHub Desktop.
Save vvpreetham/801817b949f0622d6f1eae0d3c979bf8 to your computer and use it in GitHub Desktop.
// package nn.example;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.ListBuilder;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer.Builder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
public class ANDExample {
private static final int SEED = 100;
private static final int ITERATIONS = 350;
private static final String OUTPUT_ACTIVATION = "sigmoid";
private static final double LEARNING_RATE = 0.7;
private static final int INPUT_NEURONS = 2;
private static final int HIDDEN_NEURONS = 4;
private static final int OUTPUT_NEURONS = 1;
public static void main(String[] args) {
INDArray input = Nd4j.zeros(4, 2);
INDArray labels = Nd4j.zeros(4, 1);
input.putScalar(new int[] { 0, 0 }, 0);
input.putScalar(new int[] { 0, 1 }, 0);
labels.putScalar(new int[] { 0, 0 }, 0);
input.putScalar(new int[] { 1, 0 }, 1);
input.putScalar(new int[] { 1, 1 }, 0);
labels.putScalar(new int[] { 1, 0 }, 0);
input.putScalar(new int[] { 2, 0 }, 0);
input.putScalar(new int[] { 2, 1 }, 1);
labels.putScalar(new int[] { 2, 0 }, 0);
input.putScalar(new int[] { 3, 0 }, 1);
input.putScalar(new int[] { 3, 1 }, 1);
labels.putScalar(new int[] { 3, 0 }, 1);
DataSet ds = new DataSet(input, labels);
NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder();
builder.iterations(ITERATIONS);
builder.learningRate(LEARNING_RATE);
builder.seed(SEED);
builder.useDropConnect(false);
builder.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT);
builder.biasInit(0);
builder.miniBatch(false);
ListBuilder listBuilder = builder.list();
// Hidden Layer
DenseLayer.Builder hiddenLayerBuilder = new DenseLayer.Builder();
hiddenLayerBuilder.nIn(INPUT_NEURONS);
hiddenLayerBuilder.nOut(HIDDEN_NEURONS);
hiddenLayerBuilder.activation("sigmoid");
hiddenLayerBuilder.weightInit(WeightInit.DISTRIBUTION);
hiddenLayerBuilder.dist(new UniformDistribution(0, 1));
listBuilder.layer(0, hiddenLayerBuilder.build());
// Output Layer
Builder outputLayerBuilder = new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD);
outputLayerBuilder.nIn(HIDDEN_NEURONS);
outputLayerBuilder.nOut(OUTPUT_NEURONS);
outputLayerBuilder.activation(OUTPUT_ACTIVATION);
outputLayerBuilder.weightInit(WeightInit.DISTRIBUTION);
outputLayerBuilder.dist(new UniformDistribution(0, 1));
listBuilder.layer(1, outputLayerBuilder.build());
listBuilder.pretrain(false);
listBuilder.backprop(true);
MultiLayerConfiguration conf = listBuilder.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
Layer[] layers = net.getLayers();
int totalNumParams = 0;
for (int i = 0; i < layers.length; i++) {
int nParams = layers[i].numParams();
System.out.println("Number of parameters in layer " + i + ": " + nParams);
totalNumParams += nParams;
}
System.out.println("Total number of network parameters: " + totalNumParams);
net.setListeners(new ScoreIterationListener(100));
System.out.println(net.output(input));
net.fit(ds);
System.out.println(net.output(input));
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment