Skip to content

Instantly share code, notes, and snippets.

@spaghetti-source
Last active December 17, 2015 14:19
Show Gist options
  • Save spaghetti-source/5624007 to your computer and use it in GitHub Desktop.
Save spaghetti-source/5624007 to your computer and use it in GitHub Desktop.
k-NN classification for MNIST dataset
//
// Nearlest Neighbour Classification for MNIST dataset
//
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <map>
#include <vector>
#include <queue>
#include <cstring>
#include <functional>
#include <algorithm>
#include <cassert>
using namespace std;
#define ALL(c) c.begin(), c.end()
#define FOR(i,c) for(typeof(c.begin())i=c.begin();i!=c.end();++i)
#define REP(i,n) for(int i=0;i<n;++i)
#define fst first
#define snd second
// === tick a time ===
#include <ctime>
double tick() {
static clock_t oldtick;
clock_t newtick = clock();
double diff = 1.0*(newtick - oldtick) / CLOCKS_PER_SEC;
oldtick = newtick;
return diff;
}
void endianSwap(unsigned int &x) {
x = (x>>24)|((x<<8)&0x00FF0000)|((x>>8)&0x0000FF00)|(x<<24);
}
typedef vector<unsigned int> Image;
typedef unsigned char Label;
unsigned int row, col;
vector<Image> image;
vector<Label> label;
void readTraining(const char *imageFile, const char *labelFile) {
tick();
FILE *fimage, *flabel;
assert( fimage = fopen(imageFile, "rb") );
assert( flabel = fopen(labelFile, "rb") );
unsigned int magic, num;
fread(&magic, 4, 1, fimage);
assert(magic == 0x03080000);
fread(&magic, 4, 1, flabel);
assert(magic == 0x01080000);
fread(&num, 4, 1, flabel); // dust
fread(&num, 4, 1, fimage); endianSwap(num);
fread(&row, 4, 1, fimage); endianSwap(row);
fread(&col, 4, 1, fimage); endianSwap(col);
printf("num %d\n", num);
printf("col %d\n", col);
printf("row %d\n", row);
image.assign(num, Image(col*row));
label.resize(num);
REP(k, num) {
REP(i, col) REP(j, row)
fread(&image[k][i*row+j], 1, 1, fimage);
fread(&label[k], 1, 1, flabel);
}
fprintf(stderr, "training: %lf[sec]\n", tick());
fclose(fimage);
fclose(flabel);
}
int dist2(Image a, Image b) {
int d = 0;
REP(i, a.size()) d += pow(a[i] - b[i], 2);
return d;
}
int majority(int a, int b, int c) { return b == c ? b : a; }
int classify(Image img) {
// top 3, hard-coded
vector< pair<double,int> > order;
REP(l, image.size()) order.push_back( make_pair(dist2(img, image[l]), label[l]) );
partial_sort(order.begin(), order.begin()+3, order.end());
return majority(order[0].snd, order[1].snd, order[2].snd);
}
// Nearest Neighbour
void readTest(const char *imageFile, const char *labelFile) {
tick();
FILE *fimage, *flabel;
assert( fimage = fopen(imageFile, "rb") );
assert( flabel = fopen(labelFile, "rb") );
unsigned int magic, num;
fread(&magic, 4, 1, fimage);
assert(magic == 0x03080000);
fread(&magic, 4, 1, flabel);
assert(magic == 0x01080000);
fread(&num, 4, 1, flabel); // dust
fread(&num, 4, 1, fimage); endianSwap(num);
fread(&row, 4, 1, fimage); endianSwap(row);
fread(&col, 4, 1, fimage); endianSwap(col);
//num = 10;
printf("num %d\n", num);
printf("col %d\n", col);
printf("row %d\n", row);
int miss = 0;
REP(k, num) {
Image img(row*col);
Label lbl;
REP(i, col) REP(j, row)
fread(&img[i*row+j], 1, 1, fimage);
fread(&lbl, 1, 1, flabel);
tick();
int x = classify(img);
if (x != lbl) ++miss;
printf("error rate: %4d/%4d = %.2lf; classify per image: %.4lf[sec]\n", miss, k+1, 100.0*miss / (k+1), tick());
}
}
int main() {
readTraining("train-images-idx3-ubyte", "train-labels-idx1-ubyte");
readTest("t10k-images-idx3-ubyte", "t10k-labels-idx1-ubyte");
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment