Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@vvpreetham
Created August 16, 2016 05:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vvpreetham/8b07f942511fd7faa0d36631468e62ec to your computer and use it in GitHub Desktop.
Save vvpreetham/8b07f942511fd7faa0d36631468e62ec to your computer and use it in GitHub Desktop.
package nn.example;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.util.ClassPathResource;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
public class WineClassifier {
public final static int SEED = 123; // Constant seed to ensure results on all runs are the same
public final static int EPOCHS = 30;
public final static int ITERATIONS = 20; // Number of iterations per epoch
public final static double EPSILON = 0.05; // Learning Rate. a 5% epsilon on the error derivative
private static final double ALPHA = 0.1; // Momentum Value. 10% alpha on previous error delta.
public final static int FILE_SIZE = 178; // Load all the records
public final static int CLASS_INDEX=0;
public final static int INPUT_NEURONS = 13;
public final static int NUM_OF_CLASSES = 3;
public final static int HIDDEN_NEURONS = 6;
public final static int OUTPUT_NEURONS = NUM_OF_CLASSES;
public final static double DATA_SPLIT_TRAIN_TEST = 0.65; //Use 65% of data for training, remaining for validation
private static final String HIDDEN_LAYER_ACTIVATION = "relu";
private static final String OUTPUT_LAYER_ACTIVATION = "softmax";
public static void main(String[] args) throws Exception {
//Load the training data:
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("wine.data").getFile()));
DataSetIterator iterator = new RecordReaderDataSetIterator(rr,FILE_SIZE,CLASS_INDEX,NUM_OF_CLASSES);
DataSet wineData = iterator.next();
wineData.shuffle();
SplitTestAndTrain testAndTrain = wineData.splitTestAndTrain(DATA_SPLIT_TRAIN_TEST);
DataSet trainingData = testAndTrain.getTrain();
DataSet testData = testAndTrain.getTest();
// A Standard Normalizer which gives zero-mean, unit variance
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainingData); // Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.transform(trainingData); // Apply normalization to the training data
normalizer.transform(testData); // Apply normalization to the test data. This is using statistics calculated from the *training* set
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(SEED)
.iterations(ITERATIONS)
.weightInit(WeightInit.RELU)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.learningRate(EPSILON)
.regularization(true).l2(1e-4)
.updater(Updater.NESTEROVS).momentum(ALPHA)
.list()
.layer(0, new DenseLayer.Builder()
.nIn(INPUT_NEURONS).nOut(HIDDEN_NEURONS).activation(HIDDEN_LAYER_ACTIVATION).build())
.layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
.nIn(HIDDEN_NEURONS).nOut(OUTPUT_NEURONS).activation(OUTPUT_LAYER_ACTIVATION).build())
.pretrain(false).backprop(true).build();
MultiLayerNetwork networkModel = new MultiLayerNetwork(conf);
networkModel.init();
networkModel.setListeners(new ScoreIterationListener(10)); //Print score every 10 parameter updates
for ( int n = 0; n < EPOCHS; n++) {
networkModel.fit(trainingData);
}
System.out.println("Evaluate model....");
Evaluation eval = new Evaluation(OUTPUT_NEURONS);
INDArray output = networkModel.output(testData.getFeatureMatrix());
System.out.println(testData.getFeatureMatrix());
System.out.println(output);
eval.eval(testData.getLabels(), output);
//Print the evaluation statistics
System.out.println(eval.stats());
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment