Instantly share code, notes, and snippets.

Embed
What would you like to do?
KNN on MNIST with OpenCV
#include <iostream>
#include <opencv2/opencv.hpp>
using namespace std;
using namespace cv;
#define PIXELS_IN_IMAGE 28*28
#define ENABLE_TRAIN 1
int reverseInt(int i);
int loadMNIST(const string pic_filename, const string label_filename, Mat& training_data, Mat& label_data);
int main(int argc, char* argv[])
{
try
{
Mat train_data_mat, train_label_mat;
Mat test_data_mat, test_label_mat;
loadMNIST("F:\\Datasets\\mnist\\train-images.idx3-ubyte", "F:\\Datasets\\mnist\\train-labels.idx1-ubyte", train_data_mat, train_label_mat);
loadMNIST("F:\\Datasets\\mnist\\t10k-images.idx3-ubyte", "F:\\Datasets\\mnist\\t10k-labels.idx1-ubyte", test_data_mat, test_label_mat);
train_data_mat.convertTo(train_data_mat, CV_32FC1);
train_label_mat.convertTo(train_label_mat, CV_32SC1);
test_data_mat.convertTo(test_data_mat, CV_32FC1);
Ptr<ml::KNearest> knn;
knn = ml::KNearest::create();
knn->setDefaultK(10);
knn->train(train_data_mat, ml::SampleTypes::ROW_SAMPLE, train_label_mat);
int correct_count = 0;
for (int idx = 0; idx < test_label_mat.rows; idx++) {
Mat result_mat;
float response = knn->findNearest(test_data_mat.row(idx), knn->getDefaultK(), result_mat);
if(test_label_mat.at<uchar>(idx, 0) == (uchar)response) {
correct_count++;
}
}
double correct_ratio = (double)correct_count / (double)test_label_mat.rows;
cout << correct_ratio << endl;
}
catch (const Exception& ex)
{
cout << "Error: " << ex.what() << endl;
}
cin.get();
return 0;
}
int reverseInt(int i) {
unsigned char ch1, ch2, ch3, ch4;
ch1 = i & 255;
ch2 = (i >> 8) & 255;
ch3 = (i >> 16) & 255;
ch4 = (i >> 24) & 255;
return ((int)ch1 << 24) + ((int)ch2 << 16) + ((int)ch3 << 8) + ch4;
}
int loadMNIST(const string pic_filename, const string label_filename, Mat& training_data, Mat& label_data) {
std::ifstream pic_file(pic_filename, std::ios::binary);
std::ifstream label_file(label_filename, std::ios::binary);
if (pic_file.is_open() && label_file.is_open()) {
int magic_number = 0;
int number_of_images = 0;
int n_rows = 0;
int n_cols = 0;
label_file.read((char*)&magic_number, sizeof(magic_number));
pic_file.read((char*)&magic_number, sizeof(magic_number));
magic_number = reverseInt(magic_number);
label_file.read((char*)&number_of_images, sizeof(number_of_images));
pic_file.read((char*)&number_of_images, sizeof(number_of_images));
number_of_images = reverseInt(number_of_images);
pic_file.read((char*)&n_rows, sizeof(n_rows));
n_rows = reverseInt(n_rows);
pic_file.read((char*)&n_cols, sizeof(n_cols));
n_cols = reverseInt(n_cols);
int n_stride = n_cols * n_rows;
training_data = Mat(number_of_images, n_stride, CV_8U);
label_data = Mat(number_of_images, 1, CV_8U);
for (int i = 0; i < number_of_images; ++i) {
// for (int i = 0; i < 5000; ++i) {
unsigned char data_tmp[PIXELS_IN_IMAGE];
pic_file.read((char*)data_tmp, sizeof(unsigned char) * n_stride);
Mat row_image(1, n_stride, CV_8U, data_tmp);
row_image.row(0).copyTo(training_data.row(i));
char label = 0;
label_file.read((char*)&label, sizeof(label));
label_data.at<uchar>(i, 0) = label;
}
} else {
return 1;
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment