Created
October 18, 2019 18:12
-
-
Save Saurabh7/9b5ea7def2a167903e7d206e272e2662 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.deeplearning4j.examples.dataexamples; | |
import org.datavec.api.util.ClassPathResource; | |
import org.datavec.api.records.reader.RecordReader; | |
import org.datavec.api.records.reader.impl.csv.CSVRecordReader; | |
import org.datavec.api.split.FileSplit; | |
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; | |
import org.deeplearning4j.eval.Evaluation; | |
import org.deeplearning4j.examples.download.DownloaderUtility; | |
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.layers.DenseLayer; | |
import org.deeplearning4j.nn.conf.layers.OutputLayer; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
import org.nd4j.linalg.activations.Activation; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.dataset.DataSet; | |
import org.nd4j.linalg.dataset.SplitTestAndTrain; | |
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | |
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; | |
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; | |
import org.nd4j.linalg.learning.config.Sgd; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.layers.DenseLayer; | |
import org.deeplearning4j.nn.gradient.Gradient; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; | |
import org.nd4j.linalg.activations.Activation; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.factory.Nd4j; | |
import org.nd4j.linalg.learning.config.Nadam; | |
import org.nd4j.linalg.primitives.Pair; | |
import java.io.File; | |
import java.io.IOException; | |
import java.util.List; | |
public class SyntheticTest { | |
private static DataSet readCSVDataset( | |
String csvFileClasspath, int batchSize, int labelIndex, int numClasses) | |
throws IOException, InterruptedException { | |
RecordReader rr = new CSVRecordReader(); | |
rr.initialize(new FileSplit(new File("./", csvFileClasspath))); | |
DataSetIterator iterator = new RecordReaderDataSetIterator(rr, batchSize, labelIndex, numClasses); | |
return iterator.next(); | |
} | |
public static void main(String[] args) throws IOException, InterruptedException { | |
int nIn = 2; | |
int nOut = 1; | |
Nd4j.getRandom().setSeed(12345); | |
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() | |
.seed(12345) | |
.activation(Activation.TANH) | |
.weightInit(WeightInit.XAVIER) | |
.updater(new Nadam()) | |
.list() | |
.layer(new DenseLayer.Builder().nIn(nIn).nOut(3).build()) | |
.layer(new DenseLayer.Builder().nIn(3).nOut(nOut).build()) | |
.build(); | |
MultiLayerNetwork model = new MultiLayerNetwork(conf); | |
model.init(); | |
//Calculate gradient with respect to an external error | |
int minibatch = 25; | |
int numLinesToSkip = 0; | |
char delimiter = ','; | |
DataSet trainingData = readCSVDataset( | |
"fake_data_2.csv", | |
minibatch, 2, 1 | |
); | |
INDArray input = trainingData.getFeatures(); | |
model.setInput(input); | |
int numEpochs = 30; | |
for (int i = 0; i < numEpochs; i++) { | |
List<INDArray> output = model.feedForward(true, false); | |
INDArray predictions = output.get(output.size() - 1); | |
// | |
INDArray diff = trainingData.getLabels().sub(predictions); | |
INDArray externalError = diff.mul(diff); | |
System.out.println("------"); | |
System.out.println("EPOCH:"); | |
System.out.println(i); | |
System.out.println(externalError.sum()); | |
System.out.println("------"); | |
Pair<Gradient, INDArray> p = model.backpropGradient(externalError, null); //Calculate backprop gradient based on error array | |
Gradient gradient = p.getFirst(); | |
int iteration = 0; | |
int epoch = 0; | |
model.getUpdater().update(model, gradient, iteration, epoch, minibatch, LayerWorkspaceMgr.noWorkspaces()); | |
INDArray updateVector = gradient.gradient(); | |
INDArray a = model.params().subi(updateVector); | |
} | |
} | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment