Skip to content

Instantly share code, notes, and snippets.

@joxer
Last active September 28, 2017 12:12
Show Gist options
  • Save joxer/bba9fe168a28fe82344ccaed76b03b40 to your computer and use it in GitHub Desktop.
Save joxer/bba9fe168a28fe82344ccaed76b03b40 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import csv\n",
"import numpy as np\n",
"from sklearn.linear_model import SGDClassifier\n",
"from sklearn.metrics import precision_score"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"calcoliamo il valore di precision"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def get_file_train(filename):\n",
" ret = []\n",
" with open(\"mldata/\"+filename) as file:\n",
" reader = csv.reader(file)\n",
" idx = 0\n",
" for row in reader:\n",
" if idx != 0:\n",
" ret.append([row[0],np.array(row[1:],dtype=np.float64)])\n",
" idx +=1\n",
" \n",
" return ret[1:]\n",
"\n",
"def get_file_test(filename):\n",
" ret = []\n",
" with open(\"mldata/\"+filename) as file:\n",
" reader = csv.reader(file)\n",
" idx = 0\n",
" for row in reader:\n",
" if idx != 0:\n",
" ret.append([row[0],np.array(row[1:],dtype=np.float64)])\n",
" idx +=1\n",
" \n",
" return ret\n",
"\n",
"test = get_file_test(\"mnist_test.csv\")\n",
"train = get_file_train(\"train.csv\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Un esempio delle immagini presenti nel dataset è questo:\n",
"<img src=\"https://www.classes.cs.uchicago.edu/archive/2013/spring/12300-1/pa/pa1/digit.png\" />\n",
"\n",
"Quello che noi invece abbiamo è un array lineare di una di queste immagini che hanno dimensione 28x28 pixel"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 18. 30. 137. 137. 192. 86. 72. 1. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 13. 86. 250. 254. 254. 254. 254. 217.\n",
" 246. 151. 32. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 16. 179. 254. 254. 254.\n",
" 254. 254. 254. 254. 254. 254. 231. 54. 15. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 72.\n",
" 254. 254. 254. 254. 254. 254. 254. 254. 254. 254. 254. 254.\n",
" 104. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 61. 191. 254. 254. 254. 254. 254. 109. 83. 199.\n",
" 254. 254. 254. 254. 243. 85. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 172. 254. 254. 254. 202. 147.\n",
" 147. 45. 0. 11. 29. 200. 254. 254. 254. 171. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 174. 254.\n",
" 254. 89. 67. 0. 0. 0. 0. 0. 0. 128. 252. 254.\n",
" 254. 212. 76. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 47. 254. 254. 254. 29. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 83. 254. 254. 254. 153. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 80. 254. 254. 240. 24. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 25. 240. 254. 254. 153. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 64. 254. 254.\n",
" 186. 7. 0. 0. 0. 0. 0. 0. 0. 0. 0. 166.\n",
" 254. 254. 224. 12. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 14. 232. 254. 254. 254. 29. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 75. 254. 254. 254. 17. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 18. 254. 254. 254. 254. 29. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 48. 254. 254. 254. 17.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 2. 163. 254. 254.\n",
" 254. 29. 0. 0. 0. 0. 0. 0. 0. 0. 0. 48.\n",
" 254. 254. 254. 17. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 94. 254. 254. 254. 200. 12. 0. 0. 0. 0. 0.\n",
" 0. 0. 16. 209. 254. 254. 150. 1. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 15. 206. 254. 254. 254. 202. 66.\n",
" 0. 0. 0. 0. 0. 21. 161. 254. 254. 245. 31. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 60. 212.\n",
" 254. 254. 254. 194. 48. 48. 34. 41. 48. 209. 254. 254.\n",
" 254. 171. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 86. 243. 254. 254. 254. 254. 254. 233. 243.\n",
" 254. 254. 254. 254. 254. 86. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 114. 254. 254. 254.\n",
" 254. 254. 254. 254. 254. 254. 254. 239. 86. 11. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 13. 182. 254. 254. 254. 254. 254. 254. 254. 254. 243. 70.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 8. 76. 146. 254. 255. 254. 255.\n",
" 146. 19. 15. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0.]\n"
]
}
],
"source": [
"print(train[0][1])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Creiamo il nostro dataset in modo tale che il classificatore SGD possa essere usato"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# creiamo i nostri due dataset di train e di test\n",
"# X sara' un vettore lineare di 28x28 elementi\n",
"# Y sara' il risultato della classificazione\n",
"X_train = [x[1] for x in train]\n",
"Y_train = [y[0] for y in train]\n",
"\n",
"X_test = [x[1] for x in test]\n",
"Y_test = [y[0] for y in test]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Istanziamo il nostro classificatore SGD e settiamo che faccia massimo 20 iterazioni per trovare il punto di ottimo"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sgd_clf = SGDClassifier(random_state=0,max_iter=20)\n",
"sgd_clf.fit(X_train, Y_train)\n",
"out = sgd_clf.predict(X_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Calcoliamo il valore di precision"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.836883688369\n"
]
}
],
"source": [
"print(precision_score(Y_test, out,average='micro'))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Creiamo un nuovo classificatore cambiando il numero di iterazioni fatte dall'algoritmo per vedere se il valore di precision cambia"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sgd_clf = SGDClassifier(random_state=0,max_iter=10)\n",
"sgd_clf.fit(X_train, Y_train)\n",
"out = sgd_clf.predict(X_test)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.872587258726\n"
]
}
],
"source": [
"print(precision_score(Y_test, out,average='micro'))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Il valore di precision è cambiato aumentando. Verifichiamo ora se diminuendo ancora di piu' il numero di iterazioni esso peggiora o migliora"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sgd_clf = SGDClassifier(random_state=0,max_iter=1)\n",
"sgd_clf.fit(X_train, Y_train)\n",
"out = sgd_clf.predict(X_test)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.818681868187\n"
]
}
],
"source": [
"print(precision_score(Y_test, out,average='micro'))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Il valore di precision è peggiorato, il modello con 10 iterazioni è il migliore per le nostre prove fatte"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.6",
"language": "python",
"name": "python36"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment