Skip to content

Instantly share code, notes, and snippets.

@sekia
Last active July 26, 2017 10:15
Show Gist options
  • Save sekia/ad3668df6a615a975bff4a543e366a58 to your computer and use it in GitHub Desktop.
Save sekia/ad3668df6a615a975bff4a543e366a58 to your computer and use it in GitHub Desktop.
MNIST baseline classifier using kNN method.
#include <algorithm>
#include <arpa/inet.h>
#include <cstdint>
#include <cstdio>
#include <fstream>
#include <istream>
#include <numeric>
#include <set>
#include <stdexcept>
#include <string>
#include <tuple>
#include <vector>
using namespace std;
template <typename T>
void ReadBytes(istream& s, T *store, size_t n = 1) {
s.read(reinterpret_cast<char *>(store), sizeof(T) * n);
}
#define READ_32BIT_BE(type, var, fs) \
type var; \
do { ReadBytes(fs, &var); var = htonl(var); } while (0)
vector<uint8_t> ReadLabels(const string& filename) {
ifstream fs(filename, ifstream::binary);
READ_32BIT_BE(int32_t, magic, fs);
if (magic != 2049) { throw runtime_error("Invalid as label file."); }
READ_32BIT_BE(uint32_t, num_labels, fs);
vector<uint8_t> labels(num_labels);
ReadBytes(fs, labels.data(), num_labels);
return labels;
}
class Images {
private:
uint32_t height_;
uint32_t width_;
vector<int> pixels_;
public:
Images() = delete;
Images(uint32_t height, uint32_t width, vector<int>&& pixels) :
height_(height), width_(width), pixels_(pixels) {}
using PixelIterator = decltype(pixels_)::const_iterator;
using Image = pair<PixelIterator, PixelIterator>;
uint32_t Height() const { return height_; }
uint32_t Width() const { return width_; }
uint32_t ImageSize() const { return Height() * Width(); }
uint32_t Size() const { return pixels_.size() / ImageSize(); }
Image operator[](size_t i) const {
auto begin = pixels_.cbegin() + ImageSize() * i;
return make_pair(begin, begin + ImageSize());
}
Image At(size_t i) const {
if (i >= Size()) { throw range_error("Index out of range"); }
return operator[](i);
}
};
Images ReadImages(const string& filename) {
ifstream fs(filename, ifstream::binary);
READ_32BIT_BE(int32_t, magic, fs);
if (magic != 2051) { throw runtime_error("Invalid as image file."); }
READ_32BIT_BE(uint32_t, num_images, fs);
READ_32BIT_BE(uint32_t, height, fs);
READ_32BIT_BE(uint32_t, width, fs);
size_t num_pixels = num_images * height * width;
auto pixels = make_unique<uint8_t[]>(num_pixels);
ReadBytes(fs, pixels.get(), num_pixels);
vector<int> pixel_ints(num_pixels);
transform(
pixels.get(), pixels.get() + num_pixels, pixel_ints.begin(),
[](uint8_t pixel) { return static_cast<int>(pixel); });
return Images(height, width, move(pixel_ints));
}
#undef READ_32BIT_BE
template <typename Iterator1, typename Iterator2>
int ComputeUnsimilarity(
size_t image_size,
pair<Iterator1, Iterator1> image1, pair<Iterator2, Iterator2> image2) {
vector<int> squared_diffs(image_size);
transform(
image1.first, image1.second, image2.first, squared_diffs.begin(),
[](int a, int b) { int t = a - b; return t * t; });
return accumulate(squared_diffs.begin(), squared_diffs.end(), 0);
}
template <typename Iterator>
uint8_t EstimateLabel(
size_t k, const vector<uint8_t>& labels, const Images& images,
pair<Iterator, Iterator> image) {
if (k < 1) { throw logic_error("k must be > 1."); }
k = min(k, labels.size());
vector<pair<int, uint8_t>> unsimilarities;
unsimilarities.reserve(labels.size());
for (size_t i = 0; i < labels.size(); ++i) {
auto us = ComputeUnsimilarity(images.ImageSize(), images[i], image);
unsimilarities.emplace_back(make_pair(us, labels[i]));
}
sort(unsimilarities.begin(), unsimilarities.end());
multiset<uint8_t> votes;
for (size_t i = 0; i < k; ++i) { votes.insert(unsimilarities[i].second); }
uint8_t best = *(votes.begin());
size_t max_votes = 0;
for (auto label : votes) {
auto votes_for_label = votes.count(label);
if (votes_for_label > max_votes) {
best = label;
max_votes = votes_for_label;
}
}
return best;
}
int main() {
auto training_labels = ReadLabels("train-labels-idx1-ubyte");
auto training_images = ReadImages("train-images-idx3-ubyte");
auto test_labels = ReadLabels("t10k-labels-idx1-ubyte");
auto test_images = ReadImages("t10k-images-idx3-ubyte");
size_t total = 0, correct = 0;
for (size_t i = 0; i < test_labels.size(); ++i) {
auto got = EstimateLabel(
5, training_labels, training_images, test_images[i]);
printf("Got: %d, Expected: %d\n", got, test_labels[i]);
++total;
if (got == test_labels[i]) { ++correct; }
}
printf(
"Correct: %ld/%ld (%f%%)\n",
correct, total, static_cast<double>(correct) * 100 / total);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment