Skip to content

Instantly share code, notes, and snippets.

@propella
Created January 15, 2019 07:35
Show Gist options
  • Save propella/0f1b4e2bcaacc645fda5ac6ef87b9076 to your computer and use it in GitHub Desktop.
Save propella/0f1b4e2bcaacc645fda5ac6ef87b9076 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# いろんな方法で MNIST の数字画像を分類してみる\n",
"\n",
"流行りに遅れてるかもしれませんが、機械学習について色々調べています。どれくらい凄いことが出来るのかざっと確かめるために MNIST と呼ばれる数字画像を色々な方法で分類してみました。"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"from sklearn import datasets\n",
"from sklearn.linear_model import LogisticRegression, LinearRegression\n",
"from sklearn.metrics import accuracy_score, f1_score\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.neural_network import MLPClassifier\n",
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.svm import LinearSVC, SVC\n",
"from tensorflow import keras\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import tensorflow as tf"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## データのダウンロードと前処理\n",
"\n",
"まず scikit-learn のライブラリを使って数字データをダウンロードします。ちょっと時間がかかります。X に画像データ、y に正解ラベルが入ります。"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 23.4 s, sys: 1.88 s, total: 25.3 s\n",
"Wall time: 25.5 s\n"
]
}
],
"source": [
"# Load data from https://www.openml.org/d/554\n",
"%time X, y = datasets.fetch_openml('mnist_784', version=1, return_X_y=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"画像データは 70,000 * 784 の個の配列です。numpy の世界では配列の形は `[70000, 784]` のように外側 → 内側の順に書きます。\n",
"\n",
"* 1ピクセルは 0 - 255 です。\n",
"* 一つの数字は 28 * 28 = 784 個のピクセルです\n",
"* 全部で 70,000 個の数字画像が含まれています。"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"正解データ: ['5' '0' '4' ... '4' '5' '6']\n",
"画像の次元: (70000, 784)\n",
"最初のデータ: \n",
"[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 3. 18.\n",
" 18. 18. 126. 136. 175. 26. 166. 255. 247. 127. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 30. 36. 94. 154. 170. 253.\n",
" 253. 253. 253. 253. 225. 172. 253. 242. 195. 64. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 49. 238. 253. 253. 253. 253. 253.\n",
" 253. 253. 253. 251. 93. 82. 82. 56. 39. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 18. 219. 253. 253. 253. 253. 253.\n",
" 198. 182. 247. 241. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 80. 156. 107. 253. 253. 205.\n",
" 11. 0. 43. 154. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 14. 1. 154. 253. 90.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 139. 253. 190.\n",
" 2. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 11. 190. 253.\n",
" 70. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 35. 241.\n",
" 225. 160. 108. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 81.\n",
" 240. 253. 253. 119. 25. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 45. 186. 253. 253. 150. 27. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 16. 93. 252. 253. 187. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 249. 253. 249. 64. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 46. 130. 183. 253. 253. 207. 2. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 39. 148.\n",
" 229. 253. 253. 253. 250. 182. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 24. 114. 221. 253.\n",
" 253. 253. 253. 201. 78. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 23. 66. 213. 253. 253. 253.\n",
" 253. 198. 81. 2. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 18. 171. 219. 253. 253. 253. 253. 195.\n",
" 80. 9. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 55. 172. 226. 253. 253. 253. 253. 244. 133. 11.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 136. 253. 253. 253. 212. 135. 132. 16. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
"最初のデータ画像\n"
]
},
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x11345f4e0>"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADltJREFUeJzt3W+MlOW5x/HfBeI/igplD1kpuj1oTDYkghnhJBhFOUVrqsAbgzGIxoAvQE4TiAflhbzwhdHTNiqmyWIJcFJpGyoREnMsEo0hnhgG5axQpf7JYiH8WUKxVl+g9Dov9qHZ6s49w8wz88xyfT/JZmee67nnuTLsj2dm7pm5zd0FIJ4RRTcAoBiEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUBe08mDjx4/3rq6uVh4SCKWvr08nTpywWvZtKPxmdoekZyWNlPSiuz+V2r+rq0vlcrmRQwJIKJVKNe9b98N+Mxsp6QVJP5bULeleM+uu9/YAtFYjz/mnS/rY3T9199OSfiNpbj5tAWi2RsI/UdKfB10/lG37J2a2xMzKZlbu7+9v4HAA8tT0V/vdvcfdS+5e6ujoaPbhANSokfAfljRp0PUfZNsADAONhH+3pGvN7IdmdqGkBZK25dMWgGare6rP3b8xs2WSXtPAVN96d9+fW2cAmqqheX53f1XSqzn1AqCFeHsvEBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQTW0Sq+Z9Un6QtIZSd+4eymPppCfM2fOJOuff/55U4+/du3airWvvvoqOfbAgQPJ+gsvvJCsr1y5smJt8+bNybEXX3xxsr5q1apk/YknnkjW20FD4c/c6u4ncrgdAC3Ew34gqEbD75L+YGZ7zGxJHg0BaI1GH/bf5O6HzexfJO0wsw/d/a3BO2T/KSyRpKuuuqrBwwHIS0Nnfnc/nP0+LmmrpOlD7NPj7iV3L3V0dDRyOAA5qjv8ZjbazMacvSxpjqR9eTUGoLkaedg/QdJWMzt7Oy+5+//k0hWApqs7/O7+qaTrc+zlvPXZZ58l66dPn07W33777WR9165dFWunTp1Kjt2yZUuyXqRJkyYl64888kiyvnXr1oq1MWPGJMdef336T/uWW25J1ocDpvqAoAg/EBThB4Ii/EBQhB8IivADQeXxqb7w3nvvvWT9tttuS9ab/bHadjVy5Mhk/cknn0zWR48enazfd999FWtXXnllcuzYsWOT9euuuy5ZHw448wNBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUMzz5+Dqq69O1sePH5+st/M8/4wZM5L1avPhb7zxRsXahRdemBy7cOHCZB2N4cwPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0Exz5+DcePGJevPPPNMsr59+/Zkfdq0acn68uXLk/WUqVOnJuuvv/56sl7tM/X79lVex+W5555LjkVzceYHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaCqzvOb2XpJP5F03N2nZNvGSfqtpC5JfZLucfe/NK/N4W3evHnJerXv9a+2nHRvb2/F2osvvpgcu3LlymS92jx+NVOmTKlY6+npaei20ZhazvwbJN3xrW2rJO1092sl7cyuAxhGqobf3d+SdPJbm+dK2phd3igpfWoD0Hbqfc4/wd2PZJePSpqQUz8AWqThF/zc3SV5pbqZLTGzspmV+/v7Gz0cgJzUG/5jZtYpSdnv45V2dPcedy+5e6mjo6POwwHIW73h3yZpUXZ5kaRX8mkHQKtUDb+ZbZb0v5KuM7NDZvaQpKck/cjMPpL079l1AMNI1Xl+d7+3Qml2zr2EddlllzU0/vLLL697bLX3ASxYsCBZHzGC94kNV/zLAUERfiAowg8ERfiBoAg/EBThB4Liq7vPA2vWrKlY27NnT3Lsm2++maxX++ruOXPmJOtoX5z5gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAo5vnPA6mv1163bl1y7A033JCsL168OFm/9dZbk/VSqVSxtnTp0uRYM0vW0RjO/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFPP857nJkycn6xs2bEjWH3zwwWR906ZNdde//PLL5Nj7778/We/s7EzWkcaZHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCqjrPb2brJf1E0nF3n5JtWyNpsaT+bLfH3f3VZjWJ5pk/f36yfs011yTrK1asSNZT3/v/2GOPJccePHgwWV+9enWyPnHixGQ9ulrO/Bsk3THE9l+4+9Tsh+ADw0zV8Lv7W5JOtqAXAC3UyHP+ZWbWa2brzWxsbh0BaIl6w/9LSZMlTZV0RNLPKu1oZkvMrGxm5f7+/kq7AWixusLv7sfc/Yy7/13SOknTE/v2uHvJ3UsdHR319gkgZ3WF38wGf5xqvqR9+bQDoFVqmerbLGmWpPFmdkjSE5JmmdlUSS6pT9LDTewRQBOYu7fsYKVSycvlcsuOh+Y7depUsr59+/aKtQceeCA5ttrf5uzZs5P1HTt2JOvno1KppHK5XNOCB7zDDwiK8ANBEX4gKMIPBEX4gaAIPxAUU30ozEUXXZSsf/3118n6qFGjkvXXXnutYm3WrFnJscMVU30AqiL8QFCEHwiK8ANBEX4gKMIPBEX4gaBYohtJvb29yfqWLVuS9d27d1esVZvHr6a7uztZv/nmmxu6/fMdZ34gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIp5/vPcgQMHkvXnn38+WX/55ZeT9aNHj55zT7W64IL0n2dnZ2eyPmIE57YU7h0gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCKrqPL+ZTZK0SdIESS6px92fNbNxkn4rqUtSn6R73P0vzWs1rmpz6S+99FLF2tq1a5Nj+/r66mkpFzfeeGOyvnr16mT97rvvzrOdcGo5838jaYW7d0v6N0lLzaxb0ipJO939Wkk7s+sAhomq4Xf3I+7+bnb5C0kfSJooaa6kjdluGyXNa1aTAPJ3Ts/5zaxL0jRJ70ia4O5HstJRDTwtADBM1Bx+M/uepN9L+qm7/3VwzQcW/Bty0T8zW2JmZTMr9/f3N9QsgPzUFH4zG6WB4P/a3c9+0uOYmXVm9U5Jx4ca6+497l5y91JHR0cePQPIQdXwm5lJ+pWkD9z954NK2yQtyi4vkvRK/u0BaJZaPtI7U9JCSe+b2d5s2+OSnpL0OzN7SNJBSfc0p8Xh79ixY8n6/v37k/Vly5Yl6x9++OE595SXGTNmJOuPPvpoxdrcuXOTY/lIbnNVDb+775JUab3v2fm2A6BV+K8VCIrwA0ERfiAowg8ERfiBoAg/EBRf3V2jkydPVqw9/PDDybF79+5N1j/55JO6esrDzJkzk/UVK1Yk67fffnuyfskll5xzT2gNzvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EFSYef533nknWX/66aeT9d27d1esHTp0qK6e8nLppZdWrC1fvjw5ttrXY48ePbquntD+OPMDQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFBh5vm3bt3aUL0R3d3dyfpdd92VrI8cOTJZX7lyZcXaFVdckRyLuDjzA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQ5u7pHcwmSdokaYIkl9Tj7s+a2RpJiyX1Z7s+7u6vpm6rVCp5uVxuuGkAQyuVSiqXy1bLvrW8yecbSSvc/V0zGyNpj5ntyGq/cPf/qrdRAMWpGn53PyLpSHb5CzP7QNLEZjcGoLnO6Tm/mXVJmibp7HdiLTOzXjNbb2ZjK4xZYmZlMyv39/cPtQuAAtQcfjP7nqTfS/qpu/9V0i8lTZY0VQOPDH421Dh373H3kruXOjo6cmgZQB5qCr+ZjdJA8H/t7i9Lkrsfc/cz7v53SeskTW9emwDyVjX8ZmaSfiXpA3f/+aDtnYN2my9pX/7tAWiWWl7tnylpoaT3zezsWtOPS7rXzKZqYPqvT1J6nWoAbaWWV/t3SRpq3jA5pw+gvfEOPyAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFBVv7o714OZ9Us6OGjTeEknWtbAuWnX3tq1L4ne6pVnb1e7e03fl9fS8H/n4GZldy8V1kBCu/bWrn1J9FavonrjYT8QFOEHgio6/D0FHz+lXXtr174keqtXIb0V+pwfQHGKPvMDKEgh4TezO8zsgJl9bGariuihEjPrM7P3zWyvmRW6pHC2DNpxM9s3aNs4M9thZh9lv4dcJq2g3taY2eHsvttrZncW1NskM3vDzP5oZvvN7D+y7YXed4m+CrnfWv6w38xGSvqTpB9JOiRpt6R73f2PLW2kAjPrk1Ry98LnhM3sZkl/k7TJ3adk256WdNLdn8r+4xzr7v/ZJr2tkfS3olduzhaU6Ry8srSkeZIeUIH3XaKve1TA/VbEmX+6pI/d/VN3Py3pN5LmFtBH23P3tySd/NbmuZI2Zpc3auCPp+Uq9NYW3P2Iu7+bXf5C0tmVpQu97xJ9FaKI8E+U9OdB1w+pvZb8dkl/MLM9Zrak6GaGMCFbNl2SjkqaUGQzQ6i6cnMrfWtl6ba57+pZ8TpvvOD3XTe5+w2Sfixpafbwti35wHO2dpquqWnl5lYZYmXpfyjyvqt3xeu8FRH+w5ImDbr+g2xbW3D3w9nv45K2qv1WHz52dpHU7Pfxgvv5h3ZauXmolaXVBvddO614XUT4d0u61sx+aGYXSlogaVsBfXyHmY3OXoiRmY2WNEftt/rwNkmLssuLJL1SYC//pF1Wbq60srQKvu/absVrd2/5j6Q7NfCK/yeSVhfRQ4W+/lXS/2U/+4vuTdJmDTwM/FoDr408JOn7knZK+kjS65LGtVFv/y3pfUm9GghaZ0G93aSBh/S9kvZmP3cWfd8l+irkfuMdfkBQvOAHBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiCo/wcmwWArzGoGmwAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"print(f'正解データ: {y}')\n",
"print(f'画像の次元: {X.shape}')\n",
"print(f'最初のデータ: \\n{X[0]}')\n",
"\n",
"# 画像データは 784 個のピクセルが一列に並んでいるので、\n",
"# 表示するには X[0].reshape(28,28) のように reshape で 28 x 28 の二次元配列に変換します。\n",
"print(f'最初のデータ画像')\n",
"plt.imshow(X[0].reshape(28,28), cmap=plt.cm.gray_r)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"データを教師データとテストデータに分けておきます。その際にピクセル値が 0 - 1 になるように調整します。また、正解データが文字列なので、そのまま数値に変換します。"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"X_train の長さ: 52500\n",
"X_test の長さ: 17500\n",
"X_train の内容: [[0. 0. 0. ... 0. 0. 0.]\n",
" [0. 0. 0. ... 0. 0. 0.]\n",
" [0. 0. 0. ... 0. 0. 0.]\n",
" ...\n",
" [0. 0. 0. ... 0. 0. 0.]\n",
" [0. 0. 0. ... 0. 0. 0.]\n",
" [0. 0. 0. ... 0. 0. 0.]]\n",
"X_test の内容: [[0. 0. 0. ... 0. 0. 0.]\n",
" [0. 0. 0. ... 0. 0. 0.]\n",
" [0. 0. 0. ... 0. 0. 0.]\n",
" ...\n",
" [0. 0. 0. ... 0. 0. 0.]\n",
" [0. 0. 0. ... 0. 0. 0.]\n",
" [0. 0. 0. ... 0. 0. 0.]]\n",
"y_train の内容: [2 1 2 ... 6 5 5]\n",
"y_test の内容: [7 2 7 ... 3 7 6]\n"
]
}
],
"source": [
"X_train, X_test, y_train, y_test = train_test_split(X / 255, # ピクセル値が 0 - 1 になるようにする\n",
" y.astype('int64'), # 正解データを数値にする\n",
" stratify = y,\n",
" random_state=0)\n",
"print(f'X_train の長さ: {len(X_train)}')\n",
"print(f'X_test の長さ: {len(X_test)}')\n",
"print(f'X_train の内容: {X_train}')\n",
"print(f'X_test の内容: {X_test}')\n",
"print(f'y_train の内容: {y_train}')\n",
"print(f'y_test の内容: {y_test}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## scikit-learn でロジスティック回帰\n",
"\n",
"基本のロジスティック回帰をしてみます。784 個の点で単純に予測するだけなのに意外と 92.5 % と健闘します。"
]
},
{
"cell_type": "code",
"execution_count": 94,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"accuracy_score: 0.925\n",
"CPU times: user 54.8 s, sys: 2.46 s, total: 57.3 s\n",
"Wall time: 14.8 s\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"~/.local/share/virtualenvs/predictor-VCYBZ8Xn/lib/python3.6/site-packages/sklearn/linear_model/logistic.py:758: ConvergenceWarning: lbfgs failed to converge. Increase the number of iterations.\n",
" \"of iterations.\", ConvergenceWarning)\n"
]
}
],
"source": [
"def sklearn_logistic():\n",
" clf = LogisticRegression(solver='lbfgs', multi_class='auto')\n",
" clf.fit(X_train, y_train) # 学習\n",
" print('accuracy_score: %.3f' % clf.score(X_test, y_test)) # 検証\n",
"%time sklearn_logistic()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## scikit-learn で Support Vector Machines (SVC)\n",
"\n",
"同様に SVC というやつを試します。めちゃくちゃ時間がかかります(私のマシンで 11 分)。784 個の点で単純に予測するだけなのに 94.6% です。"
]
},
{
"cell_type": "code",
"execution_count": 95,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"accuracy_score: 0.946\n",
"CPU times: user 11min 30s, sys: 1.91 s, total: 11min 32s\n",
"Wall time: 11min 33s\n"
]
}
],
"source": [
"def sklearn_svc():\n",
" clf = SVC(kernel='rbf', gamma='auto', random_state=0, C=2)\n",
" clf.fit(X_train, y_train)\n",
" print('accuracy_score: %.3f' % clf.score(X_test, y_test))\n",
"%time sklearn_svc()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## scikit-learn でニュラルネットワーク\n",
"\n",
"scikit-learn にはお手軽にニュラルネットワークを試せる MLPClassifier というのがあります。97.6% です。さすが!"
]
},
{
"cell_type": "code",
"execution_count": 97,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Iteration 1, loss = 0.42688226\n",
"Iteration 2, loss = 0.20116195\n",
"Iteration 3, loss = 0.15006803\n",
"Iteration 4, loss = 0.11915271\n",
"Iteration 5, loss = 0.09875778\n",
"Iteration 6, loss = 0.08330753\n",
"Iteration 7, loss = 0.07107738\n",
"Iteration 8, loss = 0.06181587\n",
"Iteration 9, loss = 0.05333600\n",
"Iteration 10, loss = 0.04795776\n",
"Iteration 11, loss = 0.04195001\n",
"Iteration 12, loss = 0.03634990\n",
"Iteration 13, loss = 0.03223744\n",
"Iteration 14, loss = 0.02872861\n",
"Iteration 15, loss = 0.02420542\n",
"Iteration 16, loss = 0.02191879\n",
"Iteration 17, loss = 0.01904267\n",
"Iteration 18, loss = 0.01703576\n",
"Iteration 19, loss = 0.01456136\n",
"Iteration 20, loss = 0.01302070\n",
"accuracy_score: 0.976\n",
"CPU times: user 1min 4s, sys: 10.1 s, total: 1min 14s\n",
"Wall time: 21.9 s\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"~/.local/share/virtualenvs/predictor-VCYBZ8Xn/lib/python3.6/site-packages/sklearn/neural_network/multilayer_perceptron.py:562: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (20) reached and the optimization hasn't converged yet.\n",
" % self.max_iter, ConvergenceWarning)\n"
]
}
],
"source": [
"def sklearn_mlp():\n",
" clf = MLPClassifier(hidden_layer_sizes=(128,), solver='adam', max_iter=20, verbose=10, random_state=0)\n",
" clf.fit(X_train, y_train)\n",
" print('accuracy_score: %.3f' % clf.score(X_test, y_test))\n",
"%time sklearn_mlp()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tensorflow でニューラルネットワーク\n",
"\n",
"Tensorflow で同じようなネットワークを作ります。97.7% と同じような値が出ます。"
]
},
{
"cell_type": "code",
"execution_count": 101,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"52500/52500 [==============================] - 3s 63us/step - loss: 0.2745 - acc: 0.9220\n",
"Epoch 2/20\n",
"52500/52500 [==============================] - 3s 52us/step - loss: 0.1233 - acc: 0.9634\n",
"Epoch 3/20\n",
"52500/52500 [==============================] - 3s 52us/step - loss: 0.0834 - acc: 0.9747\n",
"Epoch 4/20\n",
"52500/52500 [==============================] - 3s 53us/step - loss: 0.0621 - acc: 0.9805\n",
"Epoch 5/20\n",
"52500/52500 [==============================] - 3s 53us/step - loss: 0.0473 - acc: 0.9851: 1\n",
"Epoch 6/20\n",
"52500/52500 [==============================] - 3s 54us/step - loss: 0.0368 - acc: 0.9887\n",
"Epoch 7/20\n",
"52500/52500 [==============================] - 3s 54us/step - loss: 0.0290 - acc: 0.9911\n",
"Epoch 8/20\n",
"52500/52500 [==============================] - 3s 53us/step - loss: 0.0239 - acc: 0.9929\n",
"Epoch 9/20\n",
"52500/52500 [==============================] - 3s 54us/step - loss: 0.0194 - acc: 0.9938\n",
"Epoch 10/20\n",
"52500/52500 [==============================] - 3s 55us/step - loss: 0.0141 - acc: 0.9960\n",
"Epoch 11/20\n",
"52500/52500 [==============================] - 3s 54us/step - loss: 0.0140 - acc: 0.9956\n",
"Epoch 12/20\n",
"52500/52500 [==============================] - 3s 52us/step - loss: 0.0114 - acc: 0.9962\n",
"Epoch 13/20\n",
"52500/52500 [==============================] - 3s 53us/step - loss: 0.0096 - acc: 0.9971\n",
"Epoch 14/20\n",
"52500/52500 [==============================] - 3s 54us/step - loss: 0.0082 - acc: 0.9974\n",
"Epoch 15/20\n",
"52500/52500 [==============================] - 3s 54us/step - loss: 0.0079 - acc: 0.9974\n",
"Epoch 16/20\n",
"52500/52500 [==============================] - 3s 55us/step - loss: 0.0077 - acc: 0.9977\n",
"Epoch 17/20\n",
"52500/52500 [==============================] - 3s 52us/step - loss: 0.0054 - acc: 0.9985\n",
"Epoch 18/20\n",
"52500/52500 [==============================] - 3s 52us/step - loss: 0.0066 - acc: 0.9981\n",
"Epoch 19/20\n",
"52500/52500 [==============================] - 3s 54us/step - loss: 0.0018 - acc: 0.9998\n",
"Epoch 20/20\n",
"52500/52500 [==============================] - 3s 51us/step - loss: 0.0068 - acc: 0.9975\n",
"accuracy_score: 0.977\n",
"CPU times: user 1min 26s, sys: 15.8 s, total: 1min 42s\n",
"Wall time: 57.1 s\n"
]
}
],
"source": [
"def tensorflow_mlp():\n",
"\n",
" model = keras.Sequential([\n",
" keras.layers.Dense(128, activation=tf.nn.relu),\n",
" keras.layers.Dense(10, activation=tf.nn.softmax)\n",
" ])\n",
" \n",
" model.compile(optimizer=tf.train.AdamOptimizer(), \n",
" loss='sparse_categorical_crossentropy',\n",
" metrics=['accuracy'])\n",
"\n",
" model.fit(X_train, y_train, epochs=20)\n",
" y_test_predict_list = model.predict(X_test)\n",
" y_test_predict = np.argmax(y_test_predict_list, axis=1)\n",
"\n",
" print('accuracy_score: %.3f' % accuracy_score(y_test, y_test_predict))\n",
"\n",
"%time tensorflow_mlp()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tensorflow で CNN\n",
"\n",
"単純なニューラルネットワークで 97.9% も出てしまいましたが、今流行りのディープラーニングの一つ CNN という手法も試してみます。99.2% 出ます。やりましたね!\n",
"\n",
"これまではピクセル全部を雑にまとめて学習していましたが、Convolution 層というので近くのピクセルから特徴量を作るのがポイントです。\n",
"Convolution 層の計算は結構重くて、GPU が無いとちょっと面倒です。https://colab.research.google.com を使うなどして工夫しましょう。逆に、Convolution 層を使わない場合自分のマックで計算してもあまり速くなりませんでした。以下ポイントです。\n",
"\n",
"* Convolution 層 に keras.layers.Conv2D を使うのですが、入力次元が (縦, 横, チャンネル数) と決まっています。モノクロの場合でも (28, 28, 1) のようにチャンネル数 1 の次元を作る必要があります。\n",
"* model.fit の時は Conv2D に必要な次元に合わせて X_train.reshape(-1, 28, 28, 1) のように次元を変更します。-1 を指定すると配列の長さから適当に調整してくれます。\n",
"* 意外とスコアが上がらないので keras.layers.Dropout という過学習を防ぐ仕組みを入れると何とか 99% 超え出来ました。"
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"conv2d_33 (Conv2D) (None, 28, 28, 32) 832 \n",
"_________________________________________________________________\n",
"max_pooling2d_22 (MaxPooling (None, 14, 14, 32) 0 \n",
"_________________________________________________________________\n",
"conv2d_34 (Conv2D) (None, 14, 14, 32) 25632 \n",
"_________________________________________________________________\n",
"max_pooling2d_23 (MaxPooling (None, 7, 7, 32) 0 \n",
"_________________________________________________________________\n",
"flatten_36 (Flatten) (None, 1568) 0 \n",
"_________________________________________________________________\n",
"dense_46 (Dense) (None, 128) 200832 \n",
"_________________________________________________________________\n",
"dropout_1 (Dropout) (None, 128) 0 \n",
"_________________________________________________________________\n",
"dense_47 (Dense) (None, 10) 1290 \n",
"=================================================================\n",
"Total params: 228,586\n",
"Trainable params: 228,586\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n",
"Epoch 1/10\n",
"52500/52500 [==============================] - 58s 1ms/step - loss: 0.1440 - acc: 0.9545\n",
"Epoch 2/10\n",
"52500/52500 [==============================] - 54s 1ms/step - loss: 0.0495 - acc: 0.9847\n",
"Epoch 3/10\n",
"52500/52500 [==============================] - 52s 994us/step - loss: 0.0345 - acc: 0.9888\n",
"Epoch 4/10\n",
"52500/52500 [==============================] - 52s 984us/step - loss: 0.0273 - acc: 0.9911\n",
"Epoch 5/10\n",
"52500/52500 [==============================] - 52s 995us/step - loss: 0.0208 - acc: 0.9934\n",
"Epoch 6/10\n",
"52500/52500 [==============================] - 55s 1ms/step - loss: 0.0182 - acc: 0.9940\n",
"Epoch 7/10\n",
"52500/52500 [==============================] - 53s 1ms/step - loss: 0.0145 - acc: 0.9954:\n",
"Epoch 8/10\n",
"52500/52500 [==============================] - 53s 1ms/step - loss: 0.0136 - acc: 0.9955\n",
"Epoch 9/10\n",
"52500/52500 [==============================] - 53s 1ms/step - loss: 0.0116 - acc: 0.9963\n",
"Epoch 10/10\n",
"52500/52500 [==============================] - 53s 1ms/step - loss: 0.0109 - acc: 0.9966\n",
"accuracy_score: 0.992\n",
"CPU times: user 30min 52s, sys: 3min 22s, total: 34min 15s\n",
"Wall time: 8min 59s\n"
]
}
],
"source": [
"def tensorflow_cnn():\n",
"\n",
" model = keras.Sequential([\n",
" keras.layers.Conv2D(32, activation=tf.nn.relu, kernel_size=(5,5), padding='same', input_shape=(28, 28, 1)),\n",
" keras.layers.MaxPooling2D(pool_size=(2, 2)),\n",
" keras.layers.Conv2D(32, activation=tf.nn.relu, kernel_size=(5,5), padding='same'),\n",
" keras.layers.MaxPooling2D(pool_size=(2, 2)),\n",
" keras.layers.Flatten(),\n",
" keras.layers.Dense(128, activation=tf.nn.relu),\n",
" keras.layers.Dropout(0.2),\n",
" keras.layers.Dense(10, activation=tf.nn.softmax),\n",
" ])\n",
" \n",
" model.compile(optimizer=tf.train.AdamOptimizer(), \n",
" loss='sparse_categorical_crossentropy',\n",
" metrics=['accuracy'])\n",
" model.summary()\n",
" \n",
" model.fit(X_train.reshape(-1, 28, 28, 1), y_train, epochs=10)\n",
" y_test_predict_list = model.predict(X_test.reshape(-1, 28, 28, 1))\n",
" y_test_predict = np.argmax(y_test_predict_list, axis=1)\n",
"\n",
" print('accuracy_score: %.3f' % accuracy_score(y_test, y_test_predict))\n",
"\n",
"%time tensorflow_cnn()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## まとめ\n",
"\n",
"* scikit-learn でロジスティック回帰: 92.5%\n",
"* scikit-learn で Support Vector Machines (SVC): 94.6%\n",
"* scikit-learn でニュラルネットワーク: 97.6%\n",
"* Tensorflow でニューラルネットワーク: 97.7%\n",
"* Tensorflow で CNN: 99.2%"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 参考\n",
"\n",
"* 使うデータセット: https://www.openml.org/d/554\n",
"* 資料:\n",
" * [MNIST手書き数字データをnumpyで書いたロジスティック回帰で学習して結果を分析する](https://qiita.com/phyblas/items/375ab130e53b0d04f784)\n",
" * [データサイエンティスト育成講座](https://weblab.t.u-tokyo.ac.jp/gci_contents/) 15 総合演習問題\n",
" * [Visualization of MLP weights on MNIST](https://scikit-learn.org/stable/auto_examples/neural_networks/plot_mnist_filters.html)\n",
" * [Build a Convolutional Neural Network using Estimators](https://www.tensorflow.org/tutorials/estimators/cnn)\n",
" * [Applying Convolutional Neural Network on the MNIST dataset](https://yashk2810.github.io/Applying-Convolutional-Neural-Network-on-the-MNIST-dataset/)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment