Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save C4N4D4M4N/1d18ea23e6eae7da5db596dc177afd01 to your computer and use it in GitHub Desktop.
Save C4N4D4M4N/1d18ea23e6eae7da5db596dc177afd01 to your computer and use it in GitHub Desktop.
package com.IceKontroI.MODEL_TEST;
import org.deeplearning4j.eval.IEvaluation;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.AdaDelta;
import org.nd4j.linalg.learning.config.Sgd;
import java.util.Random;
import static org.deeplearning4j.nn.api.OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT;
import static org.deeplearning4j.nn.weights.WeightInit.XAVIER;
import static org.nd4j.linalg.activations.Activation.IDENTITY;
import static org.nd4j.linalg.activations.Activation.RELU;
import static org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction.MSE;
@SuppressWarnings("unused")
public class DoubleBarrelVanilla {
public static final int INPUT_SIZE = 10;
public static final int BATCH_SIZE = 100;
public static final int DATA_COUNT = 10000;
@SuppressWarnings("InfiniteLoopStatement")
public static void main(String[] args) {
Nd4j.setDataType(DataBuffer.Type.FLOAT);
DataIterator iterator = new DataIterator(BATCH_SIZE);
ComputationGraph model = new ComputationGraph(modelV2());
model.init();
for (int i = 1 ;; i++) {
long start = System.currentTimeMillis();
iterator.reset();
model.fit(iterator);
iterator.reset();
System.out.println("Epoch " + i + " took " + (System.currentTimeMillis() - start) / 1000D + " seconds ===========================================================================================");
IEvaluation[] evals = new IEvaluation[]{new RegressionEvaluation(), new RegressionEvaluation()};
while(iterator.hasNext()){
MultiDataSet next = iterator.next();
INDArray[] output = model.output(false, next.getFeatures(), next.getFeaturesMaskArrays(), next.getLabelsMaskArrays());
for(int a = 0; a < output.length; a++ ) {
evals[a].eval(next.getLabels(a), output[a], next.getLabelsMaskArray(a));
}
}
System.out.println(evals[0].stats());
System.out.println(evals[1].stats());
}
}
public static ComputationGraphConfiguration modelV0() {
return new NeuralNetConfiguration.Builder()
.updater(new Sgd(0.01))
.graphBuilder()
.addInputs("Input")
.setInputTypes(InputType.feedForward(INPUT_SIZE))
.addLayer("L1", new DenseLayer.Builder()
.nIn(10)
.nOut(5)
.build(), "Input")
.addLayer("O1", new OutputLayer.Builder()
.lossFunction(MSE)
.nIn(5)
.nOut(1)
.build(), "L1")
.addLayer("O2", new OutputLayer.Builder()
.lossFunction(MSE)
.nIn(5)
.nOut(1)
.build(), "L1")
.setOutputs("O1", "O2")
.build();
}
public static ComputationGraphConfiguration modelV1() {
return new NeuralNetConfiguration.Builder()
.optimizationAlgo(STOCHASTIC_GRADIENT_DESCENT)
.weightInit(XAVIER)
.activation(RELU)
.updater(new AdaDelta())
.l2(0.0001)
.seed(1000)
.graphBuilder()
.addInputs("Input")
.setInputTypes(InputType.feedForward(INPUT_SIZE))
.addLayer("FF 1", new DenseLayer.Builder()
.nOut(10)
.build(), "Input")
.addLayer("FF 2", new DenseLayer.Builder()
.nOut(20)
.build(), "FF 1")
.addLayer("FF 3", new DenseLayer.Builder()
.nOut(30)
.build(), "FF 2")
.addLayer("FF 4", new DenseLayer.Builder()
.nOut(20)
.build(), "FF 3")
.addLayer("FF 5", new DenseLayer.Builder()
.nOut(10)
.build(), "FF 4")
.addLayer("Output 1", new OutputLayer.Builder()
.lossFunction(MSE)
.activation(IDENTITY)
.nOut(1)
.build(), "FF 5")
.addLayer("Output 2", new OutputLayer.Builder()
.lossFunction(MSE)
.activation(IDENTITY)
.nOut(1)
.build(), "FF 5")
.setOutputs("Output 1", "Output 2")
.build();
}
public static ComputationGraphConfiguration modelV2() {
return new NeuralNetConfiguration.Builder()
.optimizationAlgo(STOCHASTIC_GRADIENT_DESCENT)
.weightInit(XAVIER)
.activation(RELU)
.updater(new AdaDelta())
.l2(0.0001)
.seed(1000)
.graphBuilder()
.addInputs("Input")
.setInputTypes(InputType.feedForward(INPUT_SIZE))
.addLayer("FF 1A", new DenseLayer.Builder()
.nOut(10)
.build(), "Input")
.addLayer("FF 2A", new DenseLayer.Builder()
.nOut(20)
.build(), "FF 1A")
.addLayer("FF 3A", new DenseLayer.Builder()
.nOut(30)
.build(), "FF 2A")
.addLayer("FF 4A", new DenseLayer.Builder()
.nOut(20)
.build(), "FF 3A")
.addLayer("FF 5A", new DenseLayer.Builder()
.nOut(10)
.build(), "FF 4A")
.addLayer("Output A", new OutputLayer.Builder()
.lossFunction(MSE)
.activation(IDENTITY)
.nOut(1)
.build(), "FF 5A")
.addLayer("FF 1B", new DenseLayer.Builder()
.nOut(10)
.build(), "Input")
.addLayer("FF 2B", new DenseLayer.Builder()
.nOut(20)
.build(), "FF 1B")
.addLayer("FF 3B", new DenseLayer.Builder()
.nOut(30)
.build(), "FF 2B")
.addLayer("FF 4B", new DenseLayer.Builder()
.nOut(20)
.build(), "FF 3B")
.addLayer("FF 5B", new DenseLayer.Builder()
.nOut(10)
.build(), "FF 4B")
.addLayer("Output B", new OutputLayer.Builder()
.lossFunction(MSE)
.activation(IDENTITY)
.nOut(1)
.build(), "FF 5B")
.setOutputs("Output A", "Output B")
.build();
}
/**
* Loops through a RandomDataStorage instance. Important function here is next(int num).
*/
public static class DataIterator implements MultiDataSetIterator {
public MultiDataSetPreProcessor preProcessor;
public RandomDataStorage storage;
public int batchSize;
public int index = 0;
public DataIterator(int batchSize) {
this.storage = new RandomDataStorage();
this.batchSize = batchSize;
}
@Override
public void setPreProcessor(MultiDataSetPreProcessor preProcessor) {
this.preProcessor = preProcessor;
}
@Override
public MultiDataSetPreProcessor getPreProcessor() {
return preProcessor;
}
@Override
public boolean resetSupported() {
return true;
}
@Override
public boolean asyncSupported() {
return false;
}
@Override
public void reset() {
index = 0;
}
public int getRemaining() {
return storage.size - index;
}
@Override
public boolean hasNext() {
return index + batchSize < storage.size;
}
@Override
public MultiDataSet next() {
return next(batchSize);
}
@Override
@SuppressWarnings("ConstantConditions")
public MultiDataSet next(int num) {
INDArray[] inputs = new INDArray[1];
INDArray[] labels = new INDArray[2];
// INDArray[] labelsMask = new INDArray[2];
int startIndex = index;
for (int a = 0; a < Math.max(inputs.length, labels.length); a++) {
// Reset index so we don't mismatch data for different "a" iterations
index = startIndex;
INDArray[] inputsMerged = new INDArray[num];
INDArray[] labelsMerged = new INDArray[num];
INDArray[] labelsMaskMerged = new INDArray[num];
for (int b = 0; b < num; b++, index++) {
float[] exampleInputs = storage.inputs[index];
float[] exampleLabels = {storage.label1[index], storage.label2[index]};
inputsMerged[b] = Nd4j.create(exampleInputs);
// if (Float.isNaN(exampleLabels[a])) {
// labelsMerged[b] = Nd4j.create(new float[]{0});
// labelsMaskMerged[b] = Nd4j.create(new float[]{0});
// } else {
labelsMerged[b] = Nd4j.create(new float[]{exampleLabels[a]});
// labelsMaskMerged[b] = Nd4j.create(new float[]{1});
// }
}
if (a < inputs.length) {
inputs[a] = Nd4j.vstack(inputsMerged);
}
if (a < labels.length) {
labels[a] = Nd4j.vstack(labelsMerged);
// labelsMask[a] = Nd4j.vstack(labelsMaskMerged);
}
}
return new org.nd4j.linalg.dataset.MultiDataSet(inputs, labels, null, null);
}
}
/**
* Generates some random float data. Single input of size 10. Two outputs, each of size 1.
*/
public static class RandomDataStorage {
public float[][] inputs = new float[DATA_COUNT][];
public float[] label1 = new float[DATA_COUNT], label2 = new float[DATA_COUNT];
public int size = DATA_COUNT;
public RandomDataStorage() {
Random random = new Random();
for (int a = 0; a < DATA_COUNT; a++) {
inputs[a] = new float[INPUT_SIZE];
for (int b = 0; b < INPUT_SIZE; b++) {
inputs[a][b] = 2 * random.nextFloat() - 1;
}
// if (random.nextFloat() < 0.5) {
label1[a] = 2 * random.nextFloat() - 1;
// label2[a] = Float.NaN; // Outputs are mutually exclusive
// } else {
// label1[a] = Float.NaN;
label2[a] = 2 * random.nextFloat() - 1;
// }
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment