Last active
December 5, 2019 07:10
-
-
Save nitinnat/88562f236326f5058e36789986b50707 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.deeplearning4j.examples.feedforward.classification; | |
import org.datavec.api.records.reader.RecordReader; | |
import org.datavec.api.records.reader.impl.csv.CSVRecordReader; | |
import org.datavec.api.split.FileSplit; | |
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; | |
import org.nd4j.linalg.activations.Activation; | |
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | |
import org.deeplearning4j.eval.Evaluation; | |
import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.Updater; | |
import org.deeplearning4j.nn.conf.layers.DenseLayer; | |
import org.deeplearning4j.nn.conf.layers.OutputLayer; | |
import org.deeplearning4j.nn.gradient.Gradient; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; | |
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.dataset.DataSet; | |
import org.nd4j.linalg.factory.Nd4j; | |
import org.nd4j.linalg.io.ClassPathResource; | |
import org.nd4j.linalg.learning.config.Nesterovs; | |
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; | |
import org.nd4j.linalg.primitives.Pair; | |
import java.io.File; | |
import java.io.IOException; | |
import java.util.List; | |
/** | |
* "Saturn" Data Classification Example | |
* | |
* Based on the data from Jason Baldridge: | |
* https://github.com/jasonbaldridge/try-tf/tree/master/simdata | |
* | |
* @author Josh Patterson | |
* @author Alex Black (added plots) | |
* | |
*/ | |
public class MLPClassifierSaturn { | |
private static DataSet readCSVDataset( | |
String csvFileClasspath, int batchSize, int labelIndex, int numClasses) | |
throws IOException, InterruptedException { | |
RecordReader rr = new CSVRecordReader(); | |
rr.initialize(new FileSplit(new File(csvFileClasspath))); | |
DataSetIterator iterator = new RecordReaderDataSetIterator(rr, batchSize, labelIndex, numClasses); | |
return iterator.next(); | |
} | |
public static void main(String[] args) throws Exception { | |
Nd4j.ENFORCE_NUMERICAL_STABILITY = true; | |
int seed = 123; | |
double learningRate = 0.005; | |
//Number of epochs (full passes of the data) | |
int nEpochs = 50; | |
int numInputs = 10000; | |
int numOutputs = 1; | |
int numHiddenNodes = 20; | |
final String localTrainFilepath = "C:\\Users\\Nitin\\Documents\\consensus-dl\\data\\arcene\\arcene_train_binary.csv"; | |
final String localTestFilepath = "C:\\Users\\Nitin\\Documents\\consensus-dl\\data\\arcene\\arcene_test_binary.csv"; | |
/* | |
//Load the training data: | |
RecordReader rr = new CSVRecordReader(); | |
rr.initialize(new FileSplit(new File(filenameTrain))); | |
DataSetIterator trainIter = new RecordReaderDataSetIterator(rr,80,0,1); | |
//Load the test/evaluation data: | |
RecordReader rrTest = new CSVRecordReader(); | |
rrTest.initialize(new FileSplit(new File(filenameTest))); | |
DataSetIterator testIter = new RecordReaderDataSetIterator(rrTest,80,0,1); | |
System.out.println(testIter.getLabels()); | |
*/ | |
DataSet trainSet = readCSVDataset( | |
localTrainFilepath, | |
80, 0, 1 | |
); | |
DataSet testSet = readCSVDataset( | |
localTestFilepath, | |
20, 0, 1 | |
); | |
//log.info("Build model...."); | |
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() | |
.seed(seed) | |
.updater(new Nesterovs(learningRate, 0.9)) | |
.list() | |
.layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes) | |
.weightInit(WeightInit.XAVIER) | |
.activation(Activation.RELU) | |
.build()) | |
//.layer(1, new OutputLayer.Builder(LossFunction.SQUARED_LOSS) | |
// .weightInit(WeightInit.XAVIER) | |
// .activation(Activation.SIGMOID) | |
// .nIn(numHiddenNodes).nOut(numOutputs).build()) | |
.layer(1, new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numOutputs) | |
.weightInit(WeightInit.XAVIER) | |
.activation(Activation.SIGMOID) | |
.build()) | |
.build(); | |
MultiLayerNetwork model = new MultiLayerNetwork(conf); | |
model.init(); | |
model.setListeners(new ScoreIterationListener(1)); //Print score every 10 parameter updates | |
INDArray trainFeatures = trainSet.getFeatures(); | |
INDArray trainLabels = trainSet.getLabels(); | |
for ( int n = 0; n < nEpochs; n++) { | |
//model.fit( trainSet); | |
model.setInput(trainFeatures); | |
List<INDArray> output = model.feedForward(true, false); | |
INDArray predictions = output.get(output.size() - 1); | |
INDArray diff = trainLabels.sub(predictions); | |
INDArray error = diff.mul(diff).mul(0.5); | |
Pair<Gradient, INDArray> p = model.backpropGradient(error, null); | |
Gradient gradient = p.getFirst(); | |
int iteration = 0; | |
int epoch = 0; | |
model.getUpdater().update(model, gradient, iteration, epoch, | |
(int)trainFeatures.size(0), LayerWorkspaceMgr.noWorkspaces()); | |
INDArray updateVector = gradient.gradient(); | |
INDArray a = model.params().addi(updateVector); | |
System.out.println("Loss: " + error.sum()); | |
} | |
System.out.println("****************Example finished********************"); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment