Created
March 13, 2021 08:16
-
-
Save happiie/c59a5c65c2d43f272e0934ac837d6204 to your computer and use it in GitHub Desktop.
Classicifation using kFold method for animalZoo dataset. Mean F1 scores at 90%
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package ai.certifai.farhan.practise; | |
import org.datavec.api.records.reader.impl.collection.CollectionRecordReader; | |
import org.datavec.api.transform.TransformProcess; | |
import org.datavec.api.transform.schema.Schema; | |
import org.datavec.api.records.reader.RecordReader; | |
import org.datavec.api.records.reader.impl.csv.CSVRecordReader; | |
import org.datavec.api.split.FileSplit; | |
import org.datavec.api.writable.Writable; | |
import org.datavec.local.transforms.LocalTransformExecutor; | |
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; | |
import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.layers.DenseLayer; | |
import org.deeplearning4j.nn.conf.layers.OutputLayer; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.nd4j.common.io.ClassPathResource; | |
import org.nd4j.evaluation.classification.Evaluation; | |
import org.nd4j.linalg.activations.Activation; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.dataset.DataSet; | |
import org.nd4j.linalg.dataset.api.iterator.KFoldIterator; | |
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler; | |
import org.nd4j.linalg.factory.Nd4j; | |
import org.nd4j.linalg.learning.config.Adam; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
import java.io.File; | |
import java.io.IOException; | |
import java.util.ArrayList; | |
import java.util.Arrays; | |
import java.util.List; | |
public class animalZoo { | |
public static final String BLACK_BOLD = "\033[1;30m"; | |
public static final String BLUE_BOLD = "\033[1;34m"; | |
public static final String ANSI_RESET = "\u001B[0m"; | |
private static final double learningRate = 1e-2; | |
private static final int numInput = 16; | |
private static final int numHidden = 100; | |
private static final int numOutput = 7; | |
private static int epochs = 1000; | |
public static void main(String[] args) throws Exception { | |
// load data using getDataSet() method | |
// call method here | |
DataSet dataSet = getDataSet(); | |
// create kFold object | |
KFoldIterator kFoldIterator = new KFoldIterator(5,dataSet); | |
// create neural network config | |
MultiLayerConfiguration config = new NeuralNetConfiguration.Builder() | |
.seed(123) | |
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | |
.weightInit(WeightInit.XAVIER) | |
.updater(new Adam(learningRate)) | |
.list() | |
.layer(0, new DenseLayer.Builder() | |
.nIn(numInput) | |
.nOut(numHidden) | |
.activation(Activation.TANH) | |
.build()) | |
.layer(1, new OutputLayer.Builder() | |
.nIn(numHidden) | |
.nOut(numOutput) | |
.lossFunction(LossFunctions.LossFunction.MCXENT) | |
.build()) | |
.build(); | |
// loop through the kFoldIterator and print out the observations for each training set and test set | |
int i = 1; | |
System.out.println("-------------------------------------------------------------"); | |
//initialize an empty list to store the F1 score | |
ArrayList<Double> f1List = new ArrayList<>(); | |
while (kFoldIterator.hasNext()) { | |
System.out.println(BLACK_BOLD + "Batch: " + i + ANSI_RESET); | |
// for each fold, get features and labels from training set and test set | |
DataSet currDataSet = kFoldIterator.next(); | |
INDArray trainFoldFeatures = currDataSet.getFeatures(); | |
INDArray trainFoldLabels = currDataSet.getLabels(); | |
INDArray testFoldFeatures = kFoldIterator.testFold().getFeatures(); | |
INDArray testFoldLabels = kFoldIterator.testFold().getLabels(); | |
DataSet trainDataSet = new DataSet(trainFoldFeatures,trainFoldLabels); | |
System.out.println(BLUE_BOLD + "Training Fold: \n" + ANSI_RESET); | |
System.out.println(trainFoldFeatures); | |
System.out.println(BLUE_BOLD + "Test Fold: \n" + ANSI_RESET); | |
System.out.println(testFoldFeatures); | |
// scale or normalize dataset | |
NormalizerMinMaxScaler scaler = new NormalizerMinMaxScaler(); | |
scaler.fit(trainDataSet); | |
scaler.transform(trainFoldFeatures); | |
scaler.transform(testFoldFeatures); | |
// initialize model | |
MultiLayerNetwork model = new MultiLayerNetwork(config); | |
model.init(); | |
// train data | |
for (int j = 0; j < epochs; j++) { | |
model.fit(trainDataSet); | |
} | |
// evaluate model with test set | |
Evaluation eval = new Evaluation(); | |
eval.eval(testFoldLabels, model.output(testFoldFeatures)); | |
// print out evaluation result | |
System.out.println(eval.stats()); | |
// save eval result | |
f1List.add(eval.f1()); | |
i++; | |
System.out.println("-------------------------------------------------------------"); | |
} | |
INDArray f1scores = Nd4j.create(f1List); | |
System.out.println("Average F1 scores for all folds: " + f1scores.mean(0)); | |
} | |
private static DataSet getDataSet() throws IOException, InterruptedException { | |
// define input data schema | |
Schema inputDataSchema = new Schema.Builder() | |
.addColumnString("animal_name") | |
.addColumnsInteger("hair","feathers","eggs","milk","airborne","aquatic","predator","toothed","backbone","breathes","venomous","fins","legs","tail","domestic","catsize") | |
.addColumnCategorical("class_type", Arrays.asList("1","2","3","4","5","6","7")) | |
.build(); | |
// print out schema | |
System.out.println("Before transform: " + inputDataSchema); | |
TransformProcess tp = new TransformProcess.Builder(inputDataSchema) | |
.removeColumns("animal_name") | |
.categoricalToInteger("class_type") | |
.build(); | |
// print out output transform | |
Schema outputTransform = tp.getFinalSchema(); | |
System.out.println("After transform: " + outputTransform); | |
// filePath, csv location | |
File zooFile = new ClassPathResource("animal/zoo.csv").getFile(); | |
FileSplit fileSplit = new FileSplit(zooFile); | |
// read, loading data set | |
RecordReader rr = new CSVRecordReader(1,','); | |
rr.initialize(fileSplit); | |
// adding original data to list | |
List<List<Writable>> originalData = new ArrayList<>(); | |
while(rr.hasNext()){ | |
List<Writable> data = rr.next(); | |
originalData.add(data); | |
} | |
List<List<Writable>> transformedData = LocalTransformExecutor.execute(originalData,tp); | |
// show info | |
int numRows = transformedData.size(); | |
System.out.println(transformedData); | |
System.out.println("Total number of rows: " + numRows); | |
// create iterator from processedData | |
CollectionRecordReader collectionRecordReader = new CollectionRecordReader(transformedData); | |
RecordReaderDataSetIterator dataSetIter = new RecordReaderDataSetIterator(collectionRecordReader, numRows, 16,7); | |
return dataSetIter.next(); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment