Skip to content

Instantly share code, notes, and snippets.

@happiie
Last active March 13, 2021 09:14
Show Gist options
  • Save happiie/3cabfe64875ae0c059964901c7b9ca7e to your computer and use it in GitHub Desktop.
Save happiie/3cabfe64875ae0c059964901c7b9ca7e to your computer and use it in GitHub Desktop.
fishWeight excercise - encounter an anomaly here (2 classes excluded from average)
package ai.certifai.farhan.practise;
import org.datavec.api.records.reader.impl.collection.CollectionRecordReader;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.LocalTransformExecutor;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.ViewIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class fishWeight {
private static Logger log = LoggerFactory.getLogger(fishWeight.class);
private static final int batchSize = 10;
private static final int seed = 123;
private static final double learningRate = 1e-2;
private static final int numInput = 12;
private static final int numOutput = 1;
private static final int nEpoch = 500;
private static final int numHidden = 1000;
//private static final double regValue = 1e-4;
public static void main(String[] args) throws IOException, InterruptedException {
/*
* STEP 1: DATA PREPARATION
*
* */
// set filePath
File fishData = new ClassPathResource("fish/Fish.csv").getFile();
System.out.println("File location = " + fishData);
// File Split
FileSplit fileSplit = new FileSplit(fishData);
// set CSV record reader and initialize it
RecordReader rr = new CSVRecordReader(1,','); // skip label row from dataset
rr.initialize(fileSplit);
// build schema to prepare data
Schema fishSchema = new Schema.Builder()
.addColumnCategorical("Species", Arrays.asList("Bream","Parkki","Perch","Pike","Roach","Smelt","Whitefish"))
.addColumnsDouble("Weight","Length1","length2","Length3","Height","Width")
.build();
System.out.println("Before transformation: " + fishSchema);
// transform data
TransformProcess fishTransform = new TransformProcess.Builder(fishSchema)
.categoricalToOneHot("Species")
.build();
// Checking the Schema
Schema outputTransform = fishTransform.getFinalSchema();
System.out.println("After transformation: " + outputTransform);
// adding original data to a list for later transform purpose
List<List<Writable>> originalData = new ArrayList<>();
while (rr.hasNext()) {
List<Writable> data = rr.next();
originalData.add(data);
}
List<List<Writable>> transformedData = LocalTransformExecutor.execute(originalData,fishTransform);
// Printing out the transformed data
for (int i = 0; i < transformedData.size(); i++) {
System.out.println("Transformed data : " + transformedData.get(i));
}
// Preparing to split the dataset into training and test
CollectionRecordReader collectionRecordReader = new CollectionRecordReader(transformedData);
DataSetIterator iterator = new RecordReaderDataSetIterator(collectionRecordReader, transformedData.size(),9,9,true);
DataSet dataSet = iterator.next();
dataSet.shuffle();
SplitTestAndTrain testAndTrain = dataSet.splitTestAndTrain(0.8);
DataSet train = testAndTrain.getTrain();
DataSet test = testAndTrain.getTest();
INDArray features = train.getFeatures();
System.out.println("\nFeature shape: " + features.shapeInfoToString() + "\n");
// assigning dataset iterator for training purpose
ViewIterator trainIter = new ViewIterator(train, batchSize);
ViewIterator testIter = new ViewIterator(test,batchSize);
/*
* STEP 2: MODEL TRAINING
*
* */
// configuring the structure if the model
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
//.l2(regValue)
.weightInit(WeightInit.XAVIER)
.updater(new Adam(learningRate))
.list()
.layer(0, new DenseLayer.Builder() // input layer
.nIn(numInput)
.nOut(numHidden)
.activation(Activation.RELU)
.build())
.layer(1, new DenseLayer.Builder() // hidden layer 1
.nOut(numHidden)
.activation(Activation.RELU)
.build())
.layer(2, new DenseLayer.Builder() // hidden layer 2
.nOut(numHidden)
.activation(Activation.RELU)
.build())
.layer(3, new DenseLayer.Builder() // hidden layer 3
.nOut(numHidden)
.activation(Activation.RELU)
.build())
.layer(4,new OutputLayer.Builder() // output layer
.nOut(numOutput)
.activation(Activation.IDENTITY)
.lossFunction(LossFunctions.LossFunction.MSE)
.build())
.build();
// model
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(100));
// Fitting the model for nEpochs
for(int i =0; i<nEpoch;i++){
if(i%1000==0){
System.out.println("Epoch: " + i);
}
model.fit(trainIter);
}
// Evaluating the outcome of our trained model
RegressionEvaluation regEval= model.evaluateRegression(testIter);
System.out.println(regEval.stats());
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment