public
anonymous / gist:6009944
Created

  • Download Gist
gistfile1.py
Python
1 2 3 4 5 6 7 8 9 10 11 12 13 14
import urllib2
import numpy as np
 
train_data_file = 'http://archive.ics.uci.edu/ml/machine-learning-databases/optdigits/optdigits.tra'
test_data_file = 'http://archive.ics.uci.edu/ml/machine-learning-databases/optdigits/optdigits.tes'
 
tmp = np.loadtxt(urllib2.urlopen(train_data_file), delimiter=',')
train_data, train_labels = tmp[:, :-1], tmp[:, -1]
 
tmp = np.loadtxt(urllib2.urlopen(test_data_file), delimiter=',')
test_data, test_labels = tmp[:, :-1], tmp[:, -1]
dists = (train_data ** 2).sum(axis=1)[:, np.newaxis] + (test_data ** 2).sum(axis=1) - 2 * np.dot(train_data, test_data.T)
closest = train_labels[np.argsort(dists, axis=0)]
print "1-NN accuracy: ", (closest[0] == test_labels).sum() / float(len(test_labels))

Please sign in to comment on this gist.

Something went wrong with that request. Please try again.