Skip to content

Instantly share code, notes, and snippets.

@Nick3523
Last active January 26, 2020 08:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Nick3523/70d3d98f8ecd861caa5bda5ef0a0d6ef to your computer and use it in GitHub Desktop.
Save Nick3523/70d3d98f8ecd861caa5bda5ef0a0d6ef to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# coding: utf-8
# Code partiel, l'originial étant fait sur Jupyter
## Imports & Param's
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn import neighbors
mnist = fetch_openml('mnist_784')
sample = np.random.randint(70000, size=5000)
data = mnist.data[sample]
target = mnist.target[sample]
xtrain, xtest, ytrain, ytest = train_test_split(data, target, train_size=0.8, test_size = 0.2)
# ### Computing the best K parameter :
errors = []
for k in range(2,15):
knn = neighbors.KNeighborsClassifier(k)
errors.append(100*(1 - knn.fit(xtrain, ytrain).score(xtest, ytest)))
best_k = errors.index(min(errors)) + 2 #Car on test à partir de 2-NN
print("Le paramètre K minimisant le plus l'erreur est : ",best_k,"-NN",sep="")
# On récupère le classifieur le plus performant
knn = neighbors.KNeighborsClassifier(best_k)
knn.fit(xtrain, ytrain)
# On récupère les prédictions sur les données test
predicted = knn.predict(xtest)
# On redimensionne les données sous forme d'images
images = xtest.reshape((-1, 28, 28)) #Un tuple est passé en paramètre
# ### Result of the prediction :
# On selectionne un echantillon de 12 images au hasard
select = np.random.randint(images.shape[0], size=12)
# On affiche les images avec la prédiction associée
for index, value in enumerate(select):
plt.subplot(3,4,index+1) #sublot commence à partir de 1 et non de 0
plt.axis('off')
plt.imshow(images[value],cmap=plt.cm.gray_r) #images[value] est une tableau de 28x28, qui contient les informations de l'image, imgshow() affiche l'image à partir de ces informations
plt.title('Predicted: {}'.format( predicted[value]) )
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment