Last active
May 14, 2020 23:58
-
-
Save yptheangel/ed220a3e7bf6635531df457e25ea8402 to your computer and use it in GitHub Desktop.
experiment
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.datavec.api.io.filters.BalancedPathFilter; | |
import org.datavec.api.io.labels.ParentPathLabelGenerator; | |
import org.datavec.api.split.FileSplit; | |
import org.datavec.api.split.InputSplit; | |
import org.datavec.image.loader.BaseImageLoader; | |
import org.datavec.image.recordreader.ImageRecordReader; | |
import org.datavec.image.transform.*; | |
import org.deeplearning4j.api.storage.StatsStorage; | |
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; | |
import org.deeplearning4j.nn.conf.layers.DenseLayer; | |
import org.deeplearning4j.nn.conf.layers.OutputLayer; | |
import org.deeplearning4j.nn.graph.ComputationGraph; | |
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration; | |
import org.deeplearning4j.nn.transferlearning.TransferLearning; | |
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
import org.deeplearning4j.ui.api.UIServer; | |
import org.deeplearning4j.ui.stats.StatsListener; | |
import org.deeplearning4j.ui.storage.FileStatsStorage; | |
import org.deeplearning4j.util.ModelSerializer; | |
import org.deeplearning4j.zoo.ZooModel; | |
import org.deeplearning4j.zoo.model.ResNet50; | |
import org.nd4j.evaluation.classification.Evaluation; | |
import org.nd4j.linalg.activations.Activation; | |
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | |
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; | |
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler; | |
import org.nd4j.linalg.learning.config.Adam; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
import org.nd4j.linalg.primitives.Pair; | |
import org.nd4j.linalg.schedule.CycleSchedule; | |
import org.nd4j.linalg.schedule.MapSchedule; | |
import org.nd4j.linalg.schedule.ScheduleType; | |
import java.io.File; | |
import java.io.IOException; | |
import java.nio.file.Path; | |
import java.nio.file.Paths; | |
import java.util.*; | |
public class FoodClassifier { | |
public static int epochs = 30; | |
public static void main(String[] args) throws IOException { | |
String homePath = System.getProperty("user.home"); | |
Path trainPath = Paths.get(homePath, "dataset", "food-dataset", "train"); | |
Path testPath = Paths.get(homePath, "dataset", "food-dataset", "test"); | |
System.out.println(trainPath.toString()); | |
int seed = 1234; | |
Random randNumGen = new Random(seed); | |
String[] allowedExtensions = BaseImageLoader.ALLOWED_FORMATS; | |
DataNormalization scaler = new ImagePreProcessingScaler(0, 1); | |
FileSplit fileSplitTrain = new FileSplit(new File(trainPath.toString())); | |
FileSplit fileSplitTest = new FileSplit(new File(testPath.toString())); | |
ParentPathLabelGenerator labelGenerator = new ParentPathLabelGenerator(); | |
BalancedPathFilter balancedPathFilter = new BalancedPathFilter(randNumGen, allowedExtensions, labelGenerator); | |
InputSplit trainData = fileSplitTrain.sample(balancedPathFilter)[0]; | |
InputSplit testData = fileSplitTest.sample(balancedPathFilter)[0]; | |
// image augmentation | |
ImageTransform horizontalFlip = new FlipImageTransform(1); | |
ImageTransform verticalFlip = new FlipImageTransform(0); | |
// ImageTransform cropImage = new CropImageTransform(25); | |
ImageTransform rotateImage = new RotateImageTransform(randNumGen, 15); | |
ImageTransform showImage = new ShowImageTransform("Image", 1000); | |
boolean shuffle = false; | |
List<Pair<ImageTransform, Double>> pipeline = Arrays.asList( | |
new Pair<>(verticalFlip, 0.25), | |
new Pair<>(horizontalFlip, 0.5), | |
new Pair<>(rotateImage, 0.2) | |
// new Pair<>(rotateImage, 0.5), | |
// new Pair<>(cropImage, 0.3) | |
// ,new Pair<>(showImage,1.0) //uncomment this to show transform image | |
); | |
ImageTransform transform = new PipelineImageTransform(pipeline, shuffle); | |
ImageRecordReader trainRecordReader = new ImageRecordReader(224, 224, 3, labelGenerator); | |
ImageRecordReader testRecordReader = new ImageRecordReader(224, 224, 3, labelGenerator); | |
trainRecordReader.initialize(trainData, transform); | |
testRecordReader.initialize(testData); | |
// DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRecordReader, 64, 1, 5); | |
// DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRecordReader, 18, 1, 5); | |
DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRecordReader, 14, 1, 5); | |
DataSetIterator testIter = new RecordReaderDataSetIterator(testRecordReader, 14, 1, 5); | |
trainIter.setPreProcessor(scaler); | |
testIter.setPreProcessor(scaler); | |
ZooModel zooModel = ResNet50.builder().build(); | |
ComputationGraph resnet = (ComputationGraph) zooModel.initPretrained(); | |
// Map<Integer,Double> learningRateSchedule = new HashMap<>(); | |
// learningRateSchedule.put(0,0.0000001); | |
// learningRateSchedule.put(1,0.001000099); | |
// learningRateSchedule.put(2,0.002000098); | |
// learningRateSchedule.put(3,0.003000097); | |
// learningRateSchedule.put(4,0.004000096); | |
// learningRateSchedule.put(5,0.005000095); | |
// learningRateSchedule.put(6,0.006000094); | |
// learningRateSchedule.put(7,0.007000093); | |
// learningRateSchedule.put(8,0.008000092); | |
// learningRateSchedule.put(9,0.009000091); | |
// learningRateSchedule.put(10,0.01000009); | |
// learningRateSchedule.put(11,0.011000089); | |
// learningRateSchedule.put(12,0.012000088); | |
// learningRateSchedule.put(13,0.013000087); | |
// learningRateSchedule.put(14,0.014000086); | |
// learningRateSchedule.put(15,0.015000085); | |
// learningRateSchedule.put(16,0.016000084); | |
// learningRateSchedule.put(17,0.017000083); | |
// learningRateSchedule.put(18,0.018000082); | |
// learningRateSchedule.put(19,0.019000081); | |
// learningRateSchedule.put(20,0.02000008); | |
// learningRateSchedule.put(21,0.021000079); | |
// learningRateSchedule.put(22,0.022000078); | |
// learningRateSchedule.put(23,0.023000077); | |
// learningRateSchedule.put(24,0.024000076); | |
// learningRateSchedule.put(25,0.025000075); | |
// learningRateSchedule.put(26,0.026000074); | |
// learningRateSchedule.put(27,0.027000073); | |
// learningRateSchedule.put(28,0.028000072); | |
// learningRateSchedule.put(29,0.029000071); | |
// learningRateSchedule.put(30,0.03000007); | |
// learningRateSchedule.put(31,0.031000069); | |
// learningRateSchedule.put(32,0.032000068); | |
// learningRateSchedule.put(33,0.033000067); | |
// learningRateSchedule.put(34,0.034000066); | |
// learningRateSchedule.put(35,0.035000065); | |
// learningRateSchedule.put(36,0.036000064); | |
// learningRateSchedule.put(37,0.037000063); | |
// learningRateSchedule.put(38,0.038000062); | |
// learningRateSchedule.put(39,0.039000061); | |
// learningRateSchedule.put(40,0.04000006); | |
// learningRateSchedule.put(41,0.041000059); | |
// learningRateSchedule.put(42,0.042000058); | |
// learningRateSchedule.put(43,0.043000057); | |
// learningRateSchedule.put(44,0.044000056); | |
// learningRateSchedule.put(45,0.045000055); | |
// learningRateSchedule.put(46,0.046000054); | |
// learningRateSchedule.put(47,0.047000053); | |
// learningRateSchedule.put(48,0.048000052); | |
// learningRateSchedule.put(49,0.049000051); | |
// learningRateSchedule.put(50,0.05000005); | |
// learningRateSchedule.put(51,0.0510000490000001); | |
// learningRateSchedule.put(52,0.0520000480000001); | |
// learningRateSchedule.put(53,0.0530000470000001); | |
// learningRateSchedule.put(54,0.0540000460000001); | |
// learningRateSchedule.put(55,0.0550000450000001); | |
// learningRateSchedule.put(56,0.0560000440000001); | |
// learningRateSchedule.put(57,0.0570000430000001); | |
// learningRateSchedule.put(58,0.0580000420000001); | |
// learningRateSchedule.put(59,0.0590000410000001); | |
// learningRateSchedule.put(60,0.0600000400000001); | |
// learningRateSchedule.put(61,0.0610000390000001); | |
// learningRateSchedule.put(62,0.0620000380000001); | |
// learningRateSchedule.put(63,0.0630000370000001); | |
// learningRateSchedule.put(64,0.0640000360000001); | |
// learningRateSchedule.put(65,0.0650000350000001); | |
// learningRateSchedule.put(66,0.0660000340000001); | |
// learningRateSchedule.put(67,0.0670000330000001); | |
// learningRateSchedule.put(68,0.0680000320000001); | |
// learningRateSchedule.put(69,0.0690000310000001); | |
// learningRateSchedule.put(70,0.0700000300000001); | |
// learningRateSchedule.put(71,0.0710000290000001); | |
// learningRateSchedule.put(72,0.0720000280000001); | |
// learningRateSchedule.put(73,0.0730000270000001); | |
// learningRateSchedule.put(74,0.0740000260000001); | |
// learningRateSchedule.put(75,0.0750000250000001); | |
// learningRateSchedule.put(76,0.0760000240000001); | |
// learningRateSchedule.put(77,0.0770000230000001); | |
// learningRateSchedule.put(78,0.0780000220000001); | |
// learningRateSchedule.put(79,0.0790000210000001); | |
// learningRateSchedule.put(80,0.0800000200000001); | |
// learningRateSchedule.put(81,0.0810000190000001); | |
// learningRateSchedule.put(82,0.0820000180000001); | |
// learningRateSchedule.put(83,0.0830000170000001); | |
// learningRateSchedule.put(84,0.0840000160000001); | |
// learningRateSchedule.put(85,0.0850000150000001); | |
// learningRateSchedule.put(86,0.0860000140000001); | |
// learningRateSchedule.put(87,0.0870000130000001); | |
// learningRateSchedule.put(88,0.0880000120000001); | |
// learningRateSchedule.put(89,0.0890000110000001); | |
// learningRateSchedule.put(90,0.0900000100000001); | |
// learningRateSchedule.put(91,0.0910000090000001); | |
// learningRateSchedule.put(92,0.0920000080000001); | |
// learningRateSchedule.put(93,0.0930000070000001); | |
// learningRateSchedule.put(94,0.0940000060000001); | |
// learningRateSchedule.put(95,0.0950000050000001); | |
// learningRateSchedule.put(96,0.0960000040000001); | |
// learningRateSchedule.put(97,0.0970000030000001); | |
// learningRateSchedule.put(98,0.0980000020000001); | |
// learningRateSchedule.put(99,0.0990000010000001); | |
// learningRateSchedule.put(100,0.1); | |
FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder() | |
// .updater(new Adam(1e-2)) | |
// .updater(new Adam(new CycleSchedule(ScheduleType.EPOCH, 3e-5, 3e-4, epochs, (int) Math.round(epochs * 0.1), 0.1))) | |
.updater(new Adam(new CycleSchedule(ScheduleType.EPOCH, 1e-3, 1e-2, epochs, (int) Math.round(epochs * 0.1), 0.1))) | |
// .updater(new Adam(new MapSchedule(ScheduleType.ITERATION,learningRateSchedule))) | |
.seed(seed) | |
.build(); | |
// System.out.println(resnet.summary()); | |
ComputationGraph resnet50Transfer = new TransferLearning.GraphBuilder(resnet) | |
.fineTuneConfiguration(fineTuneConf) | |
.setFeatureExtractor("flatten_1") | |
.removeVertexKeepConnections("fc1000") | |
// .setFeatureExtractor("bn5b_branch2c") //"bn5b_branch2c" and below are frozen | |
// .addLayer("fc", new DenseLayer | |
// .Builder().activation(Activation.RELU).nIn(1000).nOut(32).build(), "fc1000") //add in a new dense layer | |
.addLayer("newpredictions", new OutputLayer | |
.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) | |
// .Builder(LossFunctions.LossFunction.MCXENT) | |
// .activation(Activation.SOFTMAX) | |
.nIn(2048) | |
.nOut(5) | |
.build(), "flatten_1") | |
.setOutputs("newpredictions") | |
.build(); | |
UIServer uiServer = UIServer.getInstance(); | |
StatsStorage statsStorage = new FileStatsStorage(new File(System.getProperty("java.io.tmpdir"), "ui-stats.dl4j")); | |
uiServer.attach(statsStorage); | |
resnet50Transfer.setListeners( | |
// new StatsListener(statsStorage) | |
new StatsListener(statsStorage), | |
new ScoreIterationListener(5) | |
); | |
// for (int i =0; i<100;i++){ | |
// resnet50Transfer.fit(trainIter.next()); | |
// System.out.println(resnet50Transfer.getLearningRate("newpredictions")+","+resnet50Transfer.score()); | |
// } | |
for (int i = 0; i < epochs; i++) { | |
resnet50Transfer.fit(trainIter); | |
// Evaluation trainEval = resnet50Transfer.evaluate(trainIter); | |
// Evaluation testEval = resnet50Transfer.evaluate(testIter); | |
// System.out.println(trainEval.stats()); | |
// System.out.println(testEval.stats()); | |
System.out.println("Completed epoch: " + (i + 1)); | |
} | |
ModelSerializer.writeModel(resnet50Transfer, homePath +"/food_"+resnet50Transfer.score()+".zip", true); | |
Evaluation trainEval = resnet50Transfer.evaluate(trainIter); | |
Evaluation testEval = resnet50Transfer.evaluate(testIter); | |
System.out.println(trainEval.stats()); | |
System.out.println(testEval.stats()); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment