Skip to content

Instantly share code, notes, and snippets.

@YHaruoka
Last active November 24, 2018 15:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save YHaruoka/05b1d4a41f50951f43704853831f35f6 to your computer and use it in GitHub Desktop.
Save YHaruoka/05b1d4a41f50951f43704853831f35f6 to your computer and use it in GitHub Desktop.
0 255 0
0 0 255
255 0 0
0 255 255
255 255 0
255 0 255
80 70 180
250 80 190
245 145 50
70 150 250
50 190 190
#include "mask_r_cnn.h"
int main(int argc, char* argv[])
{
Mask_R_CNN_Detector mask_r_cnn_detector_;
String text_graph = "..\\..\\model\\mask_rcnn_inception_v2_coco_2018_01_28.pbtxt";
String model_weights = "..\\..\\model\\mask_rcnn_inception_v2_coco_2018_01_28\\frozen_inference_graph.pb";
String classes_file = "..\\..\\model\\mscoco_labels.txt";
String colors_file = "..\\..\\model\\colors.txt";
mask_r_cnn_detector_.init(text_graph, model_weights, classes_file, colors_file, 0.5, 0.3);
mask_r_cnn_detector_.imageExec("test.jpg", "output.jpg");
return 0;
}
#include "mask_r_cnn.h"
// 初期化関数
void Mask_R_CNN_Detector::init(String text_graph_, string model_weights_,
string classes_file_, string colors_file_, float conf_th, float mask_th){
// モデル関連のファイルへのパス
String text_graph = text_graph_;
String model_weights = model_weights_;
String classes_file = classes_file_;
String colors_file = colors_file_;
// 閾値
conf_threshold = conf_th;
mask_threshold = mask_th;
// クラス名の読み込み
ifstream ifs(classes_file.c_str());
string line;
while (getline(ifs, line)) classes.push_back(line);
// 色の読み込み
ifstream colorFptr(colors_file.c_str());
while (getline(colorFptr, line)) {
char* pEnd;
double r, g, b;
r = strtod(line.c_str(), &pEnd);
g = strtod(pEnd, NULL);
b = strtod(pEnd, NULL);
Scalar color = Scalar(r, g, b, 255.0);
colors.push_back(Scalar(r, g, b, 255.0));
}
// Tensorflowからのネットワーク読み込み
net = readNetFromTensorflow(model_weights, text_graph);
net.setPreferableBackend(DNN_BACKEND_OPENCV);
net.setPreferableTarget(DNN_TARGET_CPU);
return;
}
// 実行関数
bool Mask_R_CNN_Detector::imageExec(String image_filename, String output_filename)
{
// 認識対象画像の読み込み
Mat frame = imread(image_filename, 1);
if (frame.empty()) {
cerr << "【Error】:" << image_filename << "が見つからない"<< endl;
return -1;
}
// blobの作成
Mat blob;
blobFromImage(frame, blob, 1.0, Size(frame.cols, frame.rows), Scalar(), true, false);
net.setInput(blob);
// Mask R-CNN実行部
std::vector<String> outNames(2);
outNames[0] = "detection_out_final";
outNames[1] = "detection_masks";
vector<Mat> outs;
net.forward(outs, outNames);
// 表示画像用の後処理
postprocess(frame, outs);
// 結果画像の出力
imwrite(output_filename, frame);
// ディスプレイへの表示
static const string window_name = "Mask R-CNN Result";
namedWindow(window_name);
imshow(window_name, frame);
waitKey(-1);
return 0;
}
// 表示画像用の後処理
void Mask_R_CNN_Detector::postprocess(Mat& frame, const vector<Mat>& outs)
{
Mat outDetections = outs[0];
Mat outMasks = outs[1];
const int numDetections = outDetections.size[2];
const int numClasses = outMasks.size[1];
outDetections = outDetections.reshape(1, outDetections.total() / 7);
for (int i = 0; i < numDetections; ++i)
{
float score = outDetections.at<float>(i, 2);
if (score > conf_threshold)
{
// Bounding Boxの抽出
int classId = static_cast<int>(outDetections.at<float>(i, 1));
int left = static_cast<int>(frame.cols * outDetections.at<float>(i, 3));
int top = static_cast<int>(frame.rows * outDetections.at<float>(i, 4));
int right = static_cast<int>(frame.cols * outDetections.at<float>(i, 5));
int bottom = static_cast<int>(frame.rows * outDetections.at<float>(i, 6));
left = max(0, min(left, frame.cols - 1));
top = max(0, min(top, frame.rows - 1));
right = max(0, min(right, frame.cols - 1));
bottom = max(0, min(bottom, frame.rows - 1));
Rect box = Rect(left, top, right - left + 1, bottom - top + 1);
Mat objectMask(outMasks.size[2], outMasks.size[3], CV_32F, outMasks.ptr<float>(i, classId));
// Bounding Boxと輪郭の表示
drawBox(frame, classId, score, box, objectMask);
}
}
}
// Bounding Boxと輪郭の表示
void Mask_R_CNN_Detector::drawBox(Mat& frame, int classId, float conf, Rect box, Mat& objectMask)
{
rectangle(frame, Point(box.x, box.y), Point(box.x + box.width, box.y + box.height), Scalar(0, 0, 255), 3);
// ラベルと信頼度の取得
string label = format("%.2f", conf);
if (!classes.empty())
{
CV_Assert(classId < (int)classes.size());
label = classes[classId] + ":" + label;
}
// トップのBounding Boxの表示
int baseLine;
Size labelSize = getTextSize(label, FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
box.y = max(box.y, labelSize.height);
rectangle(frame, Point(box.x, box.y - round(1.5*labelSize.height)), Point(box.x + round(1.5*labelSize.width), box.y + baseLine), Scalar(255, 255, 255), FILLED);
putText(frame, label, Point(box.x, box.y), FONT_HERSHEY_SIMPLEX, 0.75, Scalar(0, 0, 0), 1);
Scalar color = colors[classId%colors.size()];
// ObjectMaskのリサイズ
resize(objectMask, objectMask, Size(box.width, box.height));
Mat mask = (objectMask > mask_threshold);
Mat coloredRoi = (0.3 * color + 0.7 * frame(box));
coloredRoi.convertTo(coloredRoi, CV_8UC3);
// 輪郭の表示
vector<Mat> contours;
Mat hierarchy;
mask.convertTo(mask, CV_8U);
findContours(mask, contours, hierarchy, RETR_CCOMP, CHAIN_APPROX_SIMPLE);
drawContours(coloredRoi, contours, -1, color, 5, LINE_8, hierarchy, 100);
coloredRoi.copyTo(frame(box), mask);
}
#pragma once
#include <fstream>
#include <sstream>
#include <iostream>
#include <string.h>
#include <opencv2/dnn.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
using namespace cv;
using namespace dnn;
using namespace std;
// Mask_R_CNN_Detectorクラス
class Mask_R_CNN_Detector
{
private:
// モデル関連のファイルへのパス
String text_graph;
String model_weights;
String classes_file;
String colors_file;
// クラス名と色設定用ベクトル
vector<string> classes;
vector<Scalar> colors;
// 閾値
float conf_threshold;
float mask_threshold;
// ネットワークインスタンス
Net net;
public:
void Mask_R_CNN_Detector::init(String text_graph_, string model_weights_,
string classes_file_, string colors_file_, float conf_th, float mask_th);
bool Mask_R_CNN_Detector::imageExec(String image_filename, String output_filename);
void Mask_R_CNN_Detector::postprocess(Mat& frame, const vector<Mat>& outs);
void Mask_R_CNN_Detector::drawBox(Mat& frame, int classId, float conf, Rect box, Mat& objectMask);
};
person
bicycle
car
motorcycle
airplane
bus
train
truck
boat
traffic light
fire hydrant
stop sign
parking meter
bench
bird
cat
dog
horse
sheep
cow
elephant
bear
zebra
giraffe
backpack
umbrella
handbag
tie
suitcase
frisbee
skis
snowboard
sports ball
kite
baseball bat
baseball glove
skateboard
surfboard
tennis racket
bottle
wine glass
cup
fork
knife
spoon
bowl
banana
apple
sandwich
orange
broccoli
carrot
hot dog
pizza
donut
cake
chair
couch
potted plant
bed
dining table
toilet
tv
laptop
mouse
remote
keyboard
cell phone
microwave
oven
toaster
sink
refrigerator
book
clock
vase
scissors
teddy bear
hair drier
toothbrush
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment