Skip to content

Instantly share code, notes, and snippets.

@Elijas
Last active August 15, 2019 18:10
Show Gist options
  • Save Elijas/2e01e8b8056dacdc9625b7b8cc63cf00 to your computer and use it in GitHub Desktop.
Save Elijas/2e01e8b8056dacdc9625b7b8cc63cf00 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "deep-learning-keras-tutorial.ipynb",
"version": "0.3.2",
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/Elijas/2e01e8b8056dacdc9625b7b8cc63cf00/deep-learning-keras-tutorial.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "u8XdSpINMhVj",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 337
},
"outputId": "9f73d61c-663c-4670-e238-51ae2692cc46"
},
"source": [
"# 3. Import libraries and modules\n",
"\n",
"\n",
"import numpy as np\n",
"from matplotlib import pyplot as plt\n",
"from keras.models import Sequential\n",
"from keras.layers import Dense, Dropout, Activation, Flatten\n",
"from keras.layers import Convolution2D, MaxPooling2D\n",
"from keras.utils import np_utils\n",
"from keras.datasets import mnist\n",
"\n",
"np.random.seed(123) # for reproducibility\n",
" \n",
"# 4. Load pre-shuffled MNIST data into train and test sets\n",
"(X_train, y_train), (X_test, y_test) = mnist.load_data()\n",
"\n",
"\n",
"plt.imshow(X_train[0])"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Downloading data from https://s3.amazonaws.com/img-datasets/mnist.npz\n",
"11493376/11490434 [==============================] - 1s 0us/step\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7f48e715bbe0>"
]
},
"metadata": {
"tags": []
},
"execution_count": 2
},
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADoBJREFUeJzt3X2MXOV1x/HfyXq9jo1JvHHYboiL\nHeMEiGlMOjIgLKCiuA5CMiiKiRVFDiFxmuCktK4EdavGrWjlVgmRQynS0ri2I95CAsJ/0CR0FUGi\nwpbFMeYtvJlNY7PsYjZgQ4i9Xp/+sdfRBnaeWc/cmTu75/uRVjtzz71zj6792zszz8x9zN0FIJ53\nFd0AgGIQfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQU1r5M6mW5vP0KxG7hII5bd6U4f9kE1k\n3ZrCb2YrJG2W1CLpP9x9U2r9GZqls+2iWnYJIKHHuye8btVP+82sRdJNkj4h6QxJq83sjGofD0Bj\n1fKaf6mk5919j7sflnSHpJX5tAWg3moJ/8mSfjXm/t5s2e8xs7Vm1mtmvcM6VMPuAOSp7u/2u3uX\nu5fcvdSqtnrvDsAE1RL+fZLmjbn/wWwZgEmglvA/ImmRmS0ws+mSPi1pRz5tAai3qof63P2Ima2T\n9CONDvVtcfcnc+sMQF3VNM7v7vdJui+nXgA0EB/vBYIi/EBQhB8IivADQRF+ICjCDwRF+IGgCD8Q\nFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/\nEBThB4Ii/EBQhB8IivADQRF+IKiaZuk1sz5JByWNSDri7qU8mkJ+bFr6n7jl/XPruv9n/np+2drI\nzKPJbU9ZOJisz/yKJesv3zC9bG1n6c7ktvtH3kzWz75rfbJ+6l89nKw3g5rCn/kTd9+fw+MAaCCe\n9gNB1Rp+l/RjM3vUzNbm0RCAxqj1af8yd99nZidJut/MfuHuD45dIfujsFaSZmhmjbsDkJeazvzu\nvi/7PSjpHklLx1mny91L7l5qVVstuwOQo6rDb2azzGz2sduSlkt6Iq/GANRXLU/7OyTdY2bHHuc2\nd/9hLl0BqLuqw+/ueyR9LMdepqyW0xcl697Wmqy/dMF7k/W3zik/Jt3+nvR49U8/lh7vLtJ//WZ2\nsv4v/7YiWe8587aytReH30puu2ng4mT9Az/1ZH0yYKgPCIrwA0ERfiAowg8ERfiBoAg/EFQe3+oL\nb+TCjyfrN2y9KVn/cGv5r55OZcM+kqz//Y2fS9anvZkebjv3rnVla7P3HUlu27Y/PRQ4s7cnWZ8M\nOPMDQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCM8+eg7ZmXkvVHfzsvWf9w60Ce7eRqff85yfqeN9KX\n/t668Ptla68fTY/Td3z7f5L1epr8X9itjDM/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRl7o0b0TzR\n2v1su6hh+2sWQ1eem6wfWJG+vHbL7hOS9ce+cuNx93TM9fv/KFl/5IL0OP7Ia68n635u+au7930t\nuakWrH4svQLeoce7dcCH0nOXZzjzA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQFcf5zWyLpEslDbr7\n4mxZu6Q7Jc2X1Cdplbv/utLOoo7zV9Iy933J+sirQ8n6i7eVH6t/8vwtyW2X/vNXk/WTbiruO/U4\nfnmP82+V9PaJ0K+T1O3uiyR1Z/cBTCIVw+/uD0p6+6lnpaRt2e1tki7LuS8AdVbta/4Od+/Pbr8s\nqSOnfgA0SM1v+PnomwZl3zgws7Vm1mtmvcM6VOvuAOSk2vAPmFmnJGW/B8ut6O5d7l5y91Kr2qrc\nHYC8VRv+HZLWZLfXSLo3n3YANErF8JvZ7ZIekvQRM9trZldJ2iTpYjN7TtKfZvcBTCIVr9vv7qvL\nlBiwz8nI/ldr2n74wPSqt/3oZ55K1l+5uSX9AEdHqt43isUn/ICgCD8QFOEHgiL8QFCEHwiK8ANB\nMUX3FHD6tc+WrV15ZnpE9j9P6U7WL/jU1cn67DsfTtbRvDjzA0ERfiAowg8ERfiBoAg/EBThB4Ii\n/EBQjPNPAalpsl/98unJbf9vx1vJ+nXXb0/W/2bV5cm6//w9ZWvz/umh5LZq4PTxEXHmB4Ii/EBQ\nhB8IivADQRF+ICjCDwRF+IGgKk7RnSem6G4+Q58/N1m/9evfSNYXTJtR9b4/un1dsr7olv5k/cie\nvqr3PVXlPUU3gCmI8ANBEX4gKMIPBEX4gaAIPxAU4QeCqjjOb2ZbJF0qadDdF2fLNkr6oqRXstU2\nuPt9lXbGOP/k4+ctSdZP3LQ3Wb/9Qz+qet+n/eQLyfpH/qH8dQwkaeS5PVXve7LKe5x/q6QV4yz/\nlrsvyX4qBh9Ac6kYfnd/UNJQA3oB0EC1vOZfZ2a7zWyLmc3JrSMADVFt+G+WtFDSEkn9kr5ZbkUz\nW2tmvWbWO6xDVe4OQN6qCr+7D7j7iLsflXSLpKWJdbvcveTupVa1VdsngJxVFX4z6xxz93JJT+TT\nDoBGqXjpbjO7XdKFkuaa2V5JX5d0oZktkeSS+iR9qY49AqgDvs+PmrR0nJSsv3TFqWVrPdduTm77\nrgpPTD/z4vJk/fVlrybrUxHf5wdQEeEHgiL8QFCEHwiK8ANBEX4gKIb6UJjv7U1P0T3Tpifrv/HD\nyfqlX72m/GPf05PcdrJiqA9ARYQfCIrwA0ERfiAowg8ERfiBoAg/EFTF7/MjtqPL0pfufuFT6Sm6\nFy/pK1urNI5fyY1DZyXrM+/trenxpzrO/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOP8U5yVFifr\nz34tPdZ+y3nbkvXzZ6S/U1+LQz6crD88tCD9AEf7c+xm6uHMDwRF+IGgCD8QFOEHgiL8QFCEHwiK\n8ANBVRznN7N5krZL6pDkkrrcfbOZtUu6U9J8SX2SVrn7r+vXalzTFpySrL9w5QfK1jZecUdy20+e\nsL+qnvKwYaCUrD+w+Zxkfc629HX/kTaRM/8RSevd/QxJ50i62szOkHSdpG53XySpO7sPYJKoGH53\n73f3ndntg5KelnSypJWSjn38a5uky+rVJID8HddrfjObL+ksST2SOtz92OcnX9boywIAk8SEw29m\nJ0j6gaRr3P3A2JqPTvg37qR/ZrbWzHrNrHdYh2pqFkB+JhR+M2vVaPBvdfe7s8UDZtaZ1TslDY63\nrbt3uXvJ3UutasujZwA5qBh+MzNJ35H0tLvfMKa0Q9Ka7PYaSffm3x6AepnIV3rPk/RZSY+b2a5s\n2QZJmyR9z8yukvRLSavq0+LkN23+Hybrr/9xZ7J+xT/+MFn/8/fenazX0/r+9HDcQ/9efjivfev/\nJredc5ShvHqqGH53/5mkcvN9X5RvOwAahU/4AUERfiAowg8ERfiBoAg/EBThB4Li0t0TNK3zD8rW\nhrbMSm775QUPJOurZw9U1VMe1u1blqzvvDk9Rffc7z+RrLcfZKy+WXHmB4Ii/EBQhB8IivADQRF+\nICjCDwRF+IGgwozzH/6z9GWiD//lULK+4dT7ytaWv/vNqnrKy8DIW2Vr5+9Yn9z2tL/7RbLe/lp6\nnP5osopmxpkfCIrwA0ERfiAowg8ERfiBoAg/EBThB4IKM87fd1n679yzZ95Vt33f9NrCZH3zA8uT\ndRspd+X0Uadd/2LZ2qKBnuS2I8kqpjLO/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QlLl7egWzeZK2\nS+qQ5JK63H2zmW2U9EVJr2SrbnD38l96l3SitfvZxqzeQL30eLcO+FD6gyGZiXzI54ik9e6+08xm\nS3rUzO7Pat9y929U2yiA4lQMv7v3S+rPbh80s6clnVzvxgDU13G95jez+ZLOknTsM6PrzGy3mW0x\nszlltllrZr1m1jusQzU1CyA/Ew6/mZ0g6QeSrnH3A5JulrRQ0hKNPjP45njbuXuXu5fcvdSqthxa\nBpCHCYXfzFo1Gvxb3f1uSXL3AXcfcfejkm6RtLR+bQLIW8Xwm5lJ+o6kp939hjHLO8esdrmk9HSt\nAJrKRN7tP0/SZyU9bma7smUbJK02syUaHf7rk/SlunQIoC4m8m7/zySNN26YHNMH0Nz4hB8QFOEH\ngiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiCoipfuznVnZq9I+uWY\nRXMl7W9YA8enWXtr1r4keqtWnr2d4u7vn8iKDQ3/O3Zu1uvupcIaSGjW3pq1L4neqlVUbzztB4Ii\n/EBQRYe/q+D9pzRrb83al0Rv1Sqkt0Jf8wMoTtFnfgAFKST8ZrbCzJ4xs+fN7LoieijHzPrM7HEz\n22VmvQX3ssXMBs3siTHL2s3sfjN7Lvs97jRpBfW20cz2Zcdul5ldUlBv88zsJ2b2lJk9aWZ/kS0v\n9Ngl+irkuDX8ab+ZtUh6VtLFkvZKekTSand/qqGNlGFmfZJK7l74mLCZnS/pDUnb3X1xtuxfJQ25\n+6bsD+ccd7+2SXrbKOmNomduziaU6Rw7s7SkyyR9TgUeu0Rfq1TAcSvizL9U0vPuvsfdD0u6Q9LK\nAvpoeu7+oKShty1eKWlbdnubRv/zNFyZ3pqCu/e7+87s9kFJx2aWLvTYJfoqRBHhP1nSr8bc36vm\nmvLbJf3YzB41s7VFNzOOjmzadEl6WVJHkc2Mo+LMzY30tpmlm+bYVTPjdd54w++dlrn7xyV9QtLV\n2dPbpuSjr9maabhmQjM3N8o4M0v/TpHHrtoZr/NWRPj3SZo35v4Hs2VNwd33Zb8HJd2j5pt9eODY\nJKnZ78GC+/mdZpq5ebyZpdUEx66ZZrwuIvyPSFpkZgvMbLqkT0vaUUAf72Bms7I3YmRmsyQtV/PN\nPrxD0prs9hpJ9xbYy+9plpmby80srYKPXdPNeO3uDf+RdIlG3/F/QdLfFtFDmb4+JOmx7OfJonuT\ndLtGnwYOa/S9kaskvU9St6TnJP23pPYm6u27kh6XtFujQessqLdlGn1Kv1vSruznkqKPXaKvQo4b\nn/ADguINPyAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQf0/sEWOix6VKakAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "RRXrKqyGNxtP",
"colab_type": "code",
"colab": {}
},
"source": [
"\n",
"# 5. Preprocess input data\n",
"X_train = X_train.reshape(X_train.shape[0], 1, 28, 28)\n",
"X_test = X_test.reshape(X_test.shape[0], 1, 28, 28)\n",
"X_train = X_train.astype('float32')\n",
"X_test = X_test.astype('float32')\n",
"X_train /= 255\n",
"X_test /= 255"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "EpkYOH1AMsUH",
"colab_type": "code",
"colab": {}
},
"source": [
"\n",
"# 6. Preprocess class labels\n",
"Y_train = np_utils.to_categorical(y_train, 10)\n",
"Y_test = np_utils.to_categorical(y_test, 10)\n",
" "
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "IjMVRF-XMvD0",
"colab_type": "code",
"outputId": "ade88e06-0bfd-4ad1-bf4f-cb2eaebcd43c",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 289
}
},
"source": [
"\n",
"# 7. Define model architecture\n",
"model = Sequential()\n",
" \n",
"model.add(Convolution2D(32, (3, 3), activation='relu', input_shape=(1,28,28), data_format='channels_first'))\n",
"model.add(Convolution2D(32, (3, 3), activation='relu'))\n",
"model.add(MaxPooling2D(pool_size=(2,2)))\n",
"model.add(Dropout(0.25))\n",
" \n",
"model.add(Flatten())\n",
"model.add(Dense(128, activation='relu'))\n",
"model.add(Dropout(0.5))\n",
"model.add(Dense(10, activation='softmax'))\n",
" "
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"text": [
"WARNING: Logging before flag parsing goes to stderr.\n",
"W0811 14:23:17.711136 139952059987840 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.\n",
"\n",
"W0811 14:23:17.754603 139952059987840 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n",
"\n",
"W0811 14:23:17.762784 139952059987840 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.\n",
"\n",
"W0811 14:23:17.781716 139952059987840 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:174: The name tf.get_default_session is deprecated. Please use tf.compat.v1.get_default_session instead.\n",
"\n",
"W0811 14:23:17.782824 139952059987840 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:181: The name tf.ConfigProto is deprecated. Please use tf.compat.v1.ConfigProto instead.\n",
"\n",
"W0811 14:23:18.153770 139952059987840 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:3976: The name tf.nn.max_pool is deprecated. Please use tf.nn.max_pool2d instead.\n",
"\n",
"W0811 14:23:18.167092 139952059987840 deprecation.py:506] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:3445: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "kGUfvB2qMyLJ",
"colab_type": "code",
"outputId": "404713bf-9a78-4704-bc6f-7dbb7a80b9eb",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 493
}
},
"source": [
"\n",
"# 8. Compile model\n",
"model.compile(loss='categorical_crossentropy',\n",
" optimizer='adam',\n",
" metrics=['accuracy'])\n",
" \n",
"# 9. Fit model on training data\n",
"model.fit(X_train, Y_train,\n",
" batch_size=32, nb_epoch=10, verbose=1)\n",
" "
],
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"text": [
"W0811 14:23:18.255020 139952059987840 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/optimizers.py:790: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n",
"\n",
"/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:7: UserWarning: The `nb_epoch` argument in `fit` has been renamed `epochs`.\n",
" import sys\n",
"W0811 14:23:18.534861 139952059987840 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"60000/60000 [==============================] - 149s 2ms/step - loss: 0.2517 - acc: 0.9240\n",
"Epoch 2/10\n",
"60000/60000 [==============================] - 148s 2ms/step - loss: 0.1028 - acc: 0.9693\n",
"Epoch 3/10\n",
"60000/60000 [==============================] - 148s 2ms/step - loss: 0.0795 - acc: 0.9767\n",
"Epoch 4/10\n",
"60000/60000 [==============================] - 148s 2ms/step - loss: 0.0659 - acc: 0.9803\n",
"Epoch 5/10\n",
"60000/60000 [==============================] - 148s 2ms/step - loss: 0.0574 - acc: 0.9826\n",
"Epoch 6/10\n",
"60000/60000 [==============================] - 148s 2ms/step - loss: 0.0525 - acc: 0.9843\n",
"Epoch 7/10\n",
"60000/60000 [==============================] - 147s 2ms/step - loss: 0.0451 - acc: 0.9860\n",
"Epoch 8/10\n",
"60000/60000 [==============================] - 148s 2ms/step - loss: 0.0428 - acc: 0.9865\n",
"Epoch 9/10\n",
"60000/60000 [==============================] - 148s 2ms/step - loss: 0.0382 - acc: 0.9881\n",
"Epoch 10/10\n",
"60000/60000 [==============================] - 148s 2ms/step - loss: 0.0354 - acc: 0.9890\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7f48df84e0f0>"
]
},
"metadata": {
"tags": []
},
"execution_count": 6
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "QfqX-GVFD3Ny",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "a657d3f3-7dfc-4924-928c-474ec6e840e5"
},
"source": [
"# 10. Evaluate model on test data\n",
"\n",
"loss, accuracy = model.evaluate(X_test, Y_test, verbose=0)\n",
"print(f\"Test accuracy: {100*accuracy:.2f}%\")"
],
"execution_count": 12,
"outputs": [
{
"output_type": "stream",
"text": [
"Test accuracy: 99.01%\n"
],
"name": "stdout"
}
]
}
]
}
@Elijas
Copy link
Author

Elijas commented Aug 11, 2019

Note: had to add several code fixes to the tutorial code because of recent updates to the library

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment