Skip to content

Instantly share code, notes, and snippets.

@duckworthd
Created May 14, 2012 08:27
Show Gist options
  • Save duckworthd/2692699 to your computer and use it in GitHub Desktop.
Save duckworthd/2692699 to your computer and use it in GitHub Desktop.
Reproducing KNN warnings
from pprint import pprint
from sklearn.datasets import load_digits
from sklearn.neighbors import KNeighborsClassifier
from sklearn.cross_validation import KFold
# load data
digits = load_digits()
X = digits.data.astype('uint8')
Y = digits.target
# try to predict
scores = []
for (train, test) in KFold(n=X.shape[0], k=5):
X_train, X_test, Y_train, Y_test = X[train], X[test], Y[train], Y[test]
clf = KNeighborsClassifier(n_neighbors=1, algorithm='brute')
print 'as type {}'.format(X_train.dtype)
score1 = clf.fit(X_train, Y_train).score(X_test, Y_test)
print 'as type {}'.format(X_train.astype(float).dtype)
score2 = clf.fit(X_train.astype(float), Y_train).score(X_test.astype(float), Y_test)
scores.append( (score1, score2) )
print scores
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment