Skip to content

Instantly share code, notes, and snippets.

@docete
Last active October 11, 2022 12:38
Show Gist options
  • Save docete/38bb39390da8e8c823c3211419b9a5ad to your computer and use it in GitHub Desktop.
Save docete/38bb39390da8e8c823c3211419b9a5ad to your computer and use it in GitHub Desktop.
import static org.opencv.core.CvType.CV_8U;
import static org.opencv.imgcodecs.Imgcodecs.IMREAD_COLOR;
import static org.opencv.imgproc.Imgproc.CHAIN_APPROX_SIMPLE;
import static org.opencv.imgproc.Imgproc.LINE_AA;
import static org.opencv.imgproc.Imgproc.RETR_TREE;
import static org.opencv.imgproc.Imgproc.THRESH_BINARY;
import static org.opencv.imgproc.Imgproc.fillConvexPoly;
import static org.opencv.imgproc.Imgproc.findContours;
import java.net.URL;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;
import javafx.util.Pair;
import org.opencv.core.Core;
import org.opencv.core.Core.MinMaxLocResult;
import org.opencv.core.Mat;
import org.opencv.core.MatOfPoint;
import org.opencv.core.Point;
import org.opencv.core.Scalar;
import org.opencv.core.Size;
import org.opencv.dnn.Dnn;
import org.opencv.dnn.Net;
import org.opencv.highgui.HighGui;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
public class OpenposeApp {
static {
System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
}
final static int nPoints = 18;
// 18 key points of COCO output format
final static String[] KEYPOINTS_MAPPING = {
"Nose", "Neck",
"R-Sho", "R-Elb", "R-Wr",
"L-Sho", "L-Elb", "L-Wr",
"R-Hip", "R-Knee", "R-Ank",
"L-Hip", "L-Knee", "L-Ank",
"R-Eye", "L-Eye",
"R-Ear", "L-Ear"
};
// 19 possible limb connections
final static int POSE_PAIRS[][] = {
{1,2}, {1,5}, {2,3},
{3,4}, {5,6}, {6,7},
{1,8}, {8,9}, {9,10},
{1,11}, {11,12}, {12,13},
{1,0}, {0,14},
{14,16}, {0,15}, {15,17},
{2, 17}, {5, 16}
};
// indices of PAFs corresponding to pose pairs
// e.g. for POSE_PAIRS {1, 2}, the PAFs are located at indices (31,32) of output.
final static int MAP_INDICES[][] = {
{31,32}, {39,40}, {33,34},
{35,36}, {41,42}, {43,44},
{19,20}, {21,22}, {23,24},
{25,26}, {27,28}, {29,30},
{47,48}, {49,50},
{53,54}, {51,52}, {55,56},
{37,38}, {45,46}
};
public static void main(String[] args) throws Exception {
// read the network model
Net net = Dnn.readNetFromCaffe(
getResourceFilePath("models/pose/coco/pose_deploy_linevec.prototxt"),
getResourceFilePath("models/pose/coco/pose_iter_440000.caffemodel"));
Mat frame = Imgcodecs.imread(getResourceFilePath("images/three_person.jpg"), IMREAD_COLOR);
Mat inputBlob = Dnn.blobFromImage(frame, 1.0 / 255, new Size((368.0 * frame.cols() / frame.rows()),368), new Scalar(0, 0, 0), true, false);
net.setInput(inputBlob);
Mat output = net.forward();
List<Mat> parts = splitNetOutputBlobToParts(output, new Size(frame.cols(), frame.rows()));
int keyPointId = 0;
List<List<KeyPoint>> detectedKeypoints = new ArrayList<>(); // group by positions ?
List<KeyPoint> keyPointsList = new ArrayList<>();
for (int i = 0; i < nPoints; i++) {
List<KeyPoint> keyPoints = getKeyPoints(parts.get(i), 0.1);
for (int j = 0; j < keyPoints.size(); j++, keyPointId++) {
keyPoints.get(j).setId(keyPointId);
}
System.out.println(String.format("Keypoints - %s : %s", KEYPOINTS_MAPPING[i], keyPoints));
detectedKeypoints.add(keyPoints);
keyPointsList.addAll(keyPoints);
}
// Draw key points with random colors
List<Scalar> colors = populateColorPalette(nPoints);
Mat outputFrame = frame.clone();
for (int i = 0; i < nPoints; i++) {
for (int j = 0; j < detectedKeypoints.get(i).size(); j++) {
Imgproc.circle(outputFrame, detectedKeypoints.get(i).get(j).point, 5, colors.get(i), -1, LINE_AA);
Imgproc.putText(
outputFrame,
Integer.toString(detectedKeypoints.get(i).get(j).id),
detectedKeypoints.get(i).get(j).point,
Imgproc.FONT_HERSHEY_SIMPLEX,
0.3,
new Scalar(0, 0, 0));
}
}
List<List<ValidPair>> validPairs = new ArrayList<>();
Set<Integer> invalidPairs = new HashSet<>();
getValidPairs(parts, detectedKeypoints, validPairs, invalidPairs);
List<List<Integer>> personwiseKeypoints = getPersonwiseKeypoints(validPairs, invalidPairs);
for (int i = 0; i < nPoints - 1; i++) {
for (int n = 0; n < personwiseKeypoints.size(); n++) {
int indexA = personwiseKeypoints.get(n).get(POSE_PAIRS[i][0]);
int indexB = personwiseKeypoints.get(n).get(POSE_PAIRS[i][1]);
if (indexA == -1 || indexB == -1) {
continue;
}
KeyPoint kpA = keyPointsList.get(indexA);
KeyPoint kpB = keyPointsList.get(indexB);
Imgproc.line(outputFrame, kpA.point, kpB.point, colors.get(i), 3, LINE_AA);
}
}
HighGui.imshow("show", outputFrame);
HighGui.waitKey(0);
}
public static String getResourceFilePath(String fileName) throws Exception {
ClassLoader classLoader = OpenposeDemo.class.getClassLoader();
URL resource = classLoader.getResource(fileName);
return resource.toURI().getPath();
}
public static List<Mat> splitNetOutputBlobToParts(Mat netOutputBlob, Size targetSize) {
int nParts = netOutputBlob.size(1); // 18 KeyPoints + 1 Background + 19 * 2 PAFs
int h = netOutputBlob.size(2);
int w = netOutputBlob.size(3);
netOutputBlob = netOutputBlob.reshape(1, nParts);
List<Mat> parts = new ArrayList<>();
for(int i = 0; i < nParts; i++) {
Mat part = netOutputBlob.row(i).reshape(1, 46);
Mat resizedPart = new Mat();
Imgproc.resize(part, resizedPart, targetSize);
parts.add(resizedPart);
}
return parts;
}
public static List<KeyPoint> getKeyPoints(Mat probMap, double threshold) {
Mat smoothProbMap = new Mat();
Imgproc.GaussianBlur(probMap, smoothProbMap, new Size(3, 3), 0, 0);
Mat maskedProbMap = new Mat();
Imgproc.threshold(smoothProbMap, maskedProbMap, threshold, 255, THRESH_BINARY);
maskedProbMap.convertTo(maskedProbMap, CV_8U, 1);
List<MatOfPoint> contours = new ArrayList<>();
Mat hierarchy = new Mat();
findContours(maskedProbMap, contours, hierarchy, RETR_TREE, CHAIN_APPROX_SIMPLE);
List<KeyPoint> keyPoints = new ArrayList<>();
for (int i = 0; i < contours.size(); i++) {
Mat blobMask = Mat.zeros(smoothProbMap.rows(), smoothProbMap.cols(), smoothProbMap.type());
fillConvexPoly(blobMask, contours.get(i), new Scalar(1));
MinMaxLocResult result = Core.minMaxLoc(smoothProbMap.mul(blobMask));
keyPoints.add(new KeyPoint(-1, result.maxLoc, (float) probMap.get((int)result.maxLoc.y, (int) result.maxLoc.x)[0]));
}
return keyPoints;
}
public static List<Scalar> populateColorPalette(int nColors) {
List<Scalar> colors = new ArrayList<>();
Random rnd = new Random();
for (int i = 0; i < nColors; i++) {
colors.add(new Scalar(
(double) (rnd.nextInt(200 - 64) + 64),
(double) (rnd.nextInt(255 - 100) + 100),
(double) (rnd.nextInt(255 - 100) + 100)));
}
return colors;
}
public static void getValidPairs(
final List<Mat> netOutputParts,
final List<List<KeyPoint>> detectedKeypoints,
List<List<ValidPair>> validPairs,
Set<Integer> invalidPairs) {
int nInterpSamples = 10;
float pafScoreTh = 0.1f;
float confTh = 0.7f;
for (int k = 0; k < MAP_INDICES.length; k++) {
//A->B constitute a limb
Mat pafA = netOutputParts.get(MAP_INDICES[k][0]);
Mat pafB = netOutputParts.get(MAP_INDICES[k][1]);
List<KeyPoint> candA = detectedKeypoints.get(POSE_PAIRS[k][0]);
List<KeyPoint> candB = detectedKeypoints.get(POSE_PAIRS[k][1]);
int nA = candA.size();
int nB = candB.size();
if (nA != 0 && nB != 0) {
List<ValidPair> localValidPairs = new ArrayList<>();
for (int i = 0; i < nA; i++) {
int maxJ = -1;
float maxScore = -1f;
boolean found = false;
for (int j = 0; j < nB; j++) {
float distanceX = (float) (candB.get(j).point.x - candA.get(i).point.x);
float distanceY = (float) (candB.get(j).point.y - candA.get(i).point.y);
float norm = (float) Math.sqrt(distanceX * distanceX + distanceY * distanceY);
if (Math.signum(norm) == 0) {
continue;
}
distanceX /= norm;
distanceY /= norm;
// find p(u)
List<Point> interpCoords = populateInterpPoints(candA.get(i).point, candB.get(j).point, nInterpSamples);
// find L(p(u))
List<Pair<Float, Float>> pafInterp = new ArrayList<>();
for (int l = 0; l < interpCoords.size(); l++) {
pafInterp.add(
new Pair<Float, Float>(
(float)pafA.get((int)interpCoords.get(l).y, (int)interpCoords.get(l).x)[0],
(float)pafB.get((int)interpCoords.get(l).y, (int)interpCoords.get(l).x)[0]));
}
List<Float> pafScores = new ArrayList<>();
float sumOfPafScores = 0f;
int numOverTh = 0;
for (int l = 0; l < pafInterp.size(); l++) {
float score = pafInterp.get(l).getKey() * distanceX + pafInterp.get(l).getValue() * distanceY;
sumOfPafScores += score;
if (score > pafScoreTh) {
numOverTh += 1;
}
pafScores.add(score);
}
float avgPafScore = sumOfPafScores/((float)pafInterp.size());
if (((float)numOverTh / (float)nInterpSamples) > confTh) {
if (avgPafScore > maxScore) {
maxJ = j;
maxScore = avgPafScore;
found = true;
}
}
} // end of j
if (found) {
localValidPairs.add(new ValidPair(candA.get(i).id, candB.get(maxJ).id, maxScore));
}
} // end of i
validPairs.add(localValidPairs);
} else {
invalidPairs.add(k);
validPairs.add(new ArrayList<>());
}
} // end of k
}
public static List<Point> populateInterpPoints(Point a, Point b, int numPoints) {
List<Point> interpCoords = new ArrayList<>();
float xStep = ((float)(b.x - a.x))/(float)(numPoints-1);
float yStep = ((float)(b.y - a.y))/(float)(numPoints-1);
interpCoords.add(a);
for (int i = 0; i < numPoints; i++) {
interpCoords.add(new Point(a.x + xStep * i, a.y + yStep * i));
}
interpCoords.add(b);
return interpCoords;
}
public static List<List<Integer>> getPersonwiseKeypoints(
List<List<ValidPair>> validPairs, Set<Integer> invalidPairs) {
List<List<Integer>> personwiseKeypoints = new ArrayList<>();
for (int k = 0; k < MAP_INDICES.length; k++) {
if (invalidPairs.contains(k)) {
continue;
}
List<ValidPair> localValidPairs = validPairs.get(k);
int indexA = POSE_PAIRS[k][0];
int indexB = POSE_PAIRS[k][1];
for (int i = 0; i < localValidPairs.size(); i++) {
boolean found = false;
int personIdx = -1;
for (int j = 0; !found && j < personwiseKeypoints.size(); j++) {
if (indexA < personwiseKeypoints.get(j).size()
&& personwiseKeypoints.get(j).get(indexA) == localValidPairs.get(i).aId) {
personIdx = j;
found = true;
}
}
if (found) {
personwiseKeypoints.get(personIdx).set(indexB, localValidPairs.get(i).bId);
} else {
List<Integer> lpkp = new ArrayList<Integer>(18);
for (int l = 0; l < 18; l++) {
lpkp.add(-1);
}
lpkp.set(indexA, localValidPairs.get(i).aId);
lpkp.set(indexB, localValidPairs.get(i).bId);
personwiseKeypoints.add(lpkp);
}
}
}
return personwiseKeypoints;
}
private static class KeyPoint {
private int id;
private Point point;
private float probability;
public KeyPoint(int id, Point point, float probability) {
this.id = id;
this.point = point;
this.probability = probability;
}
public void setId(int id) {
this.id = id;
}
@Override
public String toString() {
return String.format("Id: %d, Point: %s, Prob: %f", id, point, probability);
}
}
private static class ValidPair {
private int aId;
private int bId;
private float score;
public ValidPair(int aId, int bId, float score) {
this.aId = aId;
this.bId = bId;
this.score = score;
}
@Override
public String toString() {
return String.format("A: %d, B: %d, Score: %f", aId, bId, score);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment