Last active
March 13, 2021 09:14
-
-
Save happiie/3cabfe64875ae0c059964901c7b9ca7e to your computer and use it in GitHub Desktop.
fishWeight excercise - encounter an anomaly here (2 classes excluded from average)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package ai.certifai.farhan.practise; | |
import org.datavec.api.records.reader.impl.collection.CollectionRecordReader; | |
import org.datavec.api.transform.TransformProcess; | |
import org.datavec.api.transform.schema.Schema; | |
import org.datavec.api.records.reader.RecordReader; | |
import org.datavec.api.records.reader.impl.csv.CSVRecordReader; | |
import org.datavec.api.split.FileSplit; | |
import org.datavec.api.writable.Writable; | |
import org.datavec.local.transforms.LocalTransformExecutor; | |
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; | |
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.layers.DenseLayer; | |
import org.deeplearning4j.nn.conf.layers.OutputLayer; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
import org.nd4j.common.io.ClassPathResource; | |
import org.nd4j.evaluation.regression.RegressionEvaluation; | |
import org.nd4j.linalg.activations.Activation; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.dataset.DataSet; | |
import org.nd4j.linalg.dataset.SplitTestAndTrain; | |
import org.nd4j.linalg.dataset.ViewIterator; | |
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | |
import org.nd4j.linalg.learning.config.Adam; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import java.io.File; | |
import java.io.IOException; | |
import java.util.ArrayList; | |
import java.util.Arrays; | |
import java.util.List; | |
public class fishWeight { | |
private static Logger log = LoggerFactory.getLogger(fishWeight.class); | |
private static final int batchSize = 10; | |
private static final int seed = 123; | |
private static final double learningRate = 1e-2; | |
private static final int numInput = 12; | |
private static final int numOutput = 1; | |
private static final int nEpoch = 500; | |
private static final int numHidden = 1000; | |
//private static final double regValue = 1e-4; | |
public static void main(String[] args) throws IOException, InterruptedException { | |
/* | |
* STEP 1: DATA PREPARATION | |
* | |
* */ | |
// set filePath | |
File fishData = new ClassPathResource("fish/Fish.csv").getFile(); | |
System.out.println("File location = " + fishData); | |
// File Split | |
FileSplit fileSplit = new FileSplit(fishData); | |
// set CSV record reader and initialize it | |
RecordReader rr = new CSVRecordReader(1,','); // skip label row from dataset | |
rr.initialize(fileSplit); | |
// build schema to prepare data | |
Schema fishSchema = new Schema.Builder() | |
.addColumnCategorical("Species", Arrays.asList("Bream","Parkki","Perch","Pike","Roach","Smelt","Whitefish")) | |
.addColumnsDouble("Weight","Length1","length2","Length3","Height","Width") | |
.build(); | |
System.out.println("Before transformation: " + fishSchema); | |
// transform data | |
TransformProcess fishTransform = new TransformProcess.Builder(fishSchema) | |
.categoricalToOneHot("Species") | |
.build(); | |
// Checking the Schema | |
Schema outputTransform = fishTransform.getFinalSchema(); | |
System.out.println("After transformation: " + outputTransform); | |
// adding original data to a list for later transform purpose | |
List<List<Writable>> originalData = new ArrayList<>(); | |
while (rr.hasNext()) { | |
List<Writable> data = rr.next(); | |
originalData.add(data); | |
} | |
List<List<Writable>> transformedData = LocalTransformExecutor.execute(originalData,fishTransform); | |
// Printing out the transformed data | |
for (int i = 0; i < transformedData.size(); i++) { | |
System.out.println("Transformed data : " + transformedData.get(i)); | |
} | |
// Preparing to split the dataset into training and test | |
CollectionRecordReader collectionRecordReader = new CollectionRecordReader(transformedData); | |
DataSetIterator iterator = new RecordReaderDataSetIterator(collectionRecordReader, transformedData.size(),9,9,true); | |
DataSet dataSet = iterator.next(); | |
dataSet.shuffle(); | |
SplitTestAndTrain testAndTrain = dataSet.splitTestAndTrain(0.8); | |
DataSet train = testAndTrain.getTrain(); | |
DataSet test = testAndTrain.getTest(); | |
INDArray features = train.getFeatures(); | |
System.out.println("\nFeature shape: " + features.shapeInfoToString() + "\n"); | |
// assigning dataset iterator for training purpose | |
ViewIterator trainIter = new ViewIterator(train, batchSize); | |
ViewIterator testIter = new ViewIterator(test,batchSize); | |
/* | |
* STEP 2: MODEL TRAINING | |
* | |
* */ | |
// configuring the structure if the model | |
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() | |
.seed(seed) | |
//.l2(regValue) | |
.weightInit(WeightInit.XAVIER) | |
.updater(new Adam(learningRate)) | |
.list() | |
.layer(0, new DenseLayer.Builder() // input layer | |
.nIn(numInput) | |
.nOut(numHidden) | |
.activation(Activation.RELU) | |
.build()) | |
.layer(1, new DenseLayer.Builder() // hidden layer 1 | |
.nOut(numHidden) | |
.activation(Activation.RELU) | |
.build()) | |
.layer(2, new DenseLayer.Builder() // hidden layer 2 | |
.nOut(numHidden) | |
.activation(Activation.RELU) | |
.build()) | |
.layer(3, new DenseLayer.Builder() // hidden layer 3 | |
.nOut(numHidden) | |
.activation(Activation.RELU) | |
.build()) | |
.layer(4,new OutputLayer.Builder() // output layer | |
.nOut(numOutput) | |
.activation(Activation.IDENTITY) | |
.lossFunction(LossFunctions.LossFunction.MSE) | |
.build()) | |
.build(); | |
// model | |
MultiLayerNetwork model = new MultiLayerNetwork(conf); | |
model.init(); | |
model.setListeners(new ScoreIterationListener(100)); | |
// Fitting the model for nEpochs | |
for(int i =0; i<nEpoch;i++){ | |
if(i%1000==0){ | |
System.out.println("Epoch: " + i); | |
} | |
model.fit(trainIter); | |
} | |
// Evaluating the outcome of our trained model | |
RegressionEvaluation regEval= model.evaluateRegression(testIter); | |
System.out.println(regEval.stats()); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment