Created
November 20, 2024 07:51
-
-
Save 4rtemi5/efd2c7e504254d96546aa48500c85702 to your computer and use it in GitHub Desktop.
masked_convolution_with_dropout.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "provenance": [], | |
| "gpuType": "T4", | |
| "authorship_tag": "ABX9TyNuSvvzHF/nEwE3pLDwRIDs", | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| }, | |
| "language_info": { | |
| "name": "python" | |
| }, | |
| "accelerator": "GPU" | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/4rtemi5/efd2c7e504254d96546aa48500c85702/masked_convolution_with_dropout.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "HflDeHxoPfaM" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import os\n", | |
| "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n", | |
| "\n", | |
| "import keras\n", | |
| "from keras import layers, ops\n", | |
| "import numpy as np" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "class MaskedConv2DWithDropout(layers.Conv2D):\n", | |
| " def __init__(self, rate=0.0, **kwargs):\n", | |
| " super().__init__(**kwargs)\n", | |
| " self.rate = rate\n", | |
| " self.seed_generator = keras.random.SeedGenerator(42)\n", | |
| "\n", | |
| " def build(self, input_shape):\n", | |
| " super().build(input_shape)\n", | |
| " self.output_scaling = (\n", | |
| " self.kernel_size[0] * self.kernel_size[1] * input_shape[-1]\n", | |
| " )\n", | |
| "\n", | |
| " def call(self, inputs, training=None):\n", | |
| " mask = ops.ones_like(inputs)\n", | |
| " if training:\n", | |
| " mask = keras.random.dropout(\n", | |
| " mask, rate=self.rate, seed=self.seed_generator\n", | |
| " )\n", | |
| " inputs = inputs * mask\n", | |
| "\n", | |
| " mask_sum = self.convolution_op(\n", | |
| " mask, ops.ones_like(self.kernel)\n", | |
| " )\n", | |
| " result = self.convolution_op(\n", | |
| " inputs, self.kernel\n", | |
| " )\n", | |
| " result = result / ops.clip(mask_sum, 1, None) * self.output_scaling\n", | |
| " if self.use_bias:\n", | |
| " result = result + self.bias\n", | |
| " return result" | |
| ], | |
| "metadata": { | |
| "id": "D1C2Z4sfSkER" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Model / data parameters\n", | |
| "num_classes = 10\n", | |
| "input_shape = (28, 28, 1)\n", | |
| "batch_size = 64\n", | |
| "epochs = 30\n", | |
| "dropout_rate = 0.4\n", | |
| "padding = \"valid\"\n", | |
| "\n", | |
| "# the data, split between train and test sets\n", | |
| "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n", | |
| "\n", | |
| "# Scale images to the [0, 1] range\n", | |
| "x_train = x_train.astype(\"float32\") / 255\n", | |
| "x_test = x_test.astype(\"float32\") / 255\n", | |
| "# Make sure images have shape (28, 28, 1)\n", | |
| "x_train = np.expand_dims(x_train, -1)\n", | |
| "x_test = np.expand_dims(x_test, -1)\n", | |
| "print(\"x_train shape:\", x_train.shape)\n", | |
| "print(x_train.shape[0], \"train samples\")\n", | |
| "print(x_test.shape[0], \"test samples\")\n", | |
| "\n", | |
| "# convert class vectors to binary class matrices\n", | |
| "y_train = keras.utils.to_categorical(y_train, num_classes)\n", | |
| "y_test = keras.utils.to_categorical(y_test, num_classes)" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "gdf7CRQMSWY_", | |
| "outputId": "7c929c50-f650-4a2d-d036-da40cc691bb4" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "x_train shape: (60000, 28, 28, 1)\n", | |
| "60000 train samples\n", | |
| "10000 test samples\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "model = keras.Sequential(\n", | |
| " [\n", | |
| " keras.layers.Input(shape=input_shape),\n", | |
| " MaskedConv2DWithDropout(rate=dropout_rate, filters=32, kernel_size=(3, 3), activation=\"relu\", padding=padding),\n", | |
| " layers.MaxPooling2D(pool_size=(2, 2)),\n", | |
| " MaskedConv2DWithDropout(rate=dropout_rate, filters=64, kernel_size=(3, 3), activation=\"relu\", padding=padding),\n", | |
| " layers.MaxPooling2D(pool_size=(2, 2)),\n", | |
| " layers.Flatten(),\n", | |
| " layers.Dense(num_classes, activation=\"softmax\"),\n", | |
| " ]\n", | |
| ")\n", | |
| "\n", | |
| "model.summary()\n", | |
| "\n", | |
| "model.compile(loss=\"categorical_crossentropy\", optimizer=\"adam\", metrics=[\"accuracy\"])\n", | |
| "\n", | |
| "model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)\n", | |
| "\n", | |
| "eval_metrics = model.evaluate(x_test, y_test, verbose=0)\n", | |
| "\n", | |
| "print(f\"Test loss: {eval_metrics[0]}\")\n", | |
| "print(f\"Test accuracy: {eval_metrics[1]}\")" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 1000 | |
| }, | |
| "id": "ujhf6da1S8HO", | |
| "outputId": "a87a9482-befe-4391-9e8f-b4ac280705cc" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| "\u001b[1mModel: \"sequential\"\u001b[0m\n" | |
| ], | |
| "text/html": [ | |
| "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"sequential\"</span>\n", | |
| "</pre>\n" | |
| ] | |
| }, | |
| "metadata": {} | |
| }, | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓\n", | |
| "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", | |
| "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩\n", | |
| "│ masked_conv2d_with_dropout │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m26\u001b[0m, \u001b[38;5;34m26\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m320\u001b[0m │\n", | |
| "│ (\u001b[38;5;33mMaskedConv2DWithDropout\u001b[0m) │ │ │\n", | |
| "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", | |
| "│ max_pooling2d (\u001b[38;5;33mMaxPooling2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m13\u001b[0m, \u001b[38;5;34m13\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", | |
| "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", | |
| "│ masked_conv2d_with_dropout_1 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m18,496\u001b[0m │\n", | |
| "│ (\u001b[38;5;33mMaskedConv2DWithDropout\u001b[0m) │ │ │\n", | |
| "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", | |
| "│ max_pooling2d_1 (\u001b[38;5;33mMaxPooling2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m5\u001b[0m, \u001b[38;5;34m5\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", | |
| "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", | |
| "│ flatten (\u001b[38;5;33mFlatten\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1600\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", | |
| "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", | |
| "│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m10\u001b[0m) │ \u001b[38;5;34m16,010\u001b[0m │\n", | |
| "└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘\n" | |
| ], | |
| "text/html": [ | |
| "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓\n", | |
| "┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃\n", | |
| "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩\n", | |
| "│ masked_conv2d_with_dropout │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">26</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">26</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">320</span> │\n", | |
| "│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">MaskedConv2DWithDropout</span>) │ │ │\n", | |
| "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", | |
| "│ max_pooling2d (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">MaxPooling2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">13</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">13</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n", | |
| "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", | |
| "│ masked_conv2d_with_dropout_1 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">18,496</span> │\n", | |
| "│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">MaskedConv2DWithDropout</span>) │ │ │\n", | |
| "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", | |
| "│ max_pooling2d_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">MaxPooling2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">5</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">5</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n", | |
| "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", | |
| "│ flatten (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Flatten</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1600</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n", | |
| "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", | |
| "│ dense (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">10</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">16,010</span> │\n", | |
| "└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘\n", | |
| "</pre>\n" | |
| ] | |
| }, | |
| "metadata": {} | |
| }, | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m34,826\u001b[0m (136.04 KB)\n" | |
| ], | |
| "text/html": [ | |
| "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">34,826</span> (136.04 KB)\n", | |
| "</pre>\n" | |
| ] | |
| }, | |
| "metadata": {} | |
| }, | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m34,826\u001b[0m (136.04 KB)\n" | |
| ], | |
| "text/html": [ | |
| "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">34,826</span> (136.04 KB)\n", | |
| "</pre>\n" | |
| ] | |
| }, | |
| "metadata": {} | |
| }, | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" | |
| ], | |
| "text/html": [ | |
| "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n", | |
| "</pre>\n" | |
| ] | |
| }, | |
| "metadata": {} | |
| }, | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Epoch 1/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 22ms/step - accuracy: 0.7885 - loss: 0.7025 - val_accuracy: 0.9627 - val_loss: 0.1433\n", | |
| "Epoch 2/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 4ms/step - accuracy: 0.9466 - loss: 0.1779 - val_accuracy: 0.9688 - val_loss: 0.1109\n", | |
| "Epoch 3/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9605 - loss: 0.1292 - val_accuracy: 0.9758 - val_loss: 0.0851\n", | |
| "Epoch 4/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9661 - loss: 0.1145 - val_accuracy: 0.9798 - val_loss: 0.0746\n", | |
| "Epoch 5/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 5ms/step - accuracy: 0.9677 - loss: 0.1048 - val_accuracy: 0.9782 - val_loss: 0.0811\n", | |
| "Epoch 6/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9712 - loss: 0.0969 - val_accuracy: 0.9783 - val_loss: 0.0747\n", | |
| "Epoch 7/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9717 - loss: 0.0909 - val_accuracy: 0.9798 - val_loss: 0.0709\n", | |
| "Epoch 8/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9728 - loss: 0.0890 - val_accuracy: 0.9767 - val_loss: 0.0857\n", | |
| "Epoch 9/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9739 - loss: 0.0881 - val_accuracy: 0.9770 - val_loss: 0.0780\n", | |
| "Epoch 10/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9735 - loss: 0.0807 - val_accuracy: 0.9802 - val_loss: 0.0735\n", | |
| "Epoch 11/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9751 - loss: 0.0804 - val_accuracy: 0.9790 - val_loss: 0.0756\n", | |
| "Epoch 12/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 5ms/step - accuracy: 0.9747 - loss: 0.0798 - val_accuracy: 0.9822 - val_loss: 0.0645\n", | |
| "Epoch 13/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 7ms/step - accuracy: 0.9762 - loss: 0.0775 - val_accuracy: 0.9813 - val_loss: 0.0704\n", | |
| "Epoch 14/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - accuracy: 0.9752 - loss: 0.0766 - val_accuracy: 0.9850 - val_loss: 0.0594\n", | |
| "Epoch 15/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9760 - loss: 0.0746 - val_accuracy: 0.9823 - val_loss: 0.0626\n", | |
| "Epoch 16/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 5ms/step - accuracy: 0.9765 - loss: 0.0769 - val_accuracy: 0.9802 - val_loss: 0.0710\n", | |
| "Epoch 17/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 7ms/step - accuracy: 0.9781 - loss: 0.0702 - val_accuracy: 0.9803 - val_loss: 0.0712\n", | |
| "Epoch 18/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 5ms/step - accuracy: 0.9782 - loss: 0.0722 - val_accuracy: 0.9830 - val_loss: 0.0615\n", | |
| "Epoch 19/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9761 - loss: 0.0713 - val_accuracy: 0.9837 - val_loss: 0.0596\n", | |
| "Epoch 20/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9793 - loss: 0.0686 - val_accuracy: 0.9852 - val_loss: 0.0529\n", | |
| "Test loss: 0.05142683535814285\n", | |
| "Test accuracy: 0.9831000566482544\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "model = keras.Sequential(\n", | |
| " [\n", | |
| " keras.layers.Input(shape=input_shape),\n", | |
| " layers.Dropout(dropout_rate),\n", | |
| " layers.Conv2D(32, kernel_size=(3, 3), activation=\"relu\", padding=padding),\n", | |
| " layers.MaxPooling2D(pool_size=(2, 2)),\n", | |
| " layers.Dropout(dropout_rate),\n", | |
| " layers.Conv2D(64, kernel_size=(3, 3), activation=\"relu\", padding=padding),\n", | |
| " layers.MaxPooling2D(pool_size=(2, 2)),\n", | |
| " layers.Flatten(),\n", | |
| " layers.Dense(num_classes, activation=\"softmax\"),\n", | |
| " ]\n", | |
| ")\n", | |
| "\n", | |
| "model.summary()\n", | |
| "\n", | |
| "model.compile(loss=\"categorical_crossentropy\", optimizer=\"adam\", metrics=[\"accuracy\"])\n", | |
| "\n", | |
| "model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)\n", | |
| "\n", | |
| "eval_metrics = model.evaluate(x_test, y_test, verbose=0)\n", | |
| "\n", | |
| "print(f\"Test loss: {eval_metrics[0]}\")\n", | |
| "print(f\"Test accuracy: {eval_metrics[1]}\")" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 1000 | |
| }, | |
| "id": "7nErbLprS8O1", | |
| "outputId": "30f1c6e7-e05f-448c-efd2-2720c395855a" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| "\u001b[1mModel: \"sequential_1\"\u001b[0m\n" | |
| ], | |
| "text/html": [ | |
| "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"sequential_1\"</span>\n", | |
| "</pre>\n" | |
| ] | |
| }, | |
| "metadata": {} | |
| }, | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓\n", | |
| "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", | |
| "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩\n", | |
| "│ dropout (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", | |
| "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", | |
| "│ conv2d (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m26\u001b[0m, \u001b[38;5;34m26\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m320\u001b[0m │\n", | |
| "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", | |
| "│ max_pooling2d_2 (\u001b[38;5;33mMaxPooling2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m13\u001b[0m, \u001b[38;5;34m13\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", | |
| "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", | |
| "│ dropout_1 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m13\u001b[0m, \u001b[38;5;34m13\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", | |
| "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", | |
| "│ conv2d_1 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m18,496\u001b[0m │\n", | |
| "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", | |
| "│ max_pooling2d_3 (\u001b[38;5;33mMaxPooling2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m5\u001b[0m, \u001b[38;5;34m5\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", | |
| "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", | |
| "│ flatten_1 (\u001b[38;5;33mFlatten\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1600\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", | |
| "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", | |
| "│ dense_1 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m10\u001b[0m) │ \u001b[38;5;34m16,010\u001b[0m │\n", | |
| "└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘\n" | |
| ], | |
| "text/html": [ | |
| "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓\n", | |
| "┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃\n", | |
| "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩\n", | |
| "│ dropout (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dropout</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n", | |
| "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", | |
| "│ conv2d (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">26</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">26</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">320</span> │\n", | |
| "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", | |
| "│ max_pooling2d_2 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">MaxPooling2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">13</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">13</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n", | |
| "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", | |
| "│ dropout_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dropout</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">13</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">13</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n", | |
| "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", | |
| "│ conv2d_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">18,496</span> │\n", | |
| "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", | |
| "│ max_pooling2d_3 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">MaxPooling2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">5</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">5</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n", | |
| "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", | |
| "│ flatten_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Flatten</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1600</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n", | |
| "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", | |
| "│ dense_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">10</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">16,010</span> │\n", | |
| "└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘\n", | |
| "</pre>\n" | |
| ] | |
| }, | |
| "metadata": {} | |
| }, | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m34,826\u001b[0m (136.04 KB)\n" | |
| ], | |
| "text/html": [ | |
| "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">34,826</span> (136.04 KB)\n", | |
| "</pre>\n" | |
| ] | |
| }, | |
| "metadata": {} | |
| }, | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m34,826\u001b[0m (136.04 KB)\n" | |
| ], | |
| "text/html": [ | |
| "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">34,826</span> (136.04 KB)\n", | |
| "</pre>\n" | |
| ] | |
| }, | |
| "metadata": {} | |
| }, | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" | |
| ], | |
| "text/html": [ | |
| "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n", | |
| "</pre>\n" | |
| ] | |
| }, | |
| "metadata": {} | |
| }, | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Epoch 1/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 12ms/step - accuracy: 0.7860 - loss: 0.7278 - val_accuracy: 0.9683 - val_loss: 0.1552\n", | |
| "Epoch 2/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9535 - loss: 0.1478 - val_accuracy: 0.9773 - val_loss: 0.1407\n", | |
| "Epoch 3/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 4ms/step - accuracy: 0.9648 - loss: 0.1125 - val_accuracy: 0.9797 - val_loss: 0.1379\n", | |
| "Epoch 4/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step - accuracy: 0.9692 - loss: 0.0967 - val_accuracy: 0.9812 - val_loss: 0.1157\n", | |
| "Epoch 5/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step - accuracy: 0.9723 - loss: 0.0851 - val_accuracy: 0.9802 - val_loss: 0.1207\n", | |
| "Epoch 6/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9728 - loss: 0.0840 - val_accuracy: 0.9818 - val_loss: 0.1065\n", | |
| "Epoch 7/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9768 - loss: 0.0729 - val_accuracy: 0.9818 - val_loss: 0.1065\n", | |
| "Epoch 8/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9777 - loss: 0.0721 - val_accuracy: 0.9820 - val_loss: 0.1130\n", | |
| "Epoch 9/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 4ms/step - accuracy: 0.9777 - loss: 0.0704 - val_accuracy: 0.9820 - val_loss: 0.0920\n", | |
| "Epoch 10/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9787 - loss: 0.0652 - val_accuracy: 0.9833 - val_loss: 0.1043\n", | |
| "Epoch 11/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 4ms/step - accuracy: 0.9800 - loss: 0.0622 - val_accuracy: 0.9830 - val_loss: 0.0939\n", | |
| "Epoch 12/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 4ms/step - accuracy: 0.9813 - loss: 0.0588 - val_accuracy: 0.9817 - val_loss: 0.1109\n", | |
| "Epoch 13/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 4ms/step - accuracy: 0.9825 - loss: 0.0557 - val_accuracy: 0.9817 - val_loss: 0.1005\n", | |
| "Epoch 14/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9811 - loss: 0.0554 - val_accuracy: 0.9825 - val_loss: 0.0965\n", | |
| "Epoch 15/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9829 - loss: 0.0533 - val_accuracy: 0.9840 - val_loss: 0.0987\n", | |
| "Epoch 16/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9832 - loss: 0.0520 - val_accuracy: 0.9788 - val_loss: 0.1099\n", | |
| "Epoch 17/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9830 - loss: 0.0512 - val_accuracy: 0.9838 - val_loss: 0.0947\n", | |
| "Epoch 18/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9829 - loss: 0.0503 - val_accuracy: 0.9835 - val_loss: 0.1032\n", | |
| "Epoch 19/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 4ms/step - accuracy: 0.9829 - loss: 0.0499 - val_accuracy: 0.9842 - val_loss: 0.1041\n", | |
| "Epoch 20/20\n", | |
| "\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step - accuracy: 0.9854 - loss: 0.0464 - val_accuracy: 0.9793 - val_loss: 0.1131\n", | |
| "Test loss: 0.11808445304632187\n", | |
| "Test accuracy: 0.9824000597000122\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [], | |
| "metadata": { | |
| "id": "NhDLuEqmUsrl" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment