Skip to content

Instantly share code, notes, and snippets.

@choowilson
Created February 17, 2020 10:02
load pretrained yolov2,darknet-19 and perform real time inference on video and save to another new output video file.
import org.bytedeco.ffmpeg.global.avcodec;
import org.bytedeco.ffmpeg.global.avutil;
import org.bytedeco.javacv.*;
import org.bytedeco.opencv.opencv_core.*;
import static org.bytedeco.opencv.global.opencv_imgproc.*;
import static org.bytedeco.opencv.helper.opencv_core.RGB;
import org.bytedeco.opencv.opencv_videoio.VideoWriter;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.transform.ColorConversionTransform;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.objdetect.DetectedObject;
import org.deeplearning4j.nn.layers.objdetect.YoloUtils;
import org.deeplearning4j.zoo.ZooModel;
import org.deeplearning4j.zoo.model.TinyYOLO;
import org.deeplearning4j.zoo.model.YOLO2;
import org.deeplearning4j.zoo.util.darknet.COCOLabels;
import org.deeplearning4j.zoo.util.darknet.VOCLabels;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.opencv.video.Video;
import java.awt.event.KeyEvent;
import java.util.List;
import java.util.Random;
public class VideoObjectDetection
{
private static int seed = 123;
private static Thread thread;
private static final int gridWidth = 13;
private static final int gridHeight = 13;
private static double detectionThreshold = 0.5;
private static final int tinyyolowidth = 416;
private static final int tinyyoloheight = 416;
private static Random rng = new Random();
//// First 20 colours
// private static final Scalar AliceBlue = RGB(240, 248, 255);
// private static final Scalar Chartreuse = RGB(127, 255, 0);
// private static final Scalar Aqua = RGB(0, 255, 255);
// private static final Scalar Aquamarine = RGB(127, 255, 212);
// private static final Scalar Azure = RGB(240, 255, 255);
// private static final Scalar Beige = RGB(245, 245, 220);
// private static final Scalar Bisque = RGB(255, 228, 196);
// private static final Scalar BlanchedAlmond = RGB(255, 235, 205);
// private static final Scalar BlueViolet = RGB(138, 43, 226);
// private static final Scalar BurlyWood = RGB(222, 184, 135 );
// private static final Scalar CadetBlue = RGB(95, 158, 160);
// private static final Scalar AntiqueWhite = RGB( 250, 235, 215);
// private static final Scalar Chocolate= RGB(210, 105, 30);
// private static final Scalar Coral = RGB(255, 127, 80);
// private static final Scalar CornflowerBlue = RGB(100, 149, 237);
// private static final Scalar Cornsilk = RGB(255, 248, 220);
// private static final Scalar Crimson = RGB( 220, 20, 60);
// private static final Scalar Cyan = RGB(0, 255, 255);
// private static final Scalar DarkCyan = RGB( 0, 139, 139);
// private static final Scalar DarkGoldenRod = RGB(72.2, 52.5, 4.3);
//
//// 21st to 40th colours
// private static final Scalar DarkGray = RGB(169, 169, 169);
// private static final Scalar DarkKhaki = RGB(189, 183, 107);
// private static final Scalar DarkOrange= RGB(255, 140, 0);
// private static final Scalar DarkOrchid= RGB(153, 50, 204);
// private static final Scalar DarkSalmon= RGB(233, 150, 122);
// private static final Scalar DarkSeaGreen= RGB(0, 206, 209);
// private static final Scalar DarkTurquoise= RGB(19, 84, 78);
// private static final Scalar DarkViolet= RGB(148, 0, 211);
// private static final Scalar DeepPink= RGB(255, 20, 147);
// private static final Scalar DeepSkyBlue= RGB(0, 191, 255);
// private static final Scalar DodgerBlue= RGB(30, 144, 255);
// private static final Scalar FireBrick= RGB(178, 34, 34);
// private static final Scalar FloralWhite= RGB(255, 250, 240);
// private static final Scalar ForestGreen= RGB( 34, 139, 34);
// private static final Scalar Fuchsia= RGB(255, 0, 255);
// private static final Scalar Gainsboro= RGB(220, 220, 220);
// private static final Scalar GhostWhite= RGB(248, 248, 255);
// private static final Scalar Gold = RGB(255, 215, 0);
// private static final Scalar GoldenRod = RGB(218, 165, 32);
// private static final Scalar Salmon = RGB( 250, 128, 114);
//
// // 41st to 60th colour
// private static final Scalar Tan = RGB(210, 180, 140 );
// private static final Scalar HoneyDew = RGB(240, 255, 240 );
// private static final Scalar HotPink = RGB(255, 105, 180);
// private static final Scalar IndianRed = RGB(205, 92, 92 );
// private static final Scalar Ivory = RGB(255, 255, 240);
// private static final Scalar Khaki = RGB(240, 230, 140);
// private static final Scalar Lavender = RGB(230, 230, 250 );
// private static final Scalar LavenderBlush = RGB(255, 240, 245 );
// private static final Scalar LawnGreen = RGB(124, 252, 0 );
// private static final Scalar LemonChiffon = RGB(255, 250, 205);
// private static final Scalar LightBlue = RGB(173, 216, 230);
// private static final Scalar LightCoral = RGB(240, 128, 128);
// private static final Scalar LightCyan = RGB(224, 255, 255);
// private static final Scalar LightGoldenRodYellow = RGB( 250, 250, 210);
// private static final Scalar LightGray = RGB(211, 211, 211 );
// private static final Scalar LightGrey = RGB(211, 211, 211);
// private static final Scalar LightGreen = RGB(144, 238, 144);
// private static final Scalar LightPink = RGB(255, 182, 193);
// private static final Scalar LightSalmon = RGB(255, 160, 122);
// private static final Scalar LightSeaGreen = RGB(32, 178, 170);
// // 61st to 80th colour
// private static final Scalar LightSkyBlue = RGB(135, 206, 250);
// private static final Scalar LightSlateGray = RGB( 119, 136, 153);
// private static final Scalar LightSlateGrey = RGB(119, 136, 153);
// private static final Scalar LightSteelBlue = RGB(176, 196, 222);
// private static final Scalar LightYellow = RGB(255, 255, 224);
// private static final Scalar Lime = RGB( 0, 255, 0);
// private static final Scalar LimeGreen = RGB(50, 205, 50);
// private static final Scalar Linen = RGB( 250, 240, 230);
// private static final Scalar Magenta = RGB( 255, 0, 255);
// private static final Scalar MediumAquaMarine = RGB( 102, 205, 170);
// private static final Scalar MediumOrchid = RGB( 186, 85, 211);
// private static final Scalar MediumPurple = RGB(147, 112, 219);
// private static final Scalar MediumSeaGreen = RGB(60, 179, 113);
// private static final Scalar MediumSlateBlue = RGB(123, 104, 238);
// private static final Scalar MediumSpringGreen = RGB(0, 250, 154);
// private static final Scalar MediumTurquoise = RGB(72, 209, 204);
// private static final Scalar MediumVioletRed = RGB(199, 21, 133);
// private static final Scalar MintCream = RGB(245, 255, 250);
// private static final Scalar MistyRose = RGB(255, 228, 225);
// private static final Scalar Moccasin = RGB(255, 228, 181);
//
// private static Scalar[] colormap = {
// AliceBlue,Chartreuse,Aqua,Aquamarine,Azure,
// Beige,Bisque,BlanchedAlmond,BlueViolet,BurlyWood,
// CadetBlue,AntiqueWhite,Chocolate,Coral,CornflowerBlue,
// Cornsilk,Crimson,Cyan, DarkCyan,DarkGoldenRod,
//
// DarkGray, DarkKhaki, DarkOrange, DarkOrchid, DarkSalmon,
// DarkSeaGreen, DarkTurquoise,DarkViolet, DeepPink, DeepSkyBlue,
// DodgerBlue, FireBrick, FloralWhite, ForestGreen, Fuchsia,
// Gainsboro, GhostWhite, Gold, GoldenRod, Salmon,
//
// Tan,HoneyDew, HotPink, IndianRed, Ivory,
// Khaki, Lavender, LavenderBlush, LawnGreen, LemonChiffon,
// LightBlue, LightCoral, LightCyan, LightGoldenRodYellow, LightGray,
// LightGrey, LightGreen, LightPink, LightSalmon, LightSeaGreen,
//
// LightSkyBlue,LightSlateGray, LightSlateGrey, LightSteelBlue,LightYellow,
// Lime,LimeGreen, Linen, Magenta, MediumAquaMarine,
// MediumOrchid,MediumPurple, MediumSeaGreen, MediumSlateBlue,MediumSpringGreen,
// MediumTurquoise,MediumVioletRed, MintCream, MistyRose,Moccasin};
//
private static String labeltext = null;
public static void main(String[] args) throws Exception {
// Create a colormap by using a RNG to randomly generate RGB colours
Scalar[] colormap = new Scalar[80];
for (int i=0;i<80;i++){
int redValue= rng.nextInt(256);
int greenValue= rng.nextInt(256);
int blueValue= rng.nextInt(256);
colormap[i]=RGB(redValue,greenValue,blueValue);
}
String videoPath = "C:\\Users\\choowilson\\Downloads\\pavilion2.mp4";
FFmpegFrameGrabber grabber = new FFmpegFrameGrabber(videoPath);
grabber.setFormat("mp4");
OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
grabber.start();
FFmpegFrameRecorder recorder = new FFmpegFrameRecorder("output.mp4", 1920, 1080, 0);
recorder.setVideoCodec(avcodec.AV_CODEC_ID_MPEG4);
recorder.setVideoBitrate(9000);
recorder.setFormat("mp4");
recorder.setVideoQuality(0); // maximum quality
recorder.setFrameRate(15);
recorder.start();
String winName = "Object Detection";
CanvasFrame canvas = new CanvasFrame(winName);
int w = grabber.getImageWidth();
int h = grabber.getImageHeight();
canvas.setCanvasSize(w, h);
ZooModel model = YOLO2.builder().numClasses(0).build();
ComputationGraph initializedModel = (ComputationGraph) model.initPretrained();
System.out.println(initializedModel.summary());
NativeImageLoader loader = new NativeImageLoader(tinyyolowidth, tinyyoloheight, 3, new ColorConversionTransform(COLOR_BGR2RGB));
ImagePreProcessingScaler scaler = new ImagePreProcessingScaler(0, 1);
COCOLabels labels = new COCOLabels();
System.out.println("Start running video");
while ((grabber.grab()) != null)
{
Frame frame = grabber.grabImage();
//if a thread is null, create new thread
Mat rawImage = converter.convert(frame);
Mat resizeImage = new Mat();//rawImage);
resize(rawImage, resizeImage, new Size(tinyyolowidth, tinyyoloheight));
INDArray inputImage = loader.asMatrix(resizeImage);
scaler.transform(inputImage);
INDArray outputs = initializedModel.outputSingle(inputImage);
List<DetectedObject> objs = YoloUtils.getPredictedObjects(Nd4j.create(((YOLO2) model).getPriorBoxes()), outputs, detectionThreshold, 0.4);
for (DetectedObject obj : objs) {
double[] xy1 = obj.getTopLeftXY();
double[] xy2 = obj.getBottomRightXY();
String label = labels.getLabel(obj.getPredictedClass());
int x1 = (int) Math.round(w * xy1[0] / gridWidth);
int y1 = (int) Math.round(h * xy1[1] / gridHeight);
int x2 = (int) Math.round(w * xy2[0] / gridWidth);
int y2 = (int) Math.round(h * xy2[1] / gridHeight);
//Draw bounding box
rectangle(rawImage, new Point(x1, y1), new Point(x2, y2), colormap[obj.getPredictedClass()], 2, 0, 0);
//Display label text
labeltext =label+" "+(Math.round(obj.getConfidence()*100.0)*100.0)/100.0+"%";
int[] baseline ={0};
Size textSize=getTextSize(labeltext, FONT_HERSHEY_DUPLEX, 1,1,baseline);
rectangle(rawImage, new Point(x1 + 2, y2 - 2), new Point(x1 + 2+textSize.get(0), y2 - 2-textSize.get(1)), colormap[obj.getPredictedClass()], FILLED,0,0);
putText(rawImage, labeltext, new Point(x1 + 2, y2 - 2), FONT_HERSHEY_DUPLEX, 1, RGB(0,0,0));
}
recorder.record(converter.convert(rawImage));
canvas.showImage(converter.convert(rawImage));
KeyEvent t = canvas.waitKey(33);
if ((t != null) && (t.getKeyCode() == KeyEvent.VK_Q)) {
recorder.stop();
break;
}
}
recorder.stop();
canvas.dispose();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment