Skip to content

Instantly share code, notes, and snippets.

@thechaudharysab
Last active May 23, 2021 05:54
Show Gist options
  • Save thechaudharysab/4f22c13459a5489c2b5af069c44e9f79 to your computer and use it in GitHub Desktop.
Save thechaudharysab/4f22c13459a5489c2b5af069c44e9f79 to your computer and use it in GitHub Desktop.
Iris Classification Neural Network using Deep Learning for Java (DL4J) Library tutorial Link: https://medium.com/@ibjects/deep-learning-for-java-dl4j-getting-started-tutorial-2259c76c0a7c
package com.ibjects;
import org.apache.log4j.BasicConfigurator;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
public class IrisClassification {
private static final int FEATURES_COUNT = 4;
private static final int CLASSES_COUNT = 3;
public static void main(String[] args) {
BasicConfigurator.configure();
loadData();
}
private static void loadData() {
try(RecordReader recordReader = new CSVRecordReader(0,',')) {
recordReader.initialize(new FileSplit(
new ClassPathResource("iris.csv").getFile()
));
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, 150, FEATURES_COUNT, CLASSES_COUNT);
DataSet allData = iterator.next();
allData.shuffle(123);
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(allData);
normalizer.transform(allData);
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);
DataSet trainingData = testAndTrain.getTrain();
DataSet testingData = testAndTrain.getTest();
irisNNetwork(trainingData, testingData);
} catch (Exception e) {
Thread.dumpStack();
new Exception("Stack trace").printStackTrace();
System.out.println("Error: " + e.getLocalizedMessage());
}
}
private static void irisNNetwork(DataSet trainingData, DataSet testData) {
MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.weightInit(WeightInit.XAVIER)
.updater(new Nesterovs(0.1, 0.9))
.l2(0.0001)
.list()
.layer(0, new DenseLayer.Builder().nIn(FEATURES_COUNT).nOut(3).build())
.layer(1, new DenseLayer.Builder().nIn(3).nOut(3).build())
.layer(2, new OutputLayer.Builder(
LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).activation(Activation.SOFTMAX)
.nIn(3).nOut(CLASSES_COUNT).build())
.backprop(true).pretrain(false)
.build();
MultiLayerNetwork model = new MultiLayerNetwork(configuration);
model.init();
model.fit(trainingData);
INDArray output = model.output(testData.getFeatureMatrix());
Evaluation eval = new Evaluation(3);
eval.eval(testData.getLabels(), output);
System.out.println(eval.stats());
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment