Skip to content

Instantly share code, notes, and snippets.

@Saurabh7
Created October 18, 2019 18:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Saurabh7/9b5ea7def2a167903e7d206e272e2662 to your computer and use it in GitHub Desktop.
Save Saurabh7/9b5ea7def2a167903e7d206e272e2662 to your computer and use it in GitHub Desktop.
package org.deeplearning4j.examples.dataexamples;
import org.datavec.api.util.ClassPathResource;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.examples.download.DownloaderUtility;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nadam;
import org.nd4j.linalg.primitives.Pair;
import java.io.File;
import java.io.IOException;
import java.util.List;
public class SyntheticTest {
private static DataSet readCSVDataset(
String csvFileClasspath, int batchSize, int labelIndex, int numClasses)
throws IOException, InterruptedException {
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new File("./", csvFileClasspath)));
DataSetIterator iterator = new RecordReaderDataSetIterator(rr, batchSize, labelIndex, numClasses);
return iterator.next();
}
public static void main(String[] args) throws IOException, InterruptedException {
int nIn = 2;
int nOut = 1;
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(12345)
.activation(Activation.TANH)
.weightInit(WeightInit.XAVIER)
.updater(new Nadam())
.list()
.layer(new DenseLayer.Builder().nIn(nIn).nOut(3).build())
.layer(new DenseLayer.Builder().nIn(3).nOut(nOut).build())
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
//Calculate gradient with respect to an external error
int minibatch = 25;
int numLinesToSkip = 0;
char delimiter = ',';
DataSet trainingData = readCSVDataset(
"fake_data_2.csv",
minibatch, 2, 1
);
INDArray input = trainingData.getFeatures();
model.setInput(input);
int numEpochs = 30;
for (int i = 0; i < numEpochs; i++) {
List<INDArray> output = model.feedForward(true, false);
INDArray predictions = output.get(output.size() - 1);
//
INDArray diff = trainingData.getLabels().sub(predictions);
INDArray externalError = diff.mul(diff);
System.out.println("------");
System.out.println("EPOCH:");
System.out.println(i);
System.out.println(externalError.sum());
System.out.println("------");
Pair<Gradient, INDArray> p = model.backpropGradient(externalError, null); //Calculate backprop gradient based on error array
Gradient gradient = p.getFirst();
int iteration = 0;
int epoch = 0;
model.getUpdater().update(model, gradient, iteration, epoch, minibatch, LayerWorkspaceMgr.noWorkspaces());
INDArray updateVector = gradient.gradient();
INDArray a = model.params().subi(updateVector);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment