Skip to content

Instantly share code, notes, and snippets.

@nitinnat
Last active December 5, 2019 07:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nitinnat/88562f236326f5058e36789986b50707 to your computer and use it in GitHub Desktop.
Save nitinnat/88562f236326f5058e36789986b50707 to your computer and use it in GitHub Desktop.
package org.deeplearning4j.examples.feedforward.classification;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.nd4j.linalg.primitives.Pair;
import java.io.File;
import java.io.IOException;
import java.util.List;
/**
* "Saturn" Data Classification Example
*
* Based on the data from Jason Baldridge:
* https://github.com/jasonbaldridge/try-tf/tree/master/simdata
*
* @author Josh Patterson
* @author Alex Black (added plots)
*
*/
public class MLPClassifierSaturn {
private static DataSet readCSVDataset(
String csvFileClasspath, int batchSize, int labelIndex, int numClasses)
throws IOException, InterruptedException {
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new File(csvFileClasspath)));
DataSetIterator iterator = new RecordReaderDataSetIterator(rr, batchSize, labelIndex, numClasses);
return iterator.next();
}
public static void main(String[] args) throws Exception {
Nd4j.ENFORCE_NUMERICAL_STABILITY = true;
int seed = 123;
double learningRate = 0.005;
//Number of epochs (full passes of the data)
int nEpochs = 50;
int numInputs = 10000;
int numOutputs = 1;
int numHiddenNodes = 20;
final String localTrainFilepath = "C:\\Users\\Nitin\\Documents\\consensus-dl\\data\\arcene\\arcene_train_binary.csv";
final String localTestFilepath = "C:\\Users\\Nitin\\Documents\\consensus-dl\\data\\arcene\\arcene_test_binary.csv";
/*
//Load the training data:
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new File(filenameTrain)));
DataSetIterator trainIter = new RecordReaderDataSetIterator(rr,80,0,1);
//Load the test/evaluation data:
RecordReader rrTest = new CSVRecordReader();
rrTest.initialize(new FileSplit(new File(filenameTest)));
DataSetIterator testIter = new RecordReaderDataSetIterator(rrTest,80,0,1);
System.out.println(testIter.getLabels());
*/
DataSet trainSet = readCSVDataset(
localTrainFilepath,
80, 0, 1
);
DataSet testSet = readCSVDataset(
localTestFilepath,
20, 0, 1
);
//log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.updater(new Nesterovs(learningRate, 0.9))
.list()
.layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
.weightInit(WeightInit.XAVIER)
.activation(Activation.RELU)
.build())
//.layer(1, new OutputLayer.Builder(LossFunction.SQUARED_LOSS)
// .weightInit(WeightInit.XAVIER)
// .activation(Activation.SIGMOID)
// .nIn(numHiddenNodes).nOut(numOutputs).build())
.layer(1, new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numOutputs)
.weightInit(WeightInit.XAVIER)
.activation(Activation.SIGMOID)
.build())
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(1)); //Print score every 10 parameter updates
INDArray trainFeatures = trainSet.getFeatures();
INDArray trainLabels = trainSet.getLabels();
for ( int n = 0; n < nEpochs; n++) {
//model.fit( trainSet);
model.setInput(trainFeatures);
List<INDArray> output = model.feedForward(true, false);
INDArray predictions = output.get(output.size() - 1);
INDArray diff = trainLabels.sub(predictions);
INDArray error = diff.mul(diff).mul(0.5);
Pair<Gradient, INDArray> p = model.backpropGradient(error, null);
Gradient gradient = p.getFirst();
int iteration = 0;
int epoch = 0;
model.getUpdater().update(model, gradient, iteration, epoch,
(int)trainFeatures.size(0), LayerWorkspaceMgr.noWorkspaces());
INDArray updateVector = gradient.gradient();
INDArray a = model.params().addi(updateVector);
System.out.println("Loss: " + error.sum());
}
System.out.println("****************Example finished********************");
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment