Last active
January 26, 2020 08:34
-
-
Save Nick3523/70d3d98f8ecd861caa5bda5ef0a0d6ef to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# coding: utf-8 | |
# Code partiel, l'originial étant fait sur Jupyter | |
## Imports & Param's | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from sklearn.datasets import fetch_openml | |
from sklearn.model_selection import train_test_split | |
from sklearn import neighbors | |
mnist = fetch_openml('mnist_784') | |
sample = np.random.randint(70000, size=5000) | |
data = mnist.data[sample] | |
target = mnist.target[sample] | |
xtrain, xtest, ytrain, ytest = train_test_split(data, target, train_size=0.8, test_size = 0.2) | |
# ### Computing the best K parameter : | |
errors = [] | |
for k in range(2,15): | |
knn = neighbors.KNeighborsClassifier(k) | |
errors.append(100*(1 - knn.fit(xtrain, ytrain).score(xtest, ytest))) | |
best_k = errors.index(min(errors)) + 2 #Car on test à partir de 2-NN | |
print("Le paramètre K minimisant le plus l'erreur est : ",best_k,"-NN",sep="") | |
# On récupère le classifieur le plus performant | |
knn = neighbors.KNeighborsClassifier(best_k) | |
knn.fit(xtrain, ytrain) | |
# On récupère les prédictions sur les données test | |
predicted = knn.predict(xtest) | |
# On redimensionne les données sous forme d'images | |
images = xtest.reshape((-1, 28, 28)) #Un tuple est passé en paramètre | |
# ### Result of the prediction : | |
# On selectionne un echantillon de 12 images au hasard | |
select = np.random.randint(images.shape[0], size=12) | |
# On affiche les images avec la prédiction associée | |
for index, value in enumerate(select): | |
plt.subplot(3,4,index+1) #sublot commence à partir de 1 et non de 0 | |
plt.axis('off') | |
plt.imshow(images[value],cmap=plt.cm.gray_r) #images[value] est une tableau de 28x28, qui contient les informations de l'image, imgshow() affiche l'image à partir de ces informations | |
plt.title('Predicted: {}'.format( predicted[value]) ) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment