Skip to content

Instantly share code, notes, and snippets.

@androuino
Created September 27, 2018 08:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save androuino/f93f0239e8c2f459fa46962c5f4b52e7 to your computer and use it in GitHub Desktop.
Save androuino/f93f0239e8c2f459fa46962c5f4b52e7 to your computer and use it in GitHub Desktop.
darknet19
package org.deeplearning4j.examples.convolution.objectdetection;
import java.io.File;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.List;
import java.util.Random;
import org.bytedeco.javacv.CanvasFrame;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.api.io.filters.RandomPathFilter;
import org.datavec.api.records.metadata.RecordMetaDataImageURI;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.objdetect.ObjectDetectionRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.examples.convolution.ObjectDetectionLabelProvider;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.objdetect.DetectedObject;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.zoo.model.Darknet19;
import org.deeplearning4j.zoo.model.TinyYOLO;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static org.bytedeco.javacpp.opencv_core.*;
import static org.bytedeco.javacpp.opencv_imgproc.*;
/**
* Example transfer learning from a Tiny YOLO model pretrained on ImageNet and Pascal VOC
* to perform object detection with bounding boxes on The Street View House Numbers (SVHN) Dataset.
* <p>
* References: <br>
* - YOLO: Real-Time Object Detection: https://pjreddie.com/darknet/yolo/ <br>
* - The Street View House Numbers (SVHN) Dataset: http://ufldl.stanford.edu/housenumbers/ <br>
* <p>
* Please note, cuDNN should be used to obtain reasonable performance: https://deeplearning4j.org/cudnn
*
* @author saudet
*/
public class HouseNumberDetection {
private static final Logger log = LoggerFactory.getLogger(HouseNumberDetection.class);
public static void main(String[] args) throws Exception {
// parameters matching the pretrained TinyYOLO model
int width = 416;
int height = 416;
int nChannels = 3;
int gridWidth = 13;
int gridHeight = 13;
int[] inputShape = {3, 416, 416};
// number classes (digits) for the SVHN datasets
int nClasses = 1; //10;
// parameters for the Yolo2OutputLayer
int nBoxes = 5;
double lambdaNoObj = 0.5;
double lambdaCoord = 1.0;
double[][] priorBoxes = {{2, 5}, {2.5, 6}, {3, 7}, {3.5, 8}, {4, 9}};
double detectionThreshold = 0.5;
// parameters for the training phase
int batchSize = 8;
int nEpochs = 10;
double learningRate = 1e-2; // 0.0001
double lrMomentum = 0.9;
int seed = 123;
Random rng = new Random(seed);
String dataDir = "/home/remote/dataset/";
File imgDir = new File(dataDir, "JPEGImages");
log.info("Load data...");
RandomPathFilter pathFilter = new RandomPathFilter(rng) {
@Override
protected boolean accept(String name) {
name = name.replace("/JPEGImages/", "/Annotations/").replace(".jpg", ".xml");
try {
return new File(new URI(name)).exists();
} catch (URISyntaxException ex) {
throw new RuntimeException(ex);
}
}
};
InputSplit[] data = new FileSplit(imgDir, NativeImageLoader.ALLOWED_FORMATS, rng).sample(pathFilter, 0.8, 0.2);
InputSplit trainData = data[0];
InputSplit testData = data[1];
ObjectDetectionRecordReader recordReaderTrain = new ObjectDetectionRecordReader(height, width, nChannels,
gridHeight, gridWidth, new ObjectDetectionLabelProvider(dataDir));
recordReaderTrain.initialize(trainData);
ObjectDetectionRecordReader recordReaderTest = new ObjectDetectionRecordReader(height, width, nChannels,
gridHeight, gridWidth, new ObjectDetectionLabelProvider(dataDir));
recordReaderTest.initialize(testData);
// ObjectDetectionRecordReader performs regression, so we need to specify it here
RecordReaderDataSetIterator train = new RecordReaderDataSetIterator(recordReaderTrain, batchSize, 1, 1, true);
train.setPreProcessor(new ImagePreProcessingScaler(0, 1));
RecordReaderDataSetIterator test = new RecordReaderDataSetIterator(recordReaderTest, 1, 1, 1, true);
test.setPreProcessor(new ImagePreProcessingScaler(0, 1));
ComputationGraph model;
String modelFilename = "resources/darknet19.zip";
if (new File(modelFilename).exists()) {
log.info("Load model...");
model = ModelSerializer.restoreComputationGraph(modelFilename);
} else {
log.info("Build model...");
ComputationGraph pretrained = (ComputationGraph) Darknet19.builder()
.numClasses(30)
.updater(new Nesterovs(learningRate, lrMomentum))
.workspaceMode(WorkspaceMode.NONE)
.inputShape(inputShape)
.build()
.init();
INDArray priors = Nd4j.create(priorBoxes);
pretrained.setInputs(priors);
//model = new TransferLearning.GraphBuilder(pretrained).build();
model = pretrained;
System.out.println(model.summary(InputType.convolutional(height, width, nChannels)));
log.info("Train model...");
model.setListeners(new ScoreIterationListener(1));
for (int i = 0; i < nEpochs; i++) {
train.reset();
while (train.hasNext()) {
model.fit(train.next());
}
log.info("*** Completed epoch {} ***", i);
}
// save the model
ModelSerializer.writeModel(model, modelFilename, true);
}
// visualize results on the test set
NativeImageLoader imageLoader = new NativeImageLoader();
CanvasFrame frame = new CanvasFrame("HouseNumberDetection");
OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer yout =
(org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer)model.getOutputLayer(0);
List<String> labels = train.getLabels();
test.setCollectMetaData(true);
while (test.hasNext() && frame.isVisible()) {
org.nd4j.linalg.dataset.DataSet ds = test.next();
RecordMetaDataImageURI metadata = (RecordMetaDataImageURI)ds.getExampleMetaData().get(0);
INDArray features = ds.getFeatures();
INDArray results = model.outputSingle(features);
List<DetectedObject> objs = yout.getPredictedObjects(results, detectionThreshold);
File file = new File(metadata.getURI());
log.info(file.getName() + ": " + objs);
Mat mat = imageLoader.asMat(features);
Mat convertedMat = new Mat();
mat.convertTo(convertedMat, CV_8U, 255, 0);
int w = metadata.getOrigW() * 2;
int h = metadata.getOrigH() * 2;
Mat image = new Mat();
resize(convertedMat, image, new Size(w, h));
for (DetectedObject obj : objs) {
double[] xy1 = obj.getTopLeftXY();
double[] xy2 = obj.getBottomRightXY();
String label = labels.get(obj.getPredictedClass());
int x1 = (int) Math.round(w * xy1[0] / gridWidth);
int y1 = (int) Math.round(h * xy1[1] / gridHeight);
int x2 = (int) Math.round(w * xy2[0] / gridWidth);
int y2 = (int) Math.round(h * xy2[1] / gridHeight);
rectangle(image, new Point(x1, y1), new Point(x2, y2), Scalar.RED);
putText(image, label, new Point(x1 + 2, y2 - 2), FONT_HERSHEY_DUPLEX, 1, Scalar.GREEN);
}
frame.setTitle(new File(metadata.getURI()).getName() + " - HouseNumberDetection");
frame.setCanvasSize(w, h);
frame.showImage(converter.convert(image));
frame.waitKey();
}
frame.dispose();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment