Last active
July 26, 2017 10:15
-
-
Save sekia/ad3668df6a615a975bff4a543e366a58 to your computer and use it in GitHub Desktop.
MNIST baseline classifier using kNN method.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <algorithm> | |
#include <arpa/inet.h> | |
#include <cstdint> | |
#include <cstdio> | |
#include <fstream> | |
#include <istream> | |
#include <numeric> | |
#include <set> | |
#include <stdexcept> | |
#include <string> | |
#include <tuple> | |
#include <vector> | |
using namespace std; | |
template <typename T> | |
void ReadBytes(istream& s, T *store, size_t n = 1) { | |
s.read(reinterpret_cast<char *>(store), sizeof(T) * n); | |
} | |
#define READ_32BIT_BE(type, var, fs) \ | |
type var; \ | |
do { ReadBytes(fs, &var); var = htonl(var); } while (0) | |
vector<uint8_t> ReadLabels(const string& filename) { | |
ifstream fs(filename, ifstream::binary); | |
READ_32BIT_BE(int32_t, magic, fs); | |
if (magic != 2049) { throw runtime_error("Invalid as label file."); } | |
READ_32BIT_BE(uint32_t, num_labels, fs); | |
vector<uint8_t> labels(num_labels); | |
ReadBytes(fs, labels.data(), num_labels); | |
return labels; | |
} | |
class Images { | |
private: | |
uint32_t height_; | |
uint32_t width_; | |
vector<int> pixels_; | |
public: | |
Images() = delete; | |
Images(uint32_t height, uint32_t width, vector<int>&& pixels) : | |
height_(height), width_(width), pixels_(pixels) {} | |
using PixelIterator = decltype(pixels_)::const_iterator; | |
using Image = pair<PixelIterator, PixelIterator>; | |
uint32_t Height() const { return height_; } | |
uint32_t Width() const { return width_; } | |
uint32_t ImageSize() const { return Height() * Width(); } | |
uint32_t Size() const { return pixels_.size() / ImageSize(); } | |
Image operator[](size_t i) const { | |
auto begin = pixels_.cbegin() + ImageSize() * i; | |
return make_pair(begin, begin + ImageSize()); | |
} | |
Image At(size_t i) const { | |
if (i >= Size()) { throw range_error("Index out of range"); } | |
return operator[](i); | |
} | |
}; | |
Images ReadImages(const string& filename) { | |
ifstream fs(filename, ifstream::binary); | |
READ_32BIT_BE(int32_t, magic, fs); | |
if (magic != 2051) { throw runtime_error("Invalid as image file."); } | |
READ_32BIT_BE(uint32_t, num_images, fs); | |
READ_32BIT_BE(uint32_t, height, fs); | |
READ_32BIT_BE(uint32_t, width, fs); | |
size_t num_pixels = num_images * height * width; | |
auto pixels = make_unique<uint8_t[]>(num_pixels); | |
ReadBytes(fs, pixels.get(), num_pixels); | |
vector<int> pixel_ints(num_pixels); | |
transform( | |
pixels.get(), pixels.get() + num_pixels, pixel_ints.begin(), | |
[](uint8_t pixel) { return static_cast<int>(pixel); }); | |
return Images(height, width, move(pixel_ints)); | |
} | |
#undef READ_32BIT_BE | |
template <typename Iterator1, typename Iterator2> | |
int ComputeUnsimilarity( | |
size_t image_size, | |
pair<Iterator1, Iterator1> image1, pair<Iterator2, Iterator2> image2) { | |
vector<int> squared_diffs(image_size); | |
transform( | |
image1.first, image1.second, image2.first, squared_diffs.begin(), | |
[](int a, int b) { int t = a - b; return t * t; }); | |
return accumulate(squared_diffs.begin(), squared_diffs.end(), 0); | |
} | |
template <typename Iterator> | |
uint8_t EstimateLabel( | |
size_t k, const vector<uint8_t>& labels, const Images& images, | |
pair<Iterator, Iterator> image) { | |
if (k < 1) { throw logic_error("k must be > 1."); } | |
k = min(k, labels.size()); | |
vector<pair<int, uint8_t>> unsimilarities; | |
unsimilarities.reserve(labels.size()); | |
for (size_t i = 0; i < labels.size(); ++i) { | |
auto us = ComputeUnsimilarity(images.ImageSize(), images[i], image); | |
unsimilarities.emplace_back(make_pair(us, labels[i])); | |
} | |
sort(unsimilarities.begin(), unsimilarities.end()); | |
multiset<uint8_t> votes; | |
for (size_t i = 0; i < k; ++i) { votes.insert(unsimilarities[i].second); } | |
uint8_t best = *(votes.begin()); | |
size_t max_votes = 0; | |
for (auto label : votes) { | |
auto votes_for_label = votes.count(label); | |
if (votes_for_label > max_votes) { | |
best = label; | |
max_votes = votes_for_label; | |
} | |
} | |
return best; | |
} | |
int main() { | |
auto training_labels = ReadLabels("train-labels-idx1-ubyte"); | |
auto training_images = ReadImages("train-images-idx3-ubyte"); | |
auto test_labels = ReadLabels("t10k-labels-idx1-ubyte"); | |
auto test_images = ReadImages("t10k-images-idx3-ubyte"); | |
size_t total = 0, correct = 0; | |
for (size_t i = 0; i < test_labels.size(); ++i) { | |
auto got = EstimateLabel( | |
5, training_labels, training_images, test_images[i]); | |
printf("Got: %d, Expected: %d\n", got, test_labels[i]); | |
++total; | |
if (got == test_labels[i]) { ++correct; } | |
} | |
printf( | |
"Correct: %ld/%ld (%f%%)\n", | |
correct, total, static_cast<double>(correct) * 100 / total); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment