Skip to content

Instantly share code, notes, and snippets.

@mattslight
Last active June 14, 2022 11:43
Show Gist options
  • Save mattslight/2faf5a1f62cb73f7bf14885926edd759 to your computer and use it in GitHub Desktop.
Save mattslight/2faf5a1f62cb73f7bf14885926edd759 to your computer and use it in GitHub Desktop.
MNIST
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "MNIST",
"provenance": [],
"collapsed_sections": [
"HJeM_36aA8lN"
],
"authorship_tag": "ABX9TyO7Odq05M3XMEOBhN5QDZPX",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU",
"gpuClass": "standard"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/mattslight/2faf5a1f62cb73f7bf14885926edd759/mnist.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# The Traditional MNIST Digital Detection ML workflow\n",
"\n",
"## Outline\n",
"\n",
"1. Install wandb (A tool used to graph key learning metrics and optimisation)\n",
"2. Download the MNIST dataset\n",
"3. Check the size and shape of the dataset\n",
"4. Build the model\n",
"5. Compile the model\n",
"6. Train the model \n",
"7. Test the the model\n"
],
"metadata": {
"id": "jZAe55h_-xYK"
}
},
{
"cell_type": "markdown",
"source": [
"## 1. Install wandb"
],
"metadata": {
"id": "HJeM_36aA8lN"
}
},
{
"cell_type": "code",
"source": [
"### Install and initalise wandb for graphing\n",
"!pip install wandb &>/dev/null\n",
"!wandb login\n",
"import wandb\n",
"from wandb.keras import WandbCallback\n",
"wandb.init()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 180
},
"id": "x9L_O0d7apDw",
"outputId": "b383665d-bf49-4f9b-eb45-84113b17dd43"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: You can find your API key in your browser here: https://wandb.ai/authorize\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: \n",
"\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmattslight\u001b[0m (\u001b[33mkpsworld\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Tracking run with wandb version 0.12.18"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Run data is saved locally in <code>/content/wandb/run-20220614_113912-1nf4wrb4</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Syncing run <strong><a href=\"https://wandb.ai/kpsworld/uncategorized/runs/1nf4wrb4\" target=\"_blank\">eternal-brook-17</a></strong> to <a href=\"https://wandb.ai/kpsworld/uncategorized\" target=\"_blank\">Weights & Biases</a> (<a href=\"https://wandb.me/run\" target=\"_blank\">docs</a>)<br/>"
]
},
"metadata": {}
},
{
"output_type": "execute_result",
"data": {
"text/html": [
"<button onClick=\"this.nextSibling.style.display='block';this.style.display='none';\">Display W&B run</button><iframe src=\"https://wandb.ai/kpsworld/uncategorized/runs/1nf4wrb4?jupyter=true\" style=\"border:none;width:100%;height:420px;display:none;\"></iframe>"
],
"text/plain": [
"<wandb.sdk.wandb_run.Run at 0x7f6673145050>"
]
},
"metadata": {},
"execution_count": 1
}
]
},
{
"cell_type": "markdown",
"source": [
"## 2. Download the MNIST dataset"
],
"metadata": {
"id": "h0v8WwagAuHs"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7oTkBmSCXZij",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "4e8be988-a89f-4af1-b773-0fd376ea9e08"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n",
"11493376/11490434 [==============================] - 0s 0us/step\n",
"11501568/11490434 [==============================] - 0s 0us/step\n"
]
}
],
"source": [
"from tensorflow.keras.datasets import mnist\n",
"(train_images, train_labels), (test_images, test_labels) = mnist.load_data()"
]
},
{
"cell_type": "markdown",
"source": [
"## 3. Check the size and shape of the dataset"
],
"metadata": {
"id": "2KSipSkcBDYY"
}
},
{
"cell_type": "code",
"source": [
"print(train_images.shape)\n",
"print(train_labels.shape)\n",
"\n",
"print(test_images.shape)\n",
"print(test_labels.shape)\n",
"\n",
"## (60000 = number of training samples, 28 = width of image, 28 = height of image)\n",
"## (60000 = number of labels, )"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "dfbd4oZ2XIIV",
"outputId": "95161eb4-121a-4b0e-e327-812592269c38"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"(60000, 28, 28)\n",
"(60000,)\n",
"(10000, 28, 28)\n",
"(10000,)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(train_images[0])"
],
"metadata": {
"id": "_oP5qY6wrKLS"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## 4. Build the model"
],
"metadata": {
"id": "jZfda6LkBHsa"
}
},
{
"cell_type": "code",
"source": [
"from tensorflow import keras \n",
"from tensorflow.keras import layers\n",
"\n",
"\n",
"'''model = keras.Sequential([\n",
" layers.Dense(512, activation=\"relu\"),\n",
" layers.Dense(10, activation=\"softmax\")\n",
"])\n",
"\n",
"'''\n",
"model = keras.Sequential()\n",
"\n",
"model.add(layers.Conv2D(filters = 32, kernel_size = (5,5),padding = 'Same', activation ='relu', input_shape = (28,28,1)))\n",
"model.add(layers.Conv2D(filters = 32, kernel_size = (5,5),padding = 'Same', activation ='relu'))\n",
"model.add(layers.MaxPool2D(pool_size=(2,2)))\n",
"model.add(layers.Dropout(0.25))\n",
"model.add(layers.Conv2D(filters = 64, kernel_size = (3,3),padding = 'Same', activation ='relu'))\n",
"model.add(layers.Conv2D(filters = 64, kernel_size = (3,3),padding = 'Same', activation ='relu'))\n",
"model.add(layers.MaxPool2D(pool_size=(2,2), strides=(2,2)))\n",
"model.add(layers.Dropout(0.25))\n",
"model.add(layers.Flatten())\n",
"model.add(layers.Dense(256, activation = \"relu\"))\n",
"model.add(layers.Dense(10, activation = \"softmax\"))"
],
"metadata": {
"id": "CaCtiGGXXj8t"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## 5. Compile the model"
],
"metadata": {
"id": "SuJiNf6eBMK2"
}
},
{
"cell_type": "code",
"source": [
"model.compile(optimizer=\"rmsprop\",\n",
" loss=\"sparse_categorical_crossentropy\",\n",
" metrics=[\"accuracy\"])"
],
"metadata": {
"id": "cNRfPqFcXt6t"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#train_images = train_images.reshape((60000, 28 * 28)) ### Flatten a 28 * 28 image into a 1-d vector\n",
"train_images = train_images.astype(\"float32\") / 255 \n",
"#test_images = test_images.reshape((10000, 28 * 28)) ### Flatten a 28 * 28 image into a 1-d vector\n",
"test_images = test_images.astype(\"float32\") / 255"
],
"metadata": {
"id": "V1Z4-21HX0Ku"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## 6. Train (fit) the model \n"
],
"metadata": {
"id": "RoVBRTj7Bhhi"
}
},
{
"cell_type": "code",
"source": [
"model.fit(train_images, train_labels, epochs=5, batch_size=128, callbacks=[WandbCallback(log_weights=True)])"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "WmO8_jWTX7Ad",
"outputId": "5be4659b-03b0-4ccc-a63f-5187e879a454"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch 1/5\n",
"469/469 [==============================] - 17s 13ms/step - loss: 0.1677 - accuracy: 0.9482 - _timestamp: 1655204520.0000 - _runtime: 215.0000\n",
"Epoch 2/5\n",
"469/469 [==============================] - 5s 10ms/step - loss: 0.0482 - accuracy: 0.9849 - _timestamp: 1655204525.0000 - _runtime: 220.0000\n",
"Epoch 3/5\n",
"469/469 [==============================] - 5s 10ms/step - loss: 0.0345 - accuracy: 0.9894 - _timestamp: 1655204530.0000 - _runtime: 225.0000\n",
"Epoch 4/5\n",
"469/469 [==============================] - 5s 10ms/step - loss: 0.0276 - accuracy: 0.9912 - _timestamp: 1655204534.0000 - _runtime: 229.0000\n",
"Epoch 5/5\n",
"469/469 [==============================] - 5s 10ms/step - loss: 0.0242 - accuracy: 0.9931 - _timestamp: 1655204539.0000 - _runtime: 234.0000\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7fa65626bd50>"
]
},
"metadata": {},
"execution_count": 12
}
]
},
{
"cell_type": "markdown",
"source": [
"## 7. Test the the model"
],
"metadata": {
"id": "N7AMRczMBnwv"
}
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"predictions = model.predict(test_images[0:20])\n",
"predictions = np.argmax(predictions, axis=1)\n",
"out = np.around(predictions, decimals=2)\n",
"print(out[0:20])\n",
"print(test_labels[0:20])"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ZcEUu3PzbXiQ",
"outputId": "ab49d3c3-8164-48cd-921d-a5f40238e9c5"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[7 2 1 0 4 1 4 9 5 9 0 6 9 0 1 5 9 7 3 4]\n",
"[7 2 1 0 4 1 4 9 5 9 0 6 9 0 1 5 9 7 3 4]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"test_loss, test_acc = model.evaluate(test_images, test_labels)\n",
"print(f\"test_acc: {test_acc}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "p0TXIlkGcKfr",
"outputId": "64d2014d-369c-4e44-c4c4-ca6d7ac84e91"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"313/313 [==============================] - 11s 36ms/step - loss: 0.0316 - accuracy: 0.9889\n",
"test_acc: 0.9889000058174133\n"
]
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment