Skip to content

Instantly share code, notes, and snippets.

@happiie
Last active March 10, 2021 11:24
Show Gist options
  • Save happiie/eeb6cd2b3db6c3c6aeb920e3554a6727 to your computer and use it in GitHub Desktop.
Save happiie/eeb6cd2b3db6c3c6aeb920e3554a6727 to your computer and use it in GitHub Desktop.
Hi everyone, I encounter an anomaly here in which the code run and show the result. However, the output state (2 classes excluded from average), what this means? and one more, the number of class here, how come it correct when 7? below I put the link to gist, thank you for helping.
package ai.certifai.farhan.practise;
import org.datavec.api.records.reader.impl.collection.CollectionRecordReader;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.LocalTransformExecutor;
import org.deeplearning4j.core.storage.StatsStorage;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.model.stats.StatsListener;
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.ViewIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class carPrice_cclass {
static int seed = 123;
static int numInput = 5; // # of features
static int numClass = 7; // # of label
static int epoch = 1000;
static double splitRatio = 0.8;
static double learningRate = 1e-2; // 1e-2 -> 0.01
public static void main(String[] args) throws Exception {
// set filepath
File dataFile = new ClassPathResource("/carPrice/cclass.csv").getFile();
// File split
FileSplit fileSplit = new FileSplit(dataFile);
// set CSV Record Reader and initialize it
RecordReader rr = new CSVRecordReader(1,',');
rr.initialize(fileSplit);
//=========================================================================
// Step 1 : Build Schema to prepare the data
//=========================================================================
// Build Schema to prepare the data
Schema sc = new Schema.Builder()
.addColumnsString("model")
.addColumnCategorical("year",Arrays.asList("1991","1995","1998","2002","2003","2004","2005","2006","2007","2008","2009","2010","2011","2012","2013","2014","2015","2016","2017","2018","2019","2020"))
.addColumnInteger("price")
.addColumnCategorical("transmission", Arrays.asList("Automatic","Manual","Other","Semi-Auto"))
.addColumnsInteger("mileage")
.addColumnCategorical("fuelType", Arrays.asList("Diesel","Hybrid","Other","Petrol"))
.addColumnsFloat("engineSize")
.build();
System.out.println("Before transformation: " + sc);
//=========================================================================
// Step 2 : Build TransformProcess to transform the data
//=========================================================================
// remove column, category to integer
TransformProcess tp = new TransformProcess.Builder(sc)
.removeColumns("model")
.categoricalToInteger("year")
.categoricalToInteger("transmission")
.categoricalToInteger("fuelType")
.build();
// Checking the Schema
Schema outputSchema = tp.getFinalSchema();
System.out.println("Final transformation: " + outputSchema);
List<List<Writable>> allData = new ArrayList<>();
while(rr.hasNext()){
allData.add(rr.next());
}
List<List<Writable>> processData = LocalTransformExecutor.execute(allData, tp);
//========================================================================
// Step 3 : Create Iterator ,splitting trainData and testData
//========================================================================
//Create iterator from process data
CollectionRecordReader collectionRR = new CollectionRecordReader(processData);
//Input batch size , label index , and number of label
// label index = -1 -> last column, -2 -> 2nd last column
DataSetIterator dataSetIterator = new RecordReaderDataSetIterator(collectionRR, processData.size(),-5,numClass);
//Create Iterator and shuffle the dat
DataSet fullDataset = dataSetIterator.next();
fullDataset.shuffle(seed);
//Input split ratio
SplitTestAndTrain testAndTrain = fullDataset.splitTestAndTrain(splitRatio);
//Get train and test dataset
DataSet trainData = testAndTrain.getTrain();
DataSet testData = testAndTrain.getTest();
//printout size
System.out.println("Training vector : ");
System.out.println(Arrays.toString(trainData.getFeatures().shape()));
System.out.println("Test vector : ");
System.out.println(Arrays.toString(testData.getFeatures().shape()));
//========================================================================
// Step 4 : DataNormalization
//========================================================================
//Data normalization
DataNormalization normalize = new NormalizerMinMaxScaler();
normalize.fit(trainData); // build function
normalize.transform(trainData); // normalize happen
normalize.transform(testData); // normalize happen
System.out.println("normalize = " + normalize);
//========================================================================
// Step 5 : Network Configuration
//========================================================================
//Get network configuration ( uncomment these lines )
MultiLayerConfiguration config = getConfig(numInput, numClass, learningRate);
//Define network ( uncomment these lines )
MultiLayerNetwork model = new MultiLayerNetwork(config);
model.init();
//========================================================================
// Step 6 : Setup UI , listeners
//========================================================================
//UI-Evaluator
StatsStorage storage = new InMemoryStatsStorage();
UIServer server = UIServer.getInstance();
server.attach(storage);
//Set model listeners ( uncomment these lines )
model.setListeners(new StatsListener(storage, 10));
//========================================================================
//Step 7 : Training
//========================================================================
//Training
Evaluation eval;
for(int i=0; i < epoch; i++) {
model.fit(trainData);
eval = model.evaluate(new ViewIterator(testData, processData.size()));
System.out.println("EPOCH: " + i + " Accuracy: " + eval.accuracy());
}
//========================================================================
// Step 8 : Evaluation
//========================================================================
//Confusion matrix
//TrainData
Evaluation evalTrain = model.evaluate(new ViewIterator(trainData, processData.size()));
System.out.print("Train Data");
System.out.println(evalTrain.stats());
//TestData
Evaluation evalTest = model.evaluate(new ViewIterator(testData, processData.size()));
System.out.print("Test Data");
System.out.print(evalTest.stats());
}
public static MultiLayerConfiguration getConfig(int numInputs, int numOutputs, double learningRate) {
MultiLayerConfiguration config = new NeuralNetConfiguration.Builder()
.seed(seed)
.weightInit(WeightInit.XAVIER)
.updater(new Adam(learningRate))
.l2(0.001)
.list()
.layer(0, new DenseLayer.Builder()
.nIn(numInputs)
.nOut(30)
.activation(Activation.RELU)
.build())
.layer(1, new DenseLayer.Builder()
.nIn(30)
.nOut(40)
.activation(Activation.RELU)
.build())
.layer(2, new DenseLayer.Builder()
.nIn(40)
.nOut(30)
.activation(Activation.RELU)
.build())
.layer(3, new DenseLayer.Builder()
.nIn(30)
.nOut(20)
.activation(Activation.RELU)
.build())
.layer(4, new OutputLayer.Builder()
.nIn(20)
.nOut(numOutputs)
.lossFunction(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX)
.build())
.build();
return config;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment