Last active
March 10, 2021 11:24
-
-
Save happiie/eeb6cd2b3db6c3c6aeb920e3554a6727 to your computer and use it in GitHub Desktop.
Hi everyone, I encounter an anomaly here in which the code run and show the result. However, the output state (2 classes excluded from average), what this means? and one more, the number of class here, how come it correct when 7? below I put the link to gist, thank you for helping.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package ai.certifai.farhan.practise; | |
import org.datavec.api.records.reader.impl.collection.CollectionRecordReader; | |
import org.datavec.api.transform.TransformProcess; | |
import org.datavec.api.transform.schema.Schema; | |
import org.datavec.api.records.reader.RecordReader; | |
import org.datavec.api.records.reader.impl.csv.CSVRecordReader; | |
import org.datavec.api.split.FileSplit; | |
import org.datavec.api.writable.Writable; | |
import org.datavec.local.transforms.LocalTransformExecutor; | |
import org.deeplearning4j.core.storage.StatsStorage; | |
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; | |
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.layers.DenseLayer; | |
import org.deeplearning4j.nn.conf.layers.OutputLayer; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.ui.api.UIServer; | |
import org.deeplearning4j.ui.model.stats.StatsListener; | |
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage; | |
import org.nd4j.common.io.ClassPathResource; | |
import org.nd4j.evaluation.classification.Evaluation; | |
import org.nd4j.linalg.activations.Activation; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.dataset.DataSet; | |
import org.nd4j.linalg.dataset.SplitTestAndTrain; | |
import org.nd4j.linalg.dataset.ViewIterator; | |
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | |
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; | |
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler; | |
import org.nd4j.linalg.factory.Nd4j; | |
import org.nd4j.linalg.learning.config.Adam; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT; | |
import java.io.File; | |
import java.util.ArrayList; | |
import java.util.Arrays; | |
import java.util.List; | |
public class carPrice_cclass { | |
static int seed = 123; | |
static int numInput = 5; // # of features | |
static int numClass = 7; // # of label | |
static int epoch = 1000; | |
static double splitRatio = 0.8; | |
static double learningRate = 1e-2; // 1e-2 -> 0.01 | |
public static void main(String[] args) throws Exception { | |
// set filepath | |
File dataFile = new ClassPathResource("/carPrice/cclass.csv").getFile(); | |
// File split | |
FileSplit fileSplit = new FileSplit(dataFile); | |
// set CSV Record Reader and initialize it | |
RecordReader rr = new CSVRecordReader(1,','); | |
rr.initialize(fileSplit); | |
//========================================================================= | |
// Step 1 : Build Schema to prepare the data | |
//========================================================================= | |
// Build Schema to prepare the data | |
Schema sc = new Schema.Builder() | |
.addColumnsString("model") | |
.addColumnCategorical("year",Arrays.asList("1991","1995","1998","2002","2003","2004","2005","2006","2007","2008","2009","2010","2011","2012","2013","2014","2015","2016","2017","2018","2019","2020")) | |
.addColumnInteger("price") | |
.addColumnCategorical("transmission", Arrays.asList("Automatic","Manual","Other","Semi-Auto")) | |
.addColumnsInteger("mileage") | |
.addColumnCategorical("fuelType", Arrays.asList("Diesel","Hybrid","Other","Petrol")) | |
.addColumnsFloat("engineSize") | |
.build(); | |
System.out.println("Before transformation: " + sc); | |
//========================================================================= | |
// Step 2 : Build TransformProcess to transform the data | |
//========================================================================= | |
// remove column, category to integer | |
TransformProcess tp = new TransformProcess.Builder(sc) | |
.removeColumns("model") | |
.categoricalToInteger("year") | |
.categoricalToInteger("transmission") | |
.categoricalToInteger("fuelType") | |
.build(); | |
// Checking the Schema | |
Schema outputSchema = tp.getFinalSchema(); | |
System.out.println("Final transformation: " + outputSchema); | |
List<List<Writable>> allData = new ArrayList<>(); | |
while(rr.hasNext()){ | |
allData.add(rr.next()); | |
} | |
List<List<Writable>> processData = LocalTransformExecutor.execute(allData, tp); | |
//======================================================================== | |
// Step 3 : Create Iterator ,splitting trainData and testData | |
//======================================================================== | |
//Create iterator from process data | |
CollectionRecordReader collectionRR = new CollectionRecordReader(processData); | |
//Input batch size , label index , and number of label | |
// label index = -1 -> last column, -2 -> 2nd last column | |
DataSetIterator dataSetIterator = new RecordReaderDataSetIterator(collectionRR, processData.size(),-5,numClass); | |
//Create Iterator and shuffle the dat | |
DataSet fullDataset = dataSetIterator.next(); | |
fullDataset.shuffle(seed); | |
//Input split ratio | |
SplitTestAndTrain testAndTrain = fullDataset.splitTestAndTrain(splitRatio); | |
//Get train and test dataset | |
DataSet trainData = testAndTrain.getTrain(); | |
DataSet testData = testAndTrain.getTest(); | |
//printout size | |
System.out.println("Training vector : "); | |
System.out.println(Arrays.toString(trainData.getFeatures().shape())); | |
System.out.println("Test vector : "); | |
System.out.println(Arrays.toString(testData.getFeatures().shape())); | |
//======================================================================== | |
// Step 4 : DataNormalization | |
//======================================================================== | |
//Data normalization | |
DataNormalization normalize = new NormalizerMinMaxScaler(); | |
normalize.fit(trainData); // build function | |
normalize.transform(trainData); // normalize happen | |
normalize.transform(testData); // normalize happen | |
System.out.println("normalize = " + normalize); | |
//======================================================================== | |
// Step 5 : Network Configuration | |
//======================================================================== | |
//Get network configuration ( uncomment these lines ) | |
MultiLayerConfiguration config = getConfig(numInput, numClass, learningRate); | |
//Define network ( uncomment these lines ) | |
MultiLayerNetwork model = new MultiLayerNetwork(config); | |
model.init(); | |
//======================================================================== | |
// Step 6 : Setup UI , listeners | |
//======================================================================== | |
//UI-Evaluator | |
StatsStorage storage = new InMemoryStatsStorage(); | |
UIServer server = UIServer.getInstance(); | |
server.attach(storage); | |
//Set model listeners ( uncomment these lines ) | |
model.setListeners(new StatsListener(storage, 10)); | |
//======================================================================== | |
//Step 7 : Training | |
//======================================================================== | |
//Training | |
Evaluation eval; | |
for(int i=0; i < epoch; i++) { | |
model.fit(trainData); | |
eval = model.evaluate(new ViewIterator(testData, processData.size())); | |
System.out.println("EPOCH: " + i + " Accuracy: " + eval.accuracy()); | |
} | |
//======================================================================== | |
// Step 8 : Evaluation | |
//======================================================================== | |
//Confusion matrix | |
//TrainData | |
Evaluation evalTrain = model.evaluate(new ViewIterator(trainData, processData.size())); | |
System.out.print("Train Data"); | |
System.out.println(evalTrain.stats()); | |
//TestData | |
Evaluation evalTest = model.evaluate(new ViewIterator(testData, processData.size())); | |
System.out.print("Test Data"); | |
System.out.print(evalTest.stats()); | |
} | |
public static MultiLayerConfiguration getConfig(int numInputs, int numOutputs, double learningRate) { | |
MultiLayerConfiguration config = new NeuralNetConfiguration.Builder() | |
.seed(seed) | |
.weightInit(WeightInit.XAVIER) | |
.updater(new Adam(learningRate)) | |
.l2(0.001) | |
.list() | |
.layer(0, new DenseLayer.Builder() | |
.nIn(numInputs) | |
.nOut(30) | |
.activation(Activation.RELU) | |
.build()) | |
.layer(1, new DenseLayer.Builder() | |
.nIn(30) | |
.nOut(40) | |
.activation(Activation.RELU) | |
.build()) | |
.layer(2, new DenseLayer.Builder() | |
.nIn(40) | |
.nOut(30) | |
.activation(Activation.RELU) | |
.build()) | |
.layer(3, new DenseLayer.Builder() | |
.nIn(30) | |
.nOut(20) | |
.activation(Activation.RELU) | |
.build()) | |
.layer(4, new OutputLayer.Builder() | |
.nIn(20) | |
.nOut(numOutputs) | |
.lossFunction(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) | |
.activation(Activation.SOFTMAX) | |
.build()) | |
.build(); | |
return config; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment