Last active
October 11, 2022 12:38
-
-
Save docete/38bb39390da8e8c823c3211419b9a5ad to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import static org.opencv.core.CvType.CV_8U; | |
import static org.opencv.imgcodecs.Imgcodecs.IMREAD_COLOR; | |
import static org.opencv.imgproc.Imgproc.CHAIN_APPROX_SIMPLE; | |
import static org.opencv.imgproc.Imgproc.LINE_AA; | |
import static org.opencv.imgproc.Imgproc.RETR_TREE; | |
import static org.opencv.imgproc.Imgproc.THRESH_BINARY; | |
import static org.opencv.imgproc.Imgproc.fillConvexPoly; | |
import static org.opencv.imgproc.Imgproc.findContours; | |
import java.net.URL; | |
import java.util.ArrayList; | |
import java.util.HashSet; | |
import java.util.List; | |
import java.util.Random; | |
import java.util.Set; | |
import javafx.util.Pair; | |
import org.opencv.core.Core; | |
import org.opencv.core.Core.MinMaxLocResult; | |
import org.opencv.core.Mat; | |
import org.opencv.core.MatOfPoint; | |
import org.opencv.core.Point; | |
import org.opencv.core.Scalar; | |
import org.opencv.core.Size; | |
import org.opencv.dnn.Dnn; | |
import org.opencv.dnn.Net; | |
import org.opencv.highgui.HighGui; | |
import org.opencv.imgcodecs.Imgcodecs; | |
import org.opencv.imgproc.Imgproc; | |
public class OpenposeApp { | |
static { | |
System.loadLibrary(Core.NATIVE_LIBRARY_NAME); | |
} | |
final static int nPoints = 18; | |
// 18 key points of COCO output format | |
final static String[] KEYPOINTS_MAPPING = { | |
"Nose", "Neck", | |
"R-Sho", "R-Elb", "R-Wr", | |
"L-Sho", "L-Elb", "L-Wr", | |
"R-Hip", "R-Knee", "R-Ank", | |
"L-Hip", "L-Knee", "L-Ank", | |
"R-Eye", "L-Eye", | |
"R-Ear", "L-Ear" | |
}; | |
// 19 possible limb connections | |
final static int POSE_PAIRS[][] = { | |
{1,2}, {1,5}, {2,3}, | |
{3,4}, {5,6}, {6,7}, | |
{1,8}, {8,9}, {9,10}, | |
{1,11}, {11,12}, {12,13}, | |
{1,0}, {0,14}, | |
{14,16}, {0,15}, {15,17}, | |
{2, 17}, {5, 16} | |
}; | |
// indices of PAFs corresponding to pose pairs | |
// e.g. for POSE_PAIRS {1, 2}, the PAFs are located at indices (31,32) of output. | |
final static int MAP_INDICES[][] = { | |
{31,32}, {39,40}, {33,34}, | |
{35,36}, {41,42}, {43,44}, | |
{19,20}, {21,22}, {23,24}, | |
{25,26}, {27,28}, {29,30}, | |
{47,48}, {49,50}, | |
{53,54}, {51,52}, {55,56}, | |
{37,38}, {45,46} | |
}; | |
public static void main(String[] args) throws Exception { | |
// read the network model | |
Net net = Dnn.readNetFromCaffe( | |
getResourceFilePath("models/pose/coco/pose_deploy_linevec.prototxt"), | |
getResourceFilePath("models/pose/coco/pose_iter_440000.caffemodel")); | |
Mat frame = Imgcodecs.imread(getResourceFilePath("images/three_person.jpg"), IMREAD_COLOR); | |
Mat inputBlob = Dnn.blobFromImage(frame, 1.0 / 255, new Size((368.0 * frame.cols() / frame.rows()),368), new Scalar(0, 0, 0), true, false); | |
net.setInput(inputBlob); | |
Mat output = net.forward(); | |
List<Mat> parts = splitNetOutputBlobToParts(output, new Size(frame.cols(), frame.rows())); | |
int keyPointId = 0; | |
List<List<KeyPoint>> detectedKeypoints = new ArrayList<>(); // group by positions ? | |
List<KeyPoint> keyPointsList = new ArrayList<>(); | |
for (int i = 0; i < nPoints; i++) { | |
List<KeyPoint> keyPoints = getKeyPoints(parts.get(i), 0.1); | |
for (int j = 0; j < keyPoints.size(); j++, keyPointId++) { | |
keyPoints.get(j).setId(keyPointId); | |
} | |
System.out.println(String.format("Keypoints - %s : %s", KEYPOINTS_MAPPING[i], keyPoints)); | |
detectedKeypoints.add(keyPoints); | |
keyPointsList.addAll(keyPoints); | |
} | |
// Draw key points with random colors | |
List<Scalar> colors = populateColorPalette(nPoints); | |
Mat outputFrame = frame.clone(); | |
for (int i = 0; i < nPoints; i++) { | |
for (int j = 0; j < detectedKeypoints.get(i).size(); j++) { | |
Imgproc.circle(outputFrame, detectedKeypoints.get(i).get(j).point, 5, colors.get(i), -1, LINE_AA); | |
Imgproc.putText( | |
outputFrame, | |
Integer.toString(detectedKeypoints.get(i).get(j).id), | |
detectedKeypoints.get(i).get(j).point, | |
Imgproc.FONT_HERSHEY_SIMPLEX, | |
0.3, | |
new Scalar(0, 0, 0)); | |
} | |
} | |
List<List<ValidPair>> validPairs = new ArrayList<>(); | |
Set<Integer> invalidPairs = new HashSet<>(); | |
getValidPairs(parts, detectedKeypoints, validPairs, invalidPairs); | |
List<List<Integer>> personwiseKeypoints = getPersonwiseKeypoints(validPairs, invalidPairs); | |
for (int i = 0; i < nPoints - 1; i++) { | |
for (int n = 0; n < personwiseKeypoints.size(); n++) { | |
int indexA = personwiseKeypoints.get(n).get(POSE_PAIRS[i][0]); | |
int indexB = personwiseKeypoints.get(n).get(POSE_PAIRS[i][1]); | |
if (indexA == -1 || indexB == -1) { | |
continue; | |
} | |
KeyPoint kpA = keyPointsList.get(indexA); | |
KeyPoint kpB = keyPointsList.get(indexB); | |
Imgproc.line(outputFrame, kpA.point, kpB.point, colors.get(i), 3, LINE_AA); | |
} | |
} | |
HighGui.imshow("show", outputFrame); | |
HighGui.waitKey(0); | |
} | |
public static String getResourceFilePath(String fileName) throws Exception { | |
ClassLoader classLoader = OpenposeDemo.class.getClassLoader(); | |
URL resource = classLoader.getResource(fileName); | |
return resource.toURI().getPath(); | |
} | |
public static List<Mat> splitNetOutputBlobToParts(Mat netOutputBlob, Size targetSize) { | |
int nParts = netOutputBlob.size(1); // 18 KeyPoints + 1 Background + 19 * 2 PAFs | |
int h = netOutputBlob.size(2); | |
int w = netOutputBlob.size(3); | |
netOutputBlob = netOutputBlob.reshape(1, nParts); | |
List<Mat> parts = new ArrayList<>(); | |
for(int i = 0; i < nParts; i++) { | |
Mat part = netOutputBlob.row(i).reshape(1, 46); | |
Mat resizedPart = new Mat(); | |
Imgproc.resize(part, resizedPart, targetSize); | |
parts.add(resizedPart); | |
} | |
return parts; | |
} | |
public static List<KeyPoint> getKeyPoints(Mat probMap, double threshold) { | |
Mat smoothProbMap = new Mat(); | |
Imgproc.GaussianBlur(probMap, smoothProbMap, new Size(3, 3), 0, 0); | |
Mat maskedProbMap = new Mat(); | |
Imgproc.threshold(smoothProbMap, maskedProbMap, threshold, 255, THRESH_BINARY); | |
maskedProbMap.convertTo(maskedProbMap, CV_8U, 1); | |
List<MatOfPoint> contours = new ArrayList<>(); | |
Mat hierarchy = new Mat(); | |
findContours(maskedProbMap, contours, hierarchy, RETR_TREE, CHAIN_APPROX_SIMPLE); | |
List<KeyPoint> keyPoints = new ArrayList<>(); | |
for (int i = 0; i < contours.size(); i++) { | |
Mat blobMask = Mat.zeros(smoothProbMap.rows(), smoothProbMap.cols(), smoothProbMap.type()); | |
fillConvexPoly(blobMask, contours.get(i), new Scalar(1)); | |
MinMaxLocResult result = Core.minMaxLoc(smoothProbMap.mul(blobMask)); | |
keyPoints.add(new KeyPoint(-1, result.maxLoc, (float) probMap.get((int)result.maxLoc.y, (int) result.maxLoc.x)[0])); | |
} | |
return keyPoints; | |
} | |
public static List<Scalar> populateColorPalette(int nColors) { | |
List<Scalar> colors = new ArrayList<>(); | |
Random rnd = new Random(); | |
for (int i = 0; i < nColors; i++) { | |
colors.add(new Scalar( | |
(double) (rnd.nextInt(200 - 64) + 64), | |
(double) (rnd.nextInt(255 - 100) + 100), | |
(double) (rnd.nextInt(255 - 100) + 100))); | |
} | |
return colors; | |
} | |
public static void getValidPairs( | |
final List<Mat> netOutputParts, | |
final List<List<KeyPoint>> detectedKeypoints, | |
List<List<ValidPair>> validPairs, | |
Set<Integer> invalidPairs) { | |
int nInterpSamples = 10; | |
float pafScoreTh = 0.1f; | |
float confTh = 0.7f; | |
for (int k = 0; k < MAP_INDICES.length; k++) { | |
//A->B constitute a limb | |
Mat pafA = netOutputParts.get(MAP_INDICES[k][0]); | |
Mat pafB = netOutputParts.get(MAP_INDICES[k][1]); | |
List<KeyPoint> candA = detectedKeypoints.get(POSE_PAIRS[k][0]); | |
List<KeyPoint> candB = detectedKeypoints.get(POSE_PAIRS[k][1]); | |
int nA = candA.size(); | |
int nB = candB.size(); | |
if (nA != 0 && nB != 0) { | |
List<ValidPair> localValidPairs = new ArrayList<>(); | |
for (int i = 0; i < nA; i++) { | |
int maxJ = -1; | |
float maxScore = -1f; | |
boolean found = false; | |
for (int j = 0; j < nB; j++) { | |
float distanceX = (float) (candB.get(j).point.x - candA.get(i).point.x); | |
float distanceY = (float) (candB.get(j).point.y - candA.get(i).point.y); | |
float norm = (float) Math.sqrt(distanceX * distanceX + distanceY * distanceY); | |
if (Math.signum(norm) == 0) { | |
continue; | |
} | |
distanceX /= norm; | |
distanceY /= norm; | |
// find p(u) | |
List<Point> interpCoords = populateInterpPoints(candA.get(i).point, candB.get(j).point, nInterpSamples); | |
// find L(p(u)) | |
List<Pair<Float, Float>> pafInterp = new ArrayList<>(); | |
for (int l = 0; l < interpCoords.size(); l++) { | |
pafInterp.add( | |
new Pair<Float, Float>( | |
(float)pafA.get((int)interpCoords.get(l).y, (int)interpCoords.get(l).x)[0], | |
(float)pafB.get((int)interpCoords.get(l).y, (int)interpCoords.get(l).x)[0])); | |
} | |
List<Float> pafScores = new ArrayList<>(); | |
float sumOfPafScores = 0f; | |
int numOverTh = 0; | |
for (int l = 0; l < pafInterp.size(); l++) { | |
float score = pafInterp.get(l).getKey() * distanceX + pafInterp.get(l).getValue() * distanceY; | |
sumOfPafScores += score; | |
if (score > pafScoreTh) { | |
numOverTh += 1; | |
} | |
pafScores.add(score); | |
} | |
float avgPafScore = sumOfPafScores/((float)pafInterp.size()); | |
if (((float)numOverTh / (float)nInterpSamples) > confTh) { | |
if (avgPafScore > maxScore) { | |
maxJ = j; | |
maxScore = avgPafScore; | |
found = true; | |
} | |
} | |
} // end of j | |
if (found) { | |
localValidPairs.add(new ValidPair(candA.get(i).id, candB.get(maxJ).id, maxScore)); | |
} | |
} // end of i | |
validPairs.add(localValidPairs); | |
} else { | |
invalidPairs.add(k); | |
validPairs.add(new ArrayList<>()); | |
} | |
} // end of k | |
} | |
public static List<Point> populateInterpPoints(Point a, Point b, int numPoints) { | |
List<Point> interpCoords = new ArrayList<>(); | |
float xStep = ((float)(b.x - a.x))/(float)(numPoints-1); | |
float yStep = ((float)(b.y - a.y))/(float)(numPoints-1); | |
interpCoords.add(a); | |
for (int i = 0; i < numPoints; i++) { | |
interpCoords.add(new Point(a.x + xStep * i, a.y + yStep * i)); | |
} | |
interpCoords.add(b); | |
return interpCoords; | |
} | |
public static List<List<Integer>> getPersonwiseKeypoints( | |
List<List<ValidPair>> validPairs, Set<Integer> invalidPairs) { | |
List<List<Integer>> personwiseKeypoints = new ArrayList<>(); | |
for (int k = 0; k < MAP_INDICES.length; k++) { | |
if (invalidPairs.contains(k)) { | |
continue; | |
} | |
List<ValidPair> localValidPairs = validPairs.get(k); | |
int indexA = POSE_PAIRS[k][0]; | |
int indexB = POSE_PAIRS[k][1]; | |
for (int i = 0; i < localValidPairs.size(); i++) { | |
boolean found = false; | |
int personIdx = -1; | |
for (int j = 0; !found && j < personwiseKeypoints.size(); j++) { | |
if (indexA < personwiseKeypoints.get(j).size() | |
&& personwiseKeypoints.get(j).get(indexA) == localValidPairs.get(i).aId) { | |
personIdx = j; | |
found = true; | |
} | |
} | |
if (found) { | |
personwiseKeypoints.get(personIdx).set(indexB, localValidPairs.get(i).bId); | |
} else { | |
List<Integer> lpkp = new ArrayList<Integer>(18); | |
for (int l = 0; l < 18; l++) { | |
lpkp.add(-1); | |
} | |
lpkp.set(indexA, localValidPairs.get(i).aId); | |
lpkp.set(indexB, localValidPairs.get(i).bId); | |
personwiseKeypoints.add(lpkp); | |
} | |
} | |
} | |
return personwiseKeypoints; | |
} | |
private static class KeyPoint { | |
private int id; | |
private Point point; | |
private float probability; | |
public KeyPoint(int id, Point point, float probability) { | |
this.id = id; | |
this.point = point; | |
this.probability = probability; | |
} | |
public void setId(int id) { | |
this.id = id; | |
} | |
@Override | |
public String toString() { | |
return String.format("Id: %d, Point: %s, Prob: %f", id, point, probability); | |
} | |
} | |
private static class ValidPair { | |
private int aId; | |
private int bId; | |
private float score; | |
public ValidPair(int aId, int bId, float score) { | |
this.aId = aId; | |
this.bId = bId; | |
this.score = score; | |
} | |
@Override | |
public String toString() { | |
return String.format("A: %d, B: %d, Score: %f", aId, bId, score); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment