Created
August 6, 2017 11:28
-
-
Save johnhany/ff6835b2191e58f96699645b2d36c1a1 to your computer and use it in GitHub Desktop.
SVM on MNIST with OpenCV
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <opencv2/opencv.hpp> | |
using namespace std; | |
using namespace cv; | |
#define PIXELS_IN_IMAGE 28*28 | |
#define ENABLE_TRAIN 1 | |
int reverseInt(int i); | |
int loadMNIST(const string pic_filename, const string label_filename, Mat& training_data, Mat& label_data); | |
int main(int argc, char* argv[]) | |
{ | |
try | |
{ | |
Mat train_data_mat, train_label_mat; | |
Mat test_data_mat, test_label_mat; | |
loadMNIST("F:\\Datasets\\mnist\\train-images.idx3-ubyte", "F:\\Datasets\\mnist\\train-labels.idx1-ubyte", train_data_mat, train_label_mat); | |
loadMNIST("F:\\Datasets\\mnist\\t10k-images.idx3-ubyte", "F:\\Datasets\\mnist\\t10k-labels.idx1-ubyte", test_data_mat, test_label_mat); | |
train_data_mat.convertTo(train_data_mat, CV_32FC1); | |
train_label_mat.convertTo(train_label_mat, CV_32SC1); | |
test_data_mat.convertTo(test_data_mat, CV_32FC1); | |
Ptr<ml::SVM> svm; | |
if(ENABLE_TRAIN) { | |
svm = ml::SVM::create(); | |
svm->setType(ml::SVM::C_SVC); | |
svm->setKernel(ml::SVM::LINEAR); | |
svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 500, 1e-6)); | |
svm->train(train_data_mat, ml::SampleTypes::ROW_SAMPLE, train_label_mat); | |
svm->save("SVM_MNIST.xml"); | |
} else { | |
svm = Algorithm::load<ml::SVM>("SVM_MNIST.xml"); | |
} | |
int correct_count = 0; | |
for (int idx = 0; idx < test_label_mat.rows; idx++) { | |
float response = svm->predict(test_data_mat.row(idx)); | |
if (test_label_mat.at<uchar>(idx, 0) == (uchar)response) { | |
correct_count++; | |
} | |
} | |
double correct_ratio = (double)correct_count / (double)test_label_mat.rows; | |
cout << correct_ratio << endl; | |
} | |
catch (const Exception& ex) | |
{ | |
cout << "Error: " << ex.what() << endl; | |
} | |
cin.get(); | |
return 0; | |
} | |
int reverseInt(int i) { | |
unsigned char ch1, ch2, ch3, ch4; | |
ch1 = i & 255; | |
ch2 = (i >> 8) & 255; | |
ch3 = (i >> 16) & 255; | |
ch4 = (i >> 24) & 255; | |
return ((int)ch1 << 24) + ((int)ch2 << 16) + ((int)ch3 << 8) + ch4; | |
} | |
int loadMNIST(const string pic_filename, const string label_filename, Mat& training_data, Mat& label_data) { | |
std::ifstream pic_file(pic_filename, std::ios::binary); | |
std::ifstream label_file(label_filename, std::ios::binary); | |
if (pic_file.is_open() && label_file.is_open()) { | |
int magic_number = 0; | |
int number_of_images = 0; | |
int n_rows = 0; | |
int n_cols = 0; | |
label_file.read((char*)&magic_number, sizeof(magic_number)); | |
pic_file.read((char*)&magic_number, sizeof(magic_number)); | |
magic_number = reverseInt(magic_number); | |
label_file.read((char*)&number_of_images, sizeof(number_of_images)); | |
pic_file.read((char*)&number_of_images, sizeof(number_of_images)); | |
number_of_images = reverseInt(number_of_images); | |
pic_file.read((char*)&n_rows, sizeof(n_rows)); | |
n_rows = reverseInt(n_rows); | |
pic_file.read((char*)&n_cols, sizeof(n_cols)); | |
n_cols = reverseInt(n_cols); | |
int n_stride = n_cols * n_rows; | |
training_data = Mat(number_of_images, n_stride, CV_8U); | |
label_data = Mat(number_of_images, 1, CV_8U); | |
for (int i = 0; i < number_of_images; ++i) { | |
// for (int i = 0; i < 5000; ++i) { | |
unsigned char data_tmp[PIXELS_IN_IMAGE]; | |
pic_file.read((char*)data_tmp, sizeof(unsigned char) * n_stride); | |
Mat row_image(1, n_stride, CV_8U, data_tmp); | |
row_image.row(0).copyTo(training_data.row(i)); | |
char label = 0; | |
label_file.read((char*)&label, sizeof(label)); | |
label_data.at<uchar>(i, 0) = label; | |
} | |
} else { | |
return 1; | |
} | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment