Skip to content

Instantly share code, notes, and snippets.

@andrew-x
Created June 5, 2019 14:34
Show Gist options
  • Save andrew-x/0bb997b129647f3a7b7c0907b7e836fc to your computer and use it in GitHub Desktop.
Save andrew-x/0bb997b129647f3a7b7c0907b7e836fc to your computer and use it in GitHub Desktop.
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
import numpy as np
digits = load_digits()
x = list(map(lambda row: row.flatten() / 255, digits.data))
y = digits.target
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.33, random_state=42)
model = KNeighborsClassifier(n_neighbors=3)
model.fit(x_train, y_train)
y_train_prediction = model.predict(x_train)
y_test_prediction = model.predict(x_test)
train_accuracy = accuracy_score(y_train, y_train_prediction)
test_accuracy = accuracy_score(y_test, y_test_prediction)
train_confusion = confusion_matrix(y_train, y_train_prediction, labels=np.arange(0, 10))
test_confusion = confusion_matrix(y_test, y_test_prediction, labels=np.arange(0, 10))
print('train accuracy: {}% | test accuracy: {}%'.format(train_accuracy * 100, test_accuracy * 100))
def print_matrix(mat):
print(' ', end=' ')
for i in range(10):
print(str(i).ljust(3), end=' ')
print()
for i in range(10):
print(i, end=' ')
for j in range(10):
print(str(mat[i, j]).ljust(3), end=' ')
print()
print('===')
print('train:')
print_matrix(train_confusion)
print('test:')
print_matrix(test_confusion)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment