Last active
August 1, 2019 04:10
-
-
Save hpcslag/6d92d9e52b02def8025afc11d5850e07 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# 檢查 GPU 數量" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"['/device:GPU:0', '/device:GPU:1']" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from tensorflow.python.client import device_lib\n", | |
"\n", | |
"def get_available_gpus():\n", | |
" local_device_protos = device_lib.list_local_devices()\n", | |
" return [x.name for x in local_device_protos if x.device_type == 'GPU']\n", | |
"\n", | |
"get_available_gpus()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from tensorflow.examples.tutorials.mnist import input_data\n", | |
"from tensorflow.contrib.learn.python.learn.datasets.mnist import extract_images, extract_labels\n", | |
"\n", | |
"# TensorFlow and tf.keras\n", | |
"import tensorflow as tf\n", | |
"from tensorflow import keras\n", | |
"from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dropout, Flatten, Dense\n", | |
"# Commonly used modules\n", | |
"import numpy as np\n", | |
"import os\n", | |
"import sys\n", | |
"# Images, plots, display, and visualization\n", | |
"import matplotlib.pyplot as plt\n", | |
"import pandas as pd\n", | |
"import seaborn as sns\n", | |
"import cv2\n", | |
"import IPython\n", | |
"from six.moves import urllib\n", | |
"\n", | |
"#input_data.read_data_sets('./newDataset')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# 導入自己的 dataset" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting ./newDataset/train-images-idx3-ubyte.gz\n", | |
"Extracting ./newDataset/train-labels-idx1-ubyte.gz\n", | |
"Extracting ./newDataset/t10k-images-idx3-ubyte.gz\n", | |
"Extracting ./newDataset/t10k-labels-idx1-ubyte.gz\n" | |
] | |
} | |
], | |
"source": [ | |
"with open('./newDataset/train-images-idx3-ubyte.gz', 'rb') as f:\n", | |
" train_images_new = extract_images(f)\n", | |
" train_images_new = train_images_new.reshape(train_images_new.shape[0], 28, 28, 1)\n", | |
"with open('./newDataset/train-labels-idx1-ubyte.gz', 'rb') as f:\n", | |
" train_labels_new = extract_labels(f)\n", | |
"\n", | |
"with open('./newDataset/t10k-images-idx3-ubyte.gz', 'rb') as f:\n", | |
" test_images_new = extract_images(f)\n", | |
" test_images_new = test_images_new.reshape(test_images_new.shape[0], 28, 28, 1)\n", | |
"with open('./newDataset/t10k-labels-idx1-ubyte.gz', 'rb') as f:\n", | |
" test_labels_new = extract_labels(f)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# 合併自己與 MNIST 的 Dataset" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Type is same, concat/merge datasets...\n", | |
"Merge successed.\n" | |
] | |
} | |
], | |
"source": [ | |
"(train_images_old, train_labels_old), (test_images_old, test_labels_old) = keras.datasets.mnist.load_data()\n", | |
"\n", | |
"#https://stackoverflow.com/questions/43153076/how-to-concatenate-numpy-arrays-into-a-specific-shape\n", | |
"#must have same dim\n", | |
"train_images_old = train_images_old.reshape(train_images_old.shape[0], 28, 28, 1)\n", | |
"test_images_old = test_images_old.reshape(test_images_old.shape[0], 28, 28, 1)\n", | |
"\n", | |
"train_images, train_labels, test_images, test_labels = [None, None, None, None]\n", | |
"\n", | |
"if type(train_images_old) == type(train_images_new):\n", | |
" print(\"Type is same, concat/merge datasets...\")\n", | |
" train_images = np.concatenate((train_images_new, train_images_old))\n", | |
" train_labels = np.concatenate((train_labels_new, train_labels_old))\n", | |
" \n", | |
" test_images = np.concatenate((test_images_new, test_images_old))\n", | |
" test_labels = np.concatenate((test_labels_new, test_labels_old))\n", | |
" print(\"Merge successed.\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# 前處理圖片並且打印出來看看" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def preprocess_images(imgs): # should work for both a single image and multiple images\n", | |
" sample_img = imgs if len(imgs.shape) == 2 else imgs[0]\n", | |
" assert sample_img.shape in [(28, 28, 1), (28, 28)], sample_img.shape # make sure images are 28x28 and single-channel (grayscale)\n", | |
" return imgs / 255.0\n", | |
"\n", | |
"train_images = preprocess_images(train_images)\n", | |
"test_images = preprocess_images(test_images)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAHIAAAB8CAYAAAC11+QNAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAADRklEQVR4nO3cMWsTYRzH8f9fxcWhTYlzgkMNuBakSxYn30BR+hY6SPe+gtKpXbJlCQ5Z8woydGkR1w5CIiiIocXQ/XGp511Imlx7yV1+9/2AcMed7QNfnue83BkPIRjW35O8B4BsEFIEIUUQUgQhRRBSxLM0J1er1VCv15c0FMwzGAxsNBr5tGOpQtbrdbu8vMxmVEhtZ2dn5jGWVhGEFEFIEYQUQUgRhBRBSBGEFEFIEYQUQUgRhBRBSBGEFEFIEYQUQUgRhBRBSBGEFEFIEaneolPmPvUtw9Ty+t9tzEgRhBRBSBGS18jxeBxtb2xsPPrnnZycJPYPDw9nnhu/1g4Gg8SxWq326LHMwowUQUgRa7u0np2dRdvdbjdxrN/vL/Qz4ktfVsve/v5+tH18fJw41mw2o+29vb1Mft8/zEgRhBRBSBFrc41M8xFa/Ho3eQuwbJ1OZ+bv5vYDcxFSRKGW1qurq8R+q9Va6O+dnp4m9g8ODjIb07pgRoogpAhCiijUNbLRaCx8Lt8zm8SMFEFIEbkvrUdHRwudV6SldDgcRtuT383Hy1d4FEKKWMnSGl9uJj/8vri4iLZ3d3cTx87Pz5c7sAfa2tqKtre3t3McyX/MSBGEFEFIESu5Rsavi/c9IC7SLUbc5NOUXq8Xba/6wfUszEgRhBSx8k92irp83uf29jaxf319ndNIZmNGiiCkCEKKyP3pR1HFb5MmX+5qt9srHs18zEgRhBRBSBGEFEFIEYQUQUgRhBRBSBF8snMnqy8VzAszUgQhRRBSBCFFEFIEIUVw+3Hn5uYmsV+pVHIaycMwI0UQUgQhRXCNvLO5uZnYX7cXqZmRIggpgpAiCCmCkCIIKYKQIggpgpAiCCmCkCIIKYKQIggpgpAiCCnC0zxAdfffZjaceyKWpRZCeDntQKqQKC6WVhGEFCEf0t1fu/vX2J+xu3/Ke1xZK9U10t2fmtkPM3sbQpD6R5v8jJzwzsy+qUU0K1/ID2b2Oe9BLENpllZ3f25mP83sTQjhV97jyVqZZuR7M/uiGNGsXCE/muiyalaSpdXdX5jZdzN7FUL4k/d4lqEUIcugTEurNEKKIKQIQoogpAhCiiCkCEKK+AuVb59wiRrWcQAAAABJRU5ErkJggg==\n", | |
"text/plain": [ | |
"<Figure size 720x144 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAEsAAABWCAYAAACHBmuvAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAACtklEQVR4nO2csU7bUABFz4PODCidGzHQH8jGD+QjMlEhJgYkFr6AjSEfUMqAyi9krRAVEhvdoVKlRC2qEjEmwQzIFaSg9jqx/ZzcI0UoMi/v6eQ4seNASJIE838slb2AKmFZApYlYFkCliVgWQJvlF+u1WpJvV7PaSlxcHNzw+3tbXhpmySrXq9zeXk5m1VFSqPReHWbd0MByxKwLAHLErAsAcsSsCwByxJYKFnj8ZjxeJx5/ELJmpaFkBVCeHbLykLImhXSiXTMpBde+v0+AKurq39tmxaXJWBZAnOzGy4tPT7v6YeTeVwPdVkCc1NWyvX1dW6P7bIEKl/W6ekpkM9r1CQuS6ByZaUFpactKysrmcequCyBypWVVpH+VF6rJsfe398/u/8vXJZAZcpqtVoAnJycANO9+2Ud67IEKlPW+fk5UMzx1Gu4LIHoy0rfqdrtdskrcVkS0ZbV6/UA2NjYAGBnZ2fqx+x0OgDs7e0BcHV1JY13WQLRlJVe/Dw+PgZgbW0NgLOzs5nN0Ww2AR9nFUI0ZS0vLwNwcXEBwObm5sznmPYYzWUJRFNWejw1Go1KXsnruCyB0sva3t4Gij3nOzg4AGB/f18a57IESisryyeds2JrayvTOJclYFkChe+Gw+EQgN3d3aKn/kOtVss0zmUJFF7WYDAA4PDwsOipp8ZlCRRS1tOLmFX+pxsuSyDXsrrdLvD4d8fzgMsSyKWs9AsX6+vrANzd3eUxTeG4LIFcyjo6OgLmp6gUlyUQxC+D/QK+57ecKHiXJMnblzZIshYd74YCliUQpawQwscQws8Qwrey1/KUKGUBn4Bm2YuYJEpZSZJ8AX6XvY5JopQVK5YlYFkCliUQpawQwmfgK/A+hPAjhPCh7DWBT3ckoiwrVixLwLIELEvAsgQsS8CyBCxL4AFsZKRr2a3O0gAAAABJRU5ErkJggg==\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAEsAAABWCAYAAACHBmuvAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAC6klEQVR4nO3aP0tbbRzG8e9P+wrEbkIzCN1EUBdHF/sW/DeJDoogOLs4ObpHSkGxb0EXBwcLwUHkwSUgPtKpFV+Cd5fGY9JIc0Fy54ReHziQeDL8cnmd2/PHSClhnRnq9wCDxGEJHJbAYQkclsBhCd4pHx4dHU2VSqVHo5TD/f09j4+P0W6fFFalUuHq6qo7U5XU9PT0m/t8GAoclsBhCRyWwGEJHJbAYQkclsBhCRyWoG9hRQQRQbVapVqt9msMiZslkC6ku2Fra6vp/fr6OgBra2u5R5G5WQKHJchyGL5+Nnl7e9u07/DwMMcIXeFmCbI06/n5+eX1+fk5ACMjIwCsrq4CRfsi2t7RLQU3S5ClWa/bMjExAcDNzU3TvkH4nws3S5ClWUNDxe9kbm4OKJo1SNwsQfbLne3tbQAODg6afr6ysgLAzs4OAJOTk3kH64CbJcjerLGxMeDP86rj42MAlpaWco/UMTdLkL1Zw8PDQHFWPz8/D8DZ2RkAm5ubAMzOzgJwdHSUe8Q3uVmC7M1qaJx7nZ6eAsXadXd31/S5xlq2vLyccbr23CxFSqnjbWpqKvVKvV5P9Xo9AW23XH5/x7bf380S9G3NajU+Pg4Uf/0uLi4AXh6T7e/vA8UDjsb9sJzcLMVbx2e7rZdrVqtarZZqtVr2NcxrVpeUZs1qNTMzAxTXkCcnJ0BxV+Jv9+p3d3dfXu/t7XVlJjdLUNpmtVpYWABgcXERgMvLSwA2NjYAeHp6AuDh4aFnM7hZgoFpVusa1bgrcX19nW0GN0vgsAQOS+CwBA5L4LAEDkvgsAQOS+CwBA5L4LAEDkvgsATRuG3b0YcjfgL/926cUviQUnrfbocU1r/Oh6HAYQlKGVZEfI6IHxHxX79nea2UYQFfgE/9HqJVKcNKKV0AT/2eo1UpwyorhyVwWAKHJShlWBHxFfgGfIyI7xGx2u+ZwJc7klI2q6wclsBhCRyWwGEJHJbAYQkcluAXPTuMW72oY/sAAAAASUVORK5CYII=\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAEsAAABWCAYAAACHBmuvAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAADOklEQVR4nO2cP0scURxFzwuCWomgCDbZRoJa2NjYKaikEQs/giJoISiIIihio3ZiY6NLGknrB7BKsyiWKyhYrBhQE4mFaKVMCpmsq5OwN9mdeUN+p3xv4F0u5zF/GRcEAUZ5vEs6QJqwsgSsLAErS8DKErCyBGqUg5uamoJMJlOlKH5QKBS4ublxUXNSWZlMhqOjo8qk8pTu7u7fztk2FLCyBKwsAStLwMoSsLIErCwBK0tAuij1kdXVVQAWFhYAmJubA2Btba3ia5lZAqk16+7uDoDNzU0AnHu+ndvY2ACgra3t17Gjo6MVWdPMEkidWY+PjwBsbW0BcH19XTLf0tICQE9PT8XXNrMErCyB1G3DXC4HwPz8fOR8uD07OjoqvraZJZAaswqFAgBTU1OR8/39/QD09fVVLYOZJZAas4aGhgA4Pj4uGW9oaABgdnYWgPr6+qplMLMEUmNWPp8Hirc1IRMTEwAMDAxUPYOZJeC9WTMzM5Hj4dlvaWkptixmloC3Zk1OTgKwt7dXMt7V1QXA7u4uAHV1dbFlMrMEvDPr8PAQKBp1dXVVMj8+Pg5Ac3NzvMEwsyS8MyubzQJweXlZMt7e3g7A8PBw7JlCzCwBb8wKXzTs7OwAb6/U9/f3AWhtbY032AvMLIHEzbq4uABge3sbgKenJwBqap6jjY2NAckaFWJmCSRm1tnZGVB8TnV6eloyPz09DcD6+nq8wf6AmSWQmFknJyfAW6NCQuN8wswSSMys29vbyPHe3l4AOjs7Y0xTHmaWQGJmLS4uRo6Hz7EaGxvjjFMWZpaAlSUQ+zYMX2nd39+XjC8vLwMwMjISd6SyMbMEYjfr4OAAKH4TGlJbWwu8fTTjE2aWQOxmhV8Or6ysAPDw8ADA4OBg3FFkzCyBxC5Kz8/Pk1r6rzGzBKwsAStLwMoSsLIEnPKzMefcdyB9pzGN90EQRH51IpX1v2PbUMDKEvCyLOdc1jn3zTmXTzrLS7wsC/gEfEw6xGu8LCsIgi/Aj6RzvMbLsnzFyhKwsgSsLAEvy3LOfQZywAfn3FfnXGX+YvGP2O2OgJdm+YqVJWBlCVhZAlaWgJUlYGUJWFkCPwF4v6VzxmnTyAAAAABJRU5ErkJggg==\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAEsAAABWCAYAAACHBmuvAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAACo0lEQVR4nO3asUqbYRTG8f/R4mSc7CjJZNeADgquQgfFCxCnTt6HdyGUgkpHB8ni2CFdxKkXkEKnNhQU4qZfh/YzNaSQBzTvCX1+kCGJw8nfE+P7kaiqCpvMXOkBZoljCRxL4FgCxxI4luCV8sPLy8tVq9V6oVFy6PV69Pv9GPecFKvVanF1dfU8UyW1vr7+z+f8NhQ4lsCxBI4lcCyBYwkcS+BYguKxqqpiVi5AFo81S4rH6nQ6dDodIoKIsUeyNIrHmiXSQfol7O7ulh5hYt4sgWMJir8Na9vb2wA8PDwAMDeX7/eYb6LEim3W0dHRk/uXl5eFJpmcN0tQbLPu7++B4d+m+h/SzEcfb5aiPshOcltbW6ueG/DkVtqf1zj29XuzBI4lcCxB8U/Di4sLAPb29oDcn4reLEGxzZqfnwdgZ2cHGJ4JM/NmCdLE6na7dLvdx/v1ZebBYMBgMCg42VCaWLMgzfWszc3NsY+fnJwAsL+/D0Cj0ZjaTKO8WYI0m1Xr9XrA728ZAhweHgLDK6nerBmRbrOazSYAS0tLANze3gLQ7/cBWFlZAWBhYWHqs3mzBOk2q3ZzcwPA4uIiABsbG8DwU/H09HTqM3mzBGk3q3Z9fQ1Au90G4Ozs7Mnjo1cntra2ADg+Pn72WbxZgvSbtbq6CsDd3R0ABwcHAJyfnwM8nhun8f+XN0vgWIL0b8NR9cG6BG+WwLEEjiVwLIFjCRxL4FgCxxI4lsCxBI4lCOWrPRHxA/j6cuOk0Kyq6vW4J6RY/zu/DQWOJUgZKyLeR8T3iPhSepa/pYwFfADelh5iVMpYVVV9An6WnmNUylhZOZbAsQSOJUgZKyI+Ap+BNxHxLSLelZ4JfNyRpNysrBxL4FgCxxI4lsCxBI4lcCzBLwamjoaX8YNZAAAAAElFTkSuQmCC\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"plt.figure(figsize=(10,2))\n", | |
"for i in range(5):\n", | |
" plt.subplot(1,5,i+1)\n", | |
" plt.xticks([])\n", | |
" plt.yticks([])\n", | |
" plt.grid(False)\n", | |
" plt.imshow(train_images[i].reshape(28, 28), cmap=plt.cm.binary)\n", | |
" plt.xlabel(train_labels[i])\n", | |
" plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# 模型參數設定" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model = keras.Sequential()\n", | |
"# 32 convolution filters used each of size 5x5\n", | |
"model.add(Conv2D(32, kernel_size=(5, 5), activation='relu', input_shape=(28, 28, 1)))\n", | |
"# 64 convolution filters used each of size 5x5\n", | |
"model.add(Conv2D(64, (5, 5), activation='relu'))\n", | |
"# choose the best features via pooling\n", | |
"model.add(MaxPooling2D(pool_size=(2, 2)))\n", | |
"# randomly turn neurons on and off to improve convergence\n", | |
"model.add(Dropout(0.25))\n", | |
"# flatten since too many dimensions, we only want a classification output\n", | |
"model.add(Flatten())\n", | |
"# fully connected to get all relevant data\n", | |
"model.add(Dense(128, activation='relu'))\n", | |
"# one more dropout\n", | |
"model.add(Dropout(0.5))\n", | |
"# output a softmax to squash the matrix into output probabilities\n", | |
"model.add(Dense(10, activation='softmax'))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# 設定模型優化器" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Model: \"sequential_1\"\n", | |
"_________________________________________________________________\n", | |
"Layer (type) Output Shape Param # \n", | |
"=================================================================\n", | |
"conv2d_2 (Conv2D) (None, 24, 24, 32) 832 \n", | |
"_________________________________________________________________\n", | |
"conv2d_3 (Conv2D) (None, 20, 20, 64) 51264 \n", | |
"_________________________________________________________________\n", | |
"max_pooling2d_1 (MaxPooling2 (None, 10, 10, 64) 0 \n", | |
"_________________________________________________________________\n", | |
"dropout_2 (Dropout) (None, 10, 10, 64) 0 \n", | |
"_________________________________________________________________\n", | |
"flatten_1 (Flatten) (None, 6400) 0 \n", | |
"_________________________________________________________________\n", | |
"dense_2 (Dense) (None, 128) 819328 \n", | |
"_________________________________________________________________\n", | |
"dropout_3 (Dropout) (None, 128) 0 \n", | |
"_________________________________________________________________\n", | |
"dense_3 (Dense) (None, 10) 1290 \n", | |
"=================================================================\n", | |
"Total params: 872,714\n", | |
"Trainable params: 872,714\n", | |
"Non-trainable params: 0\n", | |
"_________________________________________________________________\n" | |
] | |
} | |
], | |
"source": [ | |
"model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.0001), \n", | |
" loss='sparse_categorical_crossentropy',\n", | |
" metrics=['accuracy'])\n", | |
"model.summary() #shows summary" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# 設定可以用 Tensorboard 看" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"tbCallBack = keras.callbacks.TensorBoard(log_dir='./Graph', histogram_freq=0, write_graph=True, write_images=True)\n", | |
"#command: tensorboard --logdir path_to_current_dir/Graph " | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# 進行模型訓練" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 1/25\n", | |
"60287/60287 [==============================] - 2s 35us/sample - loss: 2.1346 - acc: 0.3353\n", | |
"Epoch 2/25\n", | |
"60287/60287 [==============================] - 2s 33us/sample - loss: 1.5568 - acc: 0.6111\n", | |
"Epoch 3/25\n", | |
"60287/60287 [==============================] - 2s 34us/sample - loss: 0.9872 - acc: 0.7214\n", | |
"Epoch 4/25\n", | |
"60287/60287 [==============================] - 2s 33us/sample - loss: 0.6987 - acc: 0.7907\n", | |
"Epoch 5/25\n", | |
"60287/60287 [==============================] - 2s 33us/sample - loss: 0.5612 - acc: 0.8320\n", | |
"Epoch 6/25\n", | |
"60287/60287 [==============================] - 2s 33us/sample - loss: 0.4838 - acc: 0.8567\n", | |
"Epoch 7/25\n", | |
"60287/60287 [==============================] - 2s 33us/sample - loss: 0.4330 - acc: 0.8714\n", | |
"Epoch 8/25\n", | |
"60287/60287 [==============================] - 2s 33us/sample - loss: 0.3883 - acc: 0.8852\n", | |
"Epoch 9/25\n", | |
"60287/60287 [==============================] - 2s 33us/sample - loss: 0.3512 - acc: 0.8977\n", | |
"Epoch 10/25\n", | |
"60287/60287 [==============================] - 2s 33us/sample - loss: 0.3210 - acc: 0.9068\n", | |
"Epoch 11/25\n", | |
"60287/60287 [==============================] - 2s 33us/sample - loss: 0.2948 - acc: 0.9142\n", | |
"Epoch 12/25\n", | |
"60287/60287 [==============================] - 2s 33us/sample - loss: 0.2677 - acc: 0.9230\n", | |
"Epoch 13/25\n", | |
"60287/60287 [==============================] - 2s 33us/sample - loss: 0.2465 - acc: 0.9284\n", | |
"Epoch 14/25\n", | |
"60287/60287 [==============================] - 2s 33us/sample - loss: 0.2312 - acc: 0.9328\n", | |
"Epoch 15/25\n", | |
"60287/60287 [==============================] - 2s 33us/sample - loss: 0.2168 - acc: 0.9368\n", | |
"Epoch 16/25\n", | |
"60287/60287 [==============================] - 2s 33us/sample - loss: 0.2015 - acc: 0.9418\n", | |
"Epoch 17/25\n", | |
"60287/60287 [==============================] - 2s 33us/sample - loss: 0.1898 - acc: 0.9461\n", | |
"Epoch 18/25\n", | |
"60287/60287 [==============================] - 2s 33us/sample - loss: 0.1767 - acc: 0.9491\n", | |
"Epoch 19/25\n", | |
"60287/60287 [==============================] - 2s 33us/sample - loss: 0.1735 - acc: 0.9500\n", | |
"Epoch 20/25\n", | |
"60287/60287 [==============================] - 2s 33us/sample - loss: 0.1621 - acc: 0.9530\n", | |
"Epoch 21/25\n", | |
"60287/60287 [==============================] - 2s 33us/sample - loss: 0.1567 - acc: 0.9548\n", | |
"Epoch 22/25\n", | |
"60287/60287 [==============================] - 2s 33us/sample - loss: 0.1495 - acc: 0.9571\n", | |
"Epoch 23/25\n", | |
"60287/60287 [==============================] - 2s 33us/sample - loss: 0.1464 - acc: 0.9578\n", | |
"Epoch 24/25\n", | |
"60287/60287 [==============================] - 2s 33us/sample - loss: 0.1398 - acc: 0.9599\n", | |
"Epoch 25/25\n", | |
"60287/60287 [==============================] - 2s 33us/sample - loss: 0.1378 - acc: 0.9601\n" | |
] | |
} | |
], | |
"source": [ | |
"history = model.fit(train_images, train_labels, epochs=25 ,batch_size = 4000,callbacks=[tbCallBack])\n", | |
"#adjust epochs and batch_size by your self" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# 測試 test 資料集總準確度" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"10035/10035 [==============================] - 0s 44us/sample - loss: 0.0750 - acc: 0.9767\n", | |
"Test accuracy: 0.9766816\n" | |
] | |
} | |
], | |
"source": [ | |
"test_loss, test_acc = model.evaluate(test_images, test_labels)\n", | |
"\n", | |
"print('Test accuracy:', test_acc)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# 儲存模型" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model.save('test_model.h5') #save model" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# 載入儲存的模型確認是否正常讀取" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"W0801 11:52:42.644349 140171754252096 deprecation.py:506] From /home/vk/.local/lib/python3.6/site-packages/tensorflow/python/ops/init_ops.py:97: calling GlorotUniform.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Call initializer instance with the dtype argument instead of passing it to the constructor\n", | |
"W0801 11:52:42.645456 140171754252096 deprecation.py:506] From /home/vk/.local/lib/python3.6/site-packages/tensorflow/python/ops/init_ops.py:97: calling Zeros.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Call initializer instance with the dtype argument instead of passing it to the constructor\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Model: \"sequential_1\"\n", | |
"_________________________________________________________________\n", | |
"Layer (type) Output Shape Param # \n", | |
"=================================================================\n", | |
"conv2d_2 (Conv2D) (None, 24, 24, 32) 832 \n", | |
"_________________________________________________________________\n", | |
"conv2d_3 (Conv2D) (None, 20, 20, 64) 51264 \n", | |
"_________________________________________________________________\n", | |
"max_pooling2d_1 (MaxPooling2 (None, 10, 10, 64) 0 \n", | |
"_________________________________________________________________\n", | |
"dropout_2 (Dropout) (None, 10, 10, 64) 0 \n", | |
"_________________________________________________________________\n", | |
"flatten_1 (Flatten) (None, 6400) 0 \n", | |
"_________________________________________________________________\n", | |
"dense_2 (Dense) (None, 128) 819328 \n", | |
"_________________________________________________________________\n", | |
"dropout_3 (Dropout) (None, 128) 0 \n", | |
"_________________________________________________________________\n", | |
"dense_3 (Dense) (None, 10) 1290 \n", | |
"=================================================================\n", | |
"Total params: 872,714\n", | |
"Trainable params: 872,714\n", | |
"Non-trainable params: 0\n", | |
"_________________________________________________________________\n", | |
"10035/10035 [==============================] - 0s 45us/sample - loss: 0.0750 - acc: 0.9767\n", | |
"Restored model, accuracy: 97.67%\n" | |
] | |
} | |
], | |
"source": [ | |
"test_new_model = keras.models.load_model('test_model.h5') #load model test.\n", | |
"test_new_model.summary()\n", | |
"test_loss, test_acc = test_new_model.evaluate(test_images, test_labels)\n", | |
"\n", | |
"\n", | |
"print(\"Restored model, accuracy: {:5.2f}%\".format(100*test_acc))" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.8" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment