Created
April 9, 2018 13:36
-
-
Save YHaruoka/b2c3dfeb4929c641d9618bc84d40abb1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <opencv2/dnn.hpp> | |
#include <opencv2/imgproc.hpp> | |
#include <opencv2/highgui.hpp> | |
#include <opencv2/core/utils/trace.hpp> | |
#include <fstream> | |
#include <iostream> | |
#include <cstdlib> | |
using namespace cv; | |
using namespace cv::dnn; | |
using namespace std; | |
static void getMaxClass(const Mat &probBlob, int *classId, double *classProb) | |
{ | |
Mat probMat = probBlob.reshape(1, 1); | |
Point classNumber; | |
minMaxLoc(probMat, NULL, classProb, NULL, &classNumber); | |
*classId = classNumber.x; | |
} | |
static std::vector<String> readClassNames(const char *filename = "synset_words.txt") | |
{ | |
std::vector<String> classNames; | |
std::ifstream fp(filename); | |
if (!fp.is_open()) | |
{ | |
std::cerr << "File with classes labels not found: " << filename << std::endl; | |
exit(-1); | |
} | |
std::string name; | |
while (!fp.eof()) | |
{ | |
std::getline(fp, name); | |
if (name.length()) | |
classNames.push_back(name.substr(name.find(' ') + 1)); | |
} | |
fp.close(); | |
return classNames; | |
} | |
int main(int argc, char **argv) | |
{ | |
CV_TRACE_FUNCTION(); | |
String modelTxt = "bvlc_googlenet.prototxt"; | |
String modelBin = "bvlc_googlenet.caffemodel"; | |
String imageFile = "test1.jpg"; | |
Net net; | |
try { | |
net = dnn::readNetFromCaffe(modelTxt, modelBin); | |
} | |
catch (cv::Exception& e) { | |
std::cerr << "Exception: " << e.what() << std::endl; | |
if (net.empty()) | |
{ | |
exit(-1); | |
} | |
} | |
net.setPreferableTarget(DNN_TARGET_OPENCL); | |
Mat img = imread(imageFile); | |
if (img.empty()) | |
{ | |
std::cerr << "Can't read image from the file: " << imageFile << std::endl; | |
exit(-1); | |
} | |
Mat inputBlob = blobFromImage(img, 1.0f, Size(224, 224), | |
Scalar(104, 117, 123), false); | |
net.setInput(inputBlob, "data"); | |
Mat prob = net.forward("prob"); | |
int classId; | |
double classProb; | |
getMaxClass(prob, &classId, &classProb); | |
vector<String> classNames = readClassNames(); | |
cout << "Best class: #" << classId << " '" << classNames.at(classId) << "'" << endl; | |
cout << "Probability: " << classProb * 100 << "%" << endl; | |
resize(img, img, cv::Size(), 0.2, 0.2); | |
imshow("Image", img); | |
waitKey(0); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment