Skip to content

Instantly share code, notes, and snippets.

@yptheangel
Last active May 14, 2020 23:58
Show Gist options
  • Save yptheangel/ed220a3e7bf6635531df457e25ea8402 to your computer and use it in GitHub Desktop.
Save yptheangel/ed220a3e7bf6635531df457e25ea8402 to your computer and use it in GitHub Desktop.
experiment
import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.image.loader.BaseImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.image.transform.*;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.storage.FileStatsStorage;
import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.zoo.ZooModel;
import org.deeplearning4j.zoo.model.ResNet50;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.schedule.CycleSchedule;
import org.nd4j.linalg.schedule.MapSchedule;
import org.nd4j.linalg.schedule.ScheduleType;
import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
public class FoodClassifier {
public static int epochs = 30;
public static void main(String[] args) throws IOException {
String homePath = System.getProperty("user.home");
Path trainPath = Paths.get(homePath, "dataset", "food-dataset", "train");
Path testPath = Paths.get(homePath, "dataset", "food-dataset", "test");
System.out.println(trainPath.toString());
int seed = 1234;
Random randNumGen = new Random(seed);
String[] allowedExtensions = BaseImageLoader.ALLOWED_FORMATS;
DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
FileSplit fileSplitTrain = new FileSplit(new File(trainPath.toString()));
FileSplit fileSplitTest = new FileSplit(new File(testPath.toString()));
ParentPathLabelGenerator labelGenerator = new ParentPathLabelGenerator();
BalancedPathFilter balancedPathFilter = new BalancedPathFilter(randNumGen, allowedExtensions, labelGenerator);
InputSplit trainData = fileSplitTrain.sample(balancedPathFilter)[0];
InputSplit testData = fileSplitTest.sample(balancedPathFilter)[0];
// image augmentation
ImageTransform horizontalFlip = new FlipImageTransform(1);
ImageTransform verticalFlip = new FlipImageTransform(0);
// ImageTransform cropImage = new CropImageTransform(25);
ImageTransform rotateImage = new RotateImageTransform(randNumGen, 15);
ImageTransform showImage = new ShowImageTransform("Image", 1000);
boolean shuffle = false;
List<Pair<ImageTransform, Double>> pipeline = Arrays.asList(
new Pair<>(verticalFlip, 0.25),
new Pair<>(horizontalFlip, 0.5),
new Pair<>(rotateImage, 0.2)
// new Pair<>(rotateImage, 0.5),
// new Pair<>(cropImage, 0.3)
// ,new Pair<>(showImage,1.0) //uncomment this to show transform image
);
ImageTransform transform = new PipelineImageTransform(pipeline, shuffle);
ImageRecordReader trainRecordReader = new ImageRecordReader(224, 224, 3, labelGenerator);
ImageRecordReader testRecordReader = new ImageRecordReader(224, 224, 3, labelGenerator);
trainRecordReader.initialize(trainData, transform);
testRecordReader.initialize(testData);
// DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRecordReader, 64, 1, 5);
// DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRecordReader, 18, 1, 5);
DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRecordReader, 14, 1, 5);
DataSetIterator testIter = new RecordReaderDataSetIterator(testRecordReader, 14, 1, 5);
trainIter.setPreProcessor(scaler);
testIter.setPreProcessor(scaler);
ZooModel zooModel = ResNet50.builder().build();
ComputationGraph resnet = (ComputationGraph) zooModel.initPretrained();
// Map<Integer,Double> learningRateSchedule = new HashMap<>();
// learningRateSchedule.put(0,0.0000001);
// learningRateSchedule.put(1,0.001000099);
// learningRateSchedule.put(2,0.002000098);
// learningRateSchedule.put(3,0.003000097);
// learningRateSchedule.put(4,0.004000096);
// learningRateSchedule.put(5,0.005000095);
// learningRateSchedule.put(6,0.006000094);
// learningRateSchedule.put(7,0.007000093);
// learningRateSchedule.put(8,0.008000092);
// learningRateSchedule.put(9,0.009000091);
// learningRateSchedule.put(10,0.01000009);
// learningRateSchedule.put(11,0.011000089);
// learningRateSchedule.put(12,0.012000088);
// learningRateSchedule.put(13,0.013000087);
// learningRateSchedule.put(14,0.014000086);
// learningRateSchedule.put(15,0.015000085);
// learningRateSchedule.put(16,0.016000084);
// learningRateSchedule.put(17,0.017000083);
// learningRateSchedule.put(18,0.018000082);
// learningRateSchedule.put(19,0.019000081);
// learningRateSchedule.put(20,0.02000008);
// learningRateSchedule.put(21,0.021000079);
// learningRateSchedule.put(22,0.022000078);
// learningRateSchedule.put(23,0.023000077);
// learningRateSchedule.put(24,0.024000076);
// learningRateSchedule.put(25,0.025000075);
// learningRateSchedule.put(26,0.026000074);
// learningRateSchedule.put(27,0.027000073);
// learningRateSchedule.put(28,0.028000072);
// learningRateSchedule.put(29,0.029000071);
// learningRateSchedule.put(30,0.03000007);
// learningRateSchedule.put(31,0.031000069);
// learningRateSchedule.put(32,0.032000068);
// learningRateSchedule.put(33,0.033000067);
// learningRateSchedule.put(34,0.034000066);
// learningRateSchedule.put(35,0.035000065);
// learningRateSchedule.put(36,0.036000064);
// learningRateSchedule.put(37,0.037000063);
// learningRateSchedule.put(38,0.038000062);
// learningRateSchedule.put(39,0.039000061);
// learningRateSchedule.put(40,0.04000006);
// learningRateSchedule.put(41,0.041000059);
// learningRateSchedule.put(42,0.042000058);
// learningRateSchedule.put(43,0.043000057);
// learningRateSchedule.put(44,0.044000056);
// learningRateSchedule.put(45,0.045000055);
// learningRateSchedule.put(46,0.046000054);
// learningRateSchedule.put(47,0.047000053);
// learningRateSchedule.put(48,0.048000052);
// learningRateSchedule.put(49,0.049000051);
// learningRateSchedule.put(50,0.05000005);
// learningRateSchedule.put(51,0.0510000490000001);
// learningRateSchedule.put(52,0.0520000480000001);
// learningRateSchedule.put(53,0.0530000470000001);
// learningRateSchedule.put(54,0.0540000460000001);
// learningRateSchedule.put(55,0.0550000450000001);
// learningRateSchedule.put(56,0.0560000440000001);
// learningRateSchedule.put(57,0.0570000430000001);
// learningRateSchedule.put(58,0.0580000420000001);
// learningRateSchedule.put(59,0.0590000410000001);
// learningRateSchedule.put(60,0.0600000400000001);
// learningRateSchedule.put(61,0.0610000390000001);
// learningRateSchedule.put(62,0.0620000380000001);
// learningRateSchedule.put(63,0.0630000370000001);
// learningRateSchedule.put(64,0.0640000360000001);
// learningRateSchedule.put(65,0.0650000350000001);
// learningRateSchedule.put(66,0.0660000340000001);
// learningRateSchedule.put(67,0.0670000330000001);
// learningRateSchedule.put(68,0.0680000320000001);
// learningRateSchedule.put(69,0.0690000310000001);
// learningRateSchedule.put(70,0.0700000300000001);
// learningRateSchedule.put(71,0.0710000290000001);
// learningRateSchedule.put(72,0.0720000280000001);
// learningRateSchedule.put(73,0.0730000270000001);
// learningRateSchedule.put(74,0.0740000260000001);
// learningRateSchedule.put(75,0.0750000250000001);
// learningRateSchedule.put(76,0.0760000240000001);
// learningRateSchedule.put(77,0.0770000230000001);
// learningRateSchedule.put(78,0.0780000220000001);
// learningRateSchedule.put(79,0.0790000210000001);
// learningRateSchedule.put(80,0.0800000200000001);
// learningRateSchedule.put(81,0.0810000190000001);
// learningRateSchedule.put(82,0.0820000180000001);
// learningRateSchedule.put(83,0.0830000170000001);
// learningRateSchedule.put(84,0.0840000160000001);
// learningRateSchedule.put(85,0.0850000150000001);
// learningRateSchedule.put(86,0.0860000140000001);
// learningRateSchedule.put(87,0.0870000130000001);
// learningRateSchedule.put(88,0.0880000120000001);
// learningRateSchedule.put(89,0.0890000110000001);
// learningRateSchedule.put(90,0.0900000100000001);
// learningRateSchedule.put(91,0.0910000090000001);
// learningRateSchedule.put(92,0.0920000080000001);
// learningRateSchedule.put(93,0.0930000070000001);
// learningRateSchedule.put(94,0.0940000060000001);
// learningRateSchedule.put(95,0.0950000050000001);
// learningRateSchedule.put(96,0.0960000040000001);
// learningRateSchedule.put(97,0.0970000030000001);
// learningRateSchedule.put(98,0.0980000020000001);
// learningRateSchedule.put(99,0.0990000010000001);
// learningRateSchedule.put(100,0.1);
FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder()
// .updater(new Adam(1e-2))
// .updater(new Adam(new CycleSchedule(ScheduleType.EPOCH, 3e-5, 3e-4, epochs, (int) Math.round(epochs * 0.1), 0.1)))
.updater(new Adam(new CycleSchedule(ScheduleType.EPOCH, 1e-3, 1e-2, epochs, (int) Math.round(epochs * 0.1), 0.1)))
// .updater(new Adam(new MapSchedule(ScheduleType.ITERATION,learningRateSchedule)))
.seed(seed)
.build();
// System.out.println(resnet.summary());
ComputationGraph resnet50Transfer = new TransferLearning.GraphBuilder(resnet)
.fineTuneConfiguration(fineTuneConf)
.setFeatureExtractor("flatten_1")
.removeVertexKeepConnections("fc1000")
// .setFeatureExtractor("bn5b_branch2c") //"bn5b_branch2c" and below are frozen
// .addLayer("fc", new DenseLayer
// .Builder().activation(Activation.RELU).nIn(1000).nOut(32).build(), "fc1000") //add in a new dense layer
.addLayer("newpredictions", new OutputLayer
.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
// .Builder(LossFunctions.LossFunction.MCXENT)
// .activation(Activation.SOFTMAX)
.nIn(2048)
.nOut(5)
.build(), "flatten_1")
.setOutputs("newpredictions")
.build();
UIServer uiServer = UIServer.getInstance();
StatsStorage statsStorage = new FileStatsStorage(new File(System.getProperty("java.io.tmpdir"), "ui-stats.dl4j"));
uiServer.attach(statsStorage);
resnet50Transfer.setListeners(
// new StatsListener(statsStorage)
new StatsListener(statsStorage),
new ScoreIterationListener(5)
);
// for (int i =0; i<100;i++){
// resnet50Transfer.fit(trainIter.next());
// System.out.println(resnet50Transfer.getLearningRate("newpredictions")+","+resnet50Transfer.score());
// }
for (int i = 0; i < epochs; i++) {
resnet50Transfer.fit(trainIter);
// Evaluation trainEval = resnet50Transfer.evaluate(trainIter);
// Evaluation testEval = resnet50Transfer.evaluate(testIter);
// System.out.println(trainEval.stats());
// System.out.println(testEval.stats());
System.out.println("Completed epoch: " + (i + 1));
}
ModelSerializer.writeModel(resnet50Transfer, homePath +"/food_"+resnet50Transfer.score()+".zip", true);
Evaluation trainEval = resnet50Transfer.evaluate(trainIter);
Evaluation testEval = resnet50Transfer.evaluate(testIter);
System.out.println(trainEval.stats());
System.out.println(testEval.stats());
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment