Skip to content

Instantly share code, notes, and snippets.

@happiie
Created March 13, 2021 08:16
Show Gist options
  • Save happiie/c59a5c65c2d43f272e0934ac837d6204 to your computer and use it in GitHub Desktop.
Save happiie/c59a5c65c2d43f272e0934ac837d6204 to your computer and use it in GitHub Desktop.
Classicifation using kFold method for animalZoo dataset. Mean F1 scores at 90%
package ai.certifai.farhan.practise;
import org.datavec.api.records.reader.impl.collection.CollectionRecordReader;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.LocalTransformExecutor;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.KFoldIterator;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class animalZoo {
public static final String BLACK_BOLD = "\033[1;30m";
public static final String BLUE_BOLD = "\033[1;34m";
public static final String ANSI_RESET = "\u001B[0m";
private static final double learningRate = 1e-2;
private static final int numInput = 16;
private static final int numHidden = 100;
private static final int numOutput = 7;
private static int epochs = 1000;
public static void main(String[] args) throws Exception {
// load data using getDataSet() method
// call method here
DataSet dataSet = getDataSet();
// create kFold object
KFoldIterator kFoldIterator = new KFoldIterator(5,dataSet);
// create neural network config
MultiLayerConfiguration config = new NeuralNetConfiguration.Builder()
.seed(123)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.weightInit(WeightInit.XAVIER)
.updater(new Adam(learningRate))
.list()
.layer(0, new DenseLayer.Builder()
.nIn(numInput)
.nOut(numHidden)
.activation(Activation.TANH)
.build())
.layer(1, new OutputLayer.Builder()
.nIn(numHidden)
.nOut(numOutput)
.lossFunction(LossFunctions.LossFunction.MCXENT)
.build())
.build();
// loop through the kFoldIterator and print out the observations for each training set and test set
int i = 1;
System.out.println("-------------------------------------------------------------");
//initialize an empty list to store the F1 score
ArrayList<Double> f1List = new ArrayList<>();
while (kFoldIterator.hasNext()) {
System.out.println(BLACK_BOLD + "Batch: " + i + ANSI_RESET);
// for each fold, get features and labels from training set and test set
DataSet currDataSet = kFoldIterator.next();
INDArray trainFoldFeatures = currDataSet.getFeatures();
INDArray trainFoldLabels = currDataSet.getLabels();
INDArray testFoldFeatures = kFoldIterator.testFold().getFeatures();
INDArray testFoldLabels = kFoldIterator.testFold().getLabels();
DataSet trainDataSet = new DataSet(trainFoldFeatures,trainFoldLabels);
System.out.println(BLUE_BOLD + "Training Fold: \n" + ANSI_RESET);
System.out.println(trainFoldFeatures);
System.out.println(BLUE_BOLD + "Test Fold: \n" + ANSI_RESET);
System.out.println(testFoldFeatures);
// scale or normalize dataset
NormalizerMinMaxScaler scaler = new NormalizerMinMaxScaler();
scaler.fit(trainDataSet);
scaler.transform(trainFoldFeatures);
scaler.transform(testFoldFeatures);
// initialize model
MultiLayerNetwork model = new MultiLayerNetwork(config);
model.init();
// train data
for (int j = 0; j < epochs; j++) {
model.fit(trainDataSet);
}
// evaluate model with test set
Evaluation eval = new Evaluation();
eval.eval(testFoldLabels, model.output(testFoldFeatures));
// print out evaluation result
System.out.println(eval.stats());
// save eval result
f1List.add(eval.f1());
i++;
System.out.println("-------------------------------------------------------------");
}
INDArray f1scores = Nd4j.create(f1List);
System.out.println("Average F1 scores for all folds: " + f1scores.mean(0));
}
private static DataSet getDataSet() throws IOException, InterruptedException {
// define input data schema
Schema inputDataSchema = new Schema.Builder()
.addColumnString("animal_name")
.addColumnsInteger("hair","feathers","eggs","milk","airborne","aquatic","predator","toothed","backbone","breathes","venomous","fins","legs","tail","domestic","catsize")
.addColumnCategorical("class_type", Arrays.asList("1","2","3","4","5","6","7"))
.build();
// print out schema
System.out.println("Before transform: " + inputDataSchema);
TransformProcess tp = new TransformProcess.Builder(inputDataSchema)
.removeColumns("animal_name")
.categoricalToInteger("class_type")
.build();
// print out output transform
Schema outputTransform = tp.getFinalSchema();
System.out.println("After transform: " + outputTransform);
// filePath, csv location
File zooFile = new ClassPathResource("animal/zoo.csv").getFile();
FileSplit fileSplit = new FileSplit(zooFile);
// read, loading data set
RecordReader rr = new CSVRecordReader(1,',');
rr.initialize(fileSplit);
// adding original data to list
List<List<Writable>> originalData = new ArrayList<>();
while(rr.hasNext()){
List<Writable> data = rr.next();
originalData.add(data);
}
List<List<Writable>> transformedData = LocalTransformExecutor.execute(originalData,tp);
// show info
int numRows = transformedData.size();
System.out.println(transformedData);
System.out.println("Total number of rows: " + numRows);
// create iterator from processedData
CollectionRecordReader collectionRecordReader = new CollectionRecordReader(transformedData);
RecordReaderDataSetIterator dataSetIter = new RecordReaderDataSetIterator(collectionRecordReader, numRows, 16,7);
return dataSetIter.next();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment