Last active
November 24, 2018 15:18
-
-
Save YHaruoka/05b1d4a41f50951f43704853831f35f6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
0 255 0 | |
0 0 255 | |
255 0 0 | |
0 255 255 | |
255 255 0 | |
255 0 255 | |
80 70 180 | |
250 80 190 | |
245 145 50 | |
70 150 250 | |
50 190 190 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "mask_r_cnn.h" | |
int main(int argc, char* argv[]) | |
{ | |
Mask_R_CNN_Detector mask_r_cnn_detector_; | |
String text_graph = "..\\..\\model\\mask_rcnn_inception_v2_coco_2018_01_28.pbtxt"; | |
String model_weights = "..\\..\\model\\mask_rcnn_inception_v2_coco_2018_01_28\\frozen_inference_graph.pb"; | |
String classes_file = "..\\..\\model\\mscoco_labels.txt"; | |
String colors_file = "..\\..\\model\\colors.txt"; | |
mask_r_cnn_detector_.init(text_graph, model_weights, classes_file, colors_file, 0.5, 0.3); | |
mask_r_cnn_detector_.imageExec("test.jpg", "output.jpg"); | |
return 0; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "mask_r_cnn.h" | |
// 初期化関数 | |
void Mask_R_CNN_Detector::init(String text_graph_, string model_weights_, | |
string classes_file_, string colors_file_, float conf_th, float mask_th){ | |
// モデル関連のファイルへのパス | |
String text_graph = text_graph_; | |
String model_weights = model_weights_; | |
String classes_file = classes_file_; | |
String colors_file = colors_file_; | |
// 閾値 | |
conf_threshold = conf_th; | |
mask_threshold = mask_th; | |
// クラス名の読み込み | |
ifstream ifs(classes_file.c_str()); | |
string line; | |
while (getline(ifs, line)) classes.push_back(line); | |
// 色の読み込み | |
ifstream colorFptr(colors_file.c_str()); | |
while (getline(colorFptr, line)) { | |
char* pEnd; | |
double r, g, b; | |
r = strtod(line.c_str(), &pEnd); | |
g = strtod(pEnd, NULL); | |
b = strtod(pEnd, NULL); | |
Scalar color = Scalar(r, g, b, 255.0); | |
colors.push_back(Scalar(r, g, b, 255.0)); | |
} | |
// Tensorflowからのネットワーク読み込み | |
net = readNetFromTensorflow(model_weights, text_graph); | |
net.setPreferableBackend(DNN_BACKEND_OPENCV); | |
net.setPreferableTarget(DNN_TARGET_CPU); | |
return; | |
} | |
// 実行関数 | |
bool Mask_R_CNN_Detector::imageExec(String image_filename, String output_filename) | |
{ | |
// 認識対象画像の読み込み | |
Mat frame = imread(image_filename, 1); | |
if (frame.empty()) { | |
cerr << "【Error】:" << image_filename << "が見つからない"<< endl; | |
return -1; | |
} | |
// blobの作成 | |
Mat blob; | |
blobFromImage(frame, blob, 1.0, Size(frame.cols, frame.rows), Scalar(), true, false); | |
net.setInput(blob); | |
// Mask R-CNN実行部 | |
std::vector<String> outNames(2); | |
outNames[0] = "detection_out_final"; | |
outNames[1] = "detection_masks"; | |
vector<Mat> outs; | |
net.forward(outs, outNames); | |
// 表示画像用の後処理 | |
postprocess(frame, outs); | |
// 結果画像の出力 | |
imwrite(output_filename, frame); | |
// ディスプレイへの表示 | |
static const string window_name = "Mask R-CNN Result"; | |
namedWindow(window_name); | |
imshow(window_name, frame); | |
waitKey(-1); | |
return 0; | |
} | |
// 表示画像用の後処理 | |
void Mask_R_CNN_Detector::postprocess(Mat& frame, const vector<Mat>& outs) | |
{ | |
Mat outDetections = outs[0]; | |
Mat outMasks = outs[1]; | |
const int numDetections = outDetections.size[2]; | |
const int numClasses = outMasks.size[1]; | |
outDetections = outDetections.reshape(1, outDetections.total() / 7); | |
for (int i = 0; i < numDetections; ++i) | |
{ | |
float score = outDetections.at<float>(i, 2); | |
if (score > conf_threshold) | |
{ | |
// Bounding Boxの抽出 | |
int classId = static_cast<int>(outDetections.at<float>(i, 1)); | |
int left = static_cast<int>(frame.cols * outDetections.at<float>(i, 3)); | |
int top = static_cast<int>(frame.rows * outDetections.at<float>(i, 4)); | |
int right = static_cast<int>(frame.cols * outDetections.at<float>(i, 5)); | |
int bottom = static_cast<int>(frame.rows * outDetections.at<float>(i, 6)); | |
left = max(0, min(left, frame.cols - 1)); | |
top = max(0, min(top, frame.rows - 1)); | |
right = max(0, min(right, frame.cols - 1)); | |
bottom = max(0, min(bottom, frame.rows - 1)); | |
Rect box = Rect(left, top, right - left + 1, bottom - top + 1); | |
Mat objectMask(outMasks.size[2], outMasks.size[3], CV_32F, outMasks.ptr<float>(i, classId)); | |
// Bounding Boxと輪郭の表示 | |
drawBox(frame, classId, score, box, objectMask); | |
} | |
} | |
} | |
// Bounding Boxと輪郭の表示 | |
void Mask_R_CNN_Detector::drawBox(Mat& frame, int classId, float conf, Rect box, Mat& objectMask) | |
{ | |
rectangle(frame, Point(box.x, box.y), Point(box.x + box.width, box.y + box.height), Scalar(0, 0, 255), 3); | |
// ラベルと信頼度の取得 | |
string label = format("%.2f", conf); | |
if (!classes.empty()) | |
{ | |
CV_Assert(classId < (int)classes.size()); | |
label = classes[classId] + ":" + label; | |
} | |
// トップのBounding Boxの表示 | |
int baseLine; | |
Size labelSize = getTextSize(label, FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine); | |
box.y = max(box.y, labelSize.height); | |
rectangle(frame, Point(box.x, box.y - round(1.5*labelSize.height)), Point(box.x + round(1.5*labelSize.width), box.y + baseLine), Scalar(255, 255, 255), FILLED); | |
putText(frame, label, Point(box.x, box.y), FONT_HERSHEY_SIMPLEX, 0.75, Scalar(0, 0, 0), 1); | |
Scalar color = colors[classId%colors.size()]; | |
// ObjectMaskのリサイズ | |
resize(objectMask, objectMask, Size(box.width, box.height)); | |
Mat mask = (objectMask > mask_threshold); | |
Mat coloredRoi = (0.3 * color + 0.7 * frame(box)); | |
coloredRoi.convertTo(coloredRoi, CV_8UC3); | |
// 輪郭の表示 | |
vector<Mat> contours; | |
Mat hierarchy; | |
mask.convertTo(mask, CV_8U); | |
findContours(mask, contours, hierarchy, RETR_CCOMP, CHAIN_APPROX_SIMPLE); | |
drawContours(coloredRoi, contours, -1, color, 5, LINE_8, hierarchy, 100); | |
coloredRoi.copyTo(frame(box), mask); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#include <fstream> | |
#include <sstream> | |
#include <iostream> | |
#include <string.h> | |
#include <opencv2/dnn.hpp> | |
#include <opencv2/imgproc.hpp> | |
#include <opencv2/highgui.hpp> | |
using namespace cv; | |
using namespace dnn; | |
using namespace std; | |
// Mask_R_CNN_Detectorクラス | |
class Mask_R_CNN_Detector | |
{ | |
private: | |
// モデル関連のファイルへのパス | |
String text_graph; | |
String model_weights; | |
String classes_file; | |
String colors_file; | |
// クラス名と色設定用ベクトル | |
vector<string> classes; | |
vector<Scalar> colors; | |
// 閾値 | |
float conf_threshold; | |
float mask_threshold; | |
// ネットワークインスタンス | |
Net net; | |
public: | |
void Mask_R_CNN_Detector::init(String text_graph_, string model_weights_, | |
string classes_file_, string colors_file_, float conf_th, float mask_th); | |
bool Mask_R_CNN_Detector::imageExec(String image_filename, String output_filename); | |
void Mask_R_CNN_Detector::postprocess(Mat& frame, const vector<Mat>& outs); | |
void Mask_R_CNN_Detector::drawBox(Mat& frame, int classId, float conf, Rect box, Mat& objectMask); | |
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
person | |
bicycle | |
car | |
motorcycle | |
airplane | |
bus | |
train | |
truck | |
boat | |
traffic light | |
fire hydrant | |
stop sign | |
parking meter | |
bench | |
bird | |
cat | |
dog | |
horse | |
sheep | |
cow | |
elephant | |
bear | |
zebra | |
giraffe | |
backpack | |
umbrella | |
handbag | |
tie | |
suitcase | |
frisbee | |
skis | |
snowboard | |
sports ball | |
kite | |
baseball bat | |
baseball glove | |
skateboard | |
surfboard | |
tennis racket | |
bottle | |
wine glass | |
cup | |
fork | |
knife | |
spoon | |
bowl | |
banana | |
apple | |
sandwich | |
orange | |
broccoli | |
carrot | |
hot dog | |
pizza | |
donut | |
cake | |
chair | |
couch | |
potted plant | |
bed | |
dining table | |
toilet | |
tv | |
laptop | |
mouse | |
remote | |
keyboard | |
cell phone | |
microwave | |
oven | |
toaster | |
sink | |
refrigerator | |
book | |
clock | |
vase | |
scissors | |
teddy bear | |
hair drier | |
toothbrush |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment