Skip to content

Instantly share code, notes, and snippets.

@4rtemi5
Created November 20, 2024 07:51
Show Gist options
  • Select an option

  • Save 4rtemi5/efd2c7e504254d96546aa48500c85702 to your computer and use it in GitHub Desktop.

Select an option

Save 4rtemi5/efd2c7e504254d96546aa48500c85702 to your computer and use it in GitHub Desktop.
masked_convolution_with_dropout.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4",
"authorship_tag": "ABX9TyNuSvvzHF/nEwE3pLDwRIDs",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/4rtemi5/efd2c7e504254d96546aa48500c85702/masked_convolution_with_dropout.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HflDeHxoPfaM"
},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"KERAS_BACKEND\"] = \"jax\"\n",
"\n",
"import keras\n",
"from keras import layers, ops\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"source": [
"class MaskedConv2DWithDropout(layers.Conv2D):\n",
" def __init__(self, rate=0.0, **kwargs):\n",
" super().__init__(**kwargs)\n",
" self.rate = rate\n",
" self.seed_generator = keras.random.SeedGenerator(42)\n",
"\n",
" def build(self, input_shape):\n",
" super().build(input_shape)\n",
" self.output_scaling = (\n",
" self.kernel_size[0] * self.kernel_size[1] * input_shape[-1]\n",
" )\n",
"\n",
" def call(self, inputs, training=None):\n",
" mask = ops.ones_like(inputs)\n",
" if training:\n",
" mask = keras.random.dropout(\n",
" mask, rate=self.rate, seed=self.seed_generator\n",
" )\n",
" inputs = inputs * mask\n",
"\n",
" mask_sum = self.convolution_op(\n",
" mask, ops.ones_like(self.kernel)\n",
" )\n",
" result = self.convolution_op(\n",
" inputs, self.kernel\n",
" )\n",
" result = result / ops.clip(mask_sum, 1, None) * self.output_scaling\n",
" if self.use_bias:\n",
" result = result + self.bias\n",
" return result"
],
"metadata": {
"id": "D1C2Z4sfSkER"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Model / data parameters\n",
"num_classes = 10\n",
"input_shape = (28, 28, 1)\n",
"batch_size = 64\n",
"epochs = 30\n",
"dropout_rate = 0.4\n",
"padding = \"valid\"\n",
"\n",
"# the data, split between train and test sets\n",
"(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n",
"\n",
"# Scale images to the [0, 1] range\n",
"x_train = x_train.astype(\"float32\") / 255\n",
"x_test = x_test.astype(\"float32\") / 255\n",
"# Make sure images have shape (28, 28, 1)\n",
"x_train = np.expand_dims(x_train, -1)\n",
"x_test = np.expand_dims(x_test, -1)\n",
"print(\"x_train shape:\", x_train.shape)\n",
"print(x_train.shape[0], \"train samples\")\n",
"print(x_test.shape[0], \"test samples\")\n",
"\n",
"# convert class vectors to binary class matrices\n",
"y_train = keras.utils.to_categorical(y_train, num_classes)\n",
"y_test = keras.utils.to_categorical(y_test, num_classes)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gdf7CRQMSWY_",
"outputId": "7c929c50-f650-4a2d-d036-da40cc691bb4"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"x_train shape: (60000, 28, 28, 1)\n",
"60000 train samples\n",
"10000 test samples\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"model = keras.Sequential(\n",
" [\n",
" keras.layers.Input(shape=input_shape),\n",
" MaskedConv2DWithDropout(rate=dropout_rate, filters=32, kernel_size=(3, 3), activation=\"relu\", padding=padding),\n",
" layers.MaxPooling2D(pool_size=(2, 2)),\n",
" MaskedConv2DWithDropout(rate=dropout_rate, filters=64, kernel_size=(3, 3), activation=\"relu\", padding=padding),\n",
" layers.MaxPooling2D(pool_size=(2, 2)),\n",
" layers.Flatten(),\n",
" layers.Dense(num_classes, activation=\"softmax\"),\n",
" ]\n",
")\n",
"\n",
"model.summary()\n",
"\n",
"model.compile(loss=\"categorical_crossentropy\", optimizer=\"adam\", metrics=[\"accuracy\"])\n",
"\n",
"model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)\n",
"\n",
"eval_metrics = model.evaluate(x_test, y_test, verbose=0)\n",
"\n",
"print(f\"Test loss: {eval_metrics[0]}\")\n",
"print(f\"Test accuracy: {eval_metrics[1]}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "ujhf6da1S8HO",
"outputId": "a87a9482-befe-4391-9e8f-b4ac280705cc"
},
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"\u001b[1mModel: \"sequential\"\u001b[0m\n"
],
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"sequential\"</span>\n",
"</pre>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩\n",
"│ masked_conv2d_with_dropout │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m26\u001b[0m, \u001b[38;5;34m26\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m320\u001b[0m │\n",
"│ (\u001b[38;5;33mMaskedConv2DWithDropout\u001b[0m) │ │ │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ max_pooling2d (\u001b[38;5;33mMaxPooling2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m13\u001b[0m, \u001b[38;5;34m13\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ masked_conv2d_with_dropout_1 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m18,496\u001b[0m │\n",
"│ (\u001b[38;5;33mMaskedConv2DWithDropout\u001b[0m) │ │ │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ max_pooling2d_1 (\u001b[38;5;33mMaxPooling2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m5\u001b[0m, \u001b[38;5;34m5\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ flatten (\u001b[38;5;33mFlatten\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1600\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m10\u001b[0m) │ \u001b[38;5;34m16,010\u001b[0m │\n",
"└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘\n"
],
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩\n",
"│ masked_conv2d_with_dropout │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">26</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">26</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">320</span> │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">MaskedConv2DWithDropout</span>) │ │ │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ max_pooling2d (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">MaxPooling2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">13</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">13</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ masked_conv2d_with_dropout_1 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">18,496</span> │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">MaskedConv2DWithDropout</span>) │ │ │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ max_pooling2d_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">MaxPooling2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">5</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">5</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ flatten (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Flatten</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1600</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ dense (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">10</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">16,010</span> │\n",
"└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘\n",
"</pre>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m34,826\u001b[0m (136.04 KB)\n"
],
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">34,826</span> (136.04 KB)\n",
"</pre>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m34,826\u001b[0m (136.04 KB)\n"
],
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">34,826</span> (136.04 KB)\n",
"</pre>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
],
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
"</pre>\n"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch 1/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 22ms/step - accuracy: 0.7885 - loss: 0.7025 - val_accuracy: 0.9627 - val_loss: 0.1433\n",
"Epoch 2/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 4ms/step - accuracy: 0.9466 - loss: 0.1779 - val_accuracy: 0.9688 - val_loss: 0.1109\n",
"Epoch 3/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9605 - loss: 0.1292 - val_accuracy: 0.9758 - val_loss: 0.0851\n",
"Epoch 4/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9661 - loss: 0.1145 - val_accuracy: 0.9798 - val_loss: 0.0746\n",
"Epoch 5/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 5ms/step - accuracy: 0.9677 - loss: 0.1048 - val_accuracy: 0.9782 - val_loss: 0.0811\n",
"Epoch 6/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9712 - loss: 0.0969 - val_accuracy: 0.9783 - val_loss: 0.0747\n",
"Epoch 7/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9717 - loss: 0.0909 - val_accuracy: 0.9798 - val_loss: 0.0709\n",
"Epoch 8/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9728 - loss: 0.0890 - val_accuracy: 0.9767 - val_loss: 0.0857\n",
"Epoch 9/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9739 - loss: 0.0881 - val_accuracy: 0.9770 - val_loss: 0.0780\n",
"Epoch 10/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9735 - loss: 0.0807 - val_accuracy: 0.9802 - val_loss: 0.0735\n",
"Epoch 11/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9751 - loss: 0.0804 - val_accuracy: 0.9790 - val_loss: 0.0756\n",
"Epoch 12/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 5ms/step - accuracy: 0.9747 - loss: 0.0798 - val_accuracy: 0.9822 - val_loss: 0.0645\n",
"Epoch 13/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 7ms/step - accuracy: 0.9762 - loss: 0.0775 - val_accuracy: 0.9813 - val_loss: 0.0704\n",
"Epoch 14/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - accuracy: 0.9752 - loss: 0.0766 - val_accuracy: 0.9850 - val_loss: 0.0594\n",
"Epoch 15/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9760 - loss: 0.0746 - val_accuracy: 0.9823 - val_loss: 0.0626\n",
"Epoch 16/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 5ms/step - accuracy: 0.9765 - loss: 0.0769 - val_accuracy: 0.9802 - val_loss: 0.0710\n",
"Epoch 17/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 7ms/step - accuracy: 0.9781 - loss: 0.0702 - val_accuracy: 0.9803 - val_loss: 0.0712\n",
"Epoch 18/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 5ms/step - accuracy: 0.9782 - loss: 0.0722 - val_accuracy: 0.9830 - val_loss: 0.0615\n",
"Epoch 19/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9761 - loss: 0.0713 - val_accuracy: 0.9837 - val_loss: 0.0596\n",
"Epoch 20/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9793 - loss: 0.0686 - val_accuracy: 0.9852 - val_loss: 0.0529\n",
"Test loss: 0.05142683535814285\n",
"Test accuracy: 0.9831000566482544\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"model = keras.Sequential(\n",
" [\n",
" keras.layers.Input(shape=input_shape),\n",
" layers.Dropout(dropout_rate),\n",
" layers.Conv2D(32, kernel_size=(3, 3), activation=\"relu\", padding=padding),\n",
" layers.MaxPooling2D(pool_size=(2, 2)),\n",
" layers.Dropout(dropout_rate),\n",
" layers.Conv2D(64, kernel_size=(3, 3), activation=\"relu\", padding=padding),\n",
" layers.MaxPooling2D(pool_size=(2, 2)),\n",
" layers.Flatten(),\n",
" layers.Dense(num_classes, activation=\"softmax\"),\n",
" ]\n",
")\n",
"\n",
"model.summary()\n",
"\n",
"model.compile(loss=\"categorical_crossentropy\", optimizer=\"adam\", metrics=[\"accuracy\"])\n",
"\n",
"model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)\n",
"\n",
"eval_metrics = model.evaluate(x_test, y_test, verbose=0)\n",
"\n",
"print(f\"Test loss: {eval_metrics[0]}\")\n",
"print(f\"Test accuracy: {eval_metrics[1]}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "7nErbLprS8O1",
"outputId": "30f1c6e7-e05f-448c-efd2-2720c395855a"
},
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"\u001b[1mModel: \"sequential_1\"\u001b[0m\n"
],
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"sequential_1\"</span>\n",
"</pre>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩\n",
"│ dropout (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ conv2d (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m26\u001b[0m, \u001b[38;5;34m26\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m320\u001b[0m │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ max_pooling2d_2 (\u001b[38;5;33mMaxPooling2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m13\u001b[0m, \u001b[38;5;34m13\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ dropout_1 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m13\u001b[0m, \u001b[38;5;34m13\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ conv2d_1 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m18,496\u001b[0m │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ max_pooling2d_3 (\u001b[38;5;33mMaxPooling2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m5\u001b[0m, \u001b[38;5;34m5\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ flatten_1 (\u001b[38;5;33mFlatten\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1600\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ dense_1 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m10\u001b[0m) │ \u001b[38;5;34m16,010\u001b[0m │\n",
"└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘\n"
],
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩\n",
"│ dropout (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dropout</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">28</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ conv2d (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">26</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">26</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">320</span> │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ max_pooling2d_2 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">MaxPooling2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">13</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">13</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ dropout_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dropout</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">13</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">13</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ conv2d_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">18,496</span> │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ max_pooling2d_3 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">MaxPooling2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">5</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">5</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ flatten_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Flatten</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1600</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n",
"├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
"│ dense_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">10</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">16,010</span> │\n",
"└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘\n",
"</pre>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m34,826\u001b[0m (136.04 KB)\n"
],
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">34,826</span> (136.04 KB)\n",
"</pre>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m34,826\u001b[0m (136.04 KB)\n"
],
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">34,826</span> (136.04 KB)\n",
"</pre>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
],
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
"</pre>\n"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch 1/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 12ms/step - accuracy: 0.7860 - loss: 0.7278 - val_accuracy: 0.9683 - val_loss: 0.1552\n",
"Epoch 2/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9535 - loss: 0.1478 - val_accuracy: 0.9773 - val_loss: 0.1407\n",
"Epoch 3/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 4ms/step - accuracy: 0.9648 - loss: 0.1125 - val_accuracy: 0.9797 - val_loss: 0.1379\n",
"Epoch 4/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step - accuracy: 0.9692 - loss: 0.0967 - val_accuracy: 0.9812 - val_loss: 0.1157\n",
"Epoch 5/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step - accuracy: 0.9723 - loss: 0.0851 - val_accuracy: 0.9802 - val_loss: 0.1207\n",
"Epoch 6/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9728 - loss: 0.0840 - val_accuracy: 0.9818 - val_loss: 0.1065\n",
"Epoch 7/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9768 - loss: 0.0729 - val_accuracy: 0.9818 - val_loss: 0.1065\n",
"Epoch 8/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9777 - loss: 0.0721 - val_accuracy: 0.9820 - val_loss: 0.1130\n",
"Epoch 9/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 4ms/step - accuracy: 0.9777 - loss: 0.0704 - val_accuracy: 0.9820 - val_loss: 0.0920\n",
"Epoch 10/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9787 - loss: 0.0652 - val_accuracy: 0.9833 - val_loss: 0.1043\n",
"Epoch 11/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 4ms/step - accuracy: 0.9800 - loss: 0.0622 - val_accuracy: 0.9830 - val_loss: 0.0939\n",
"Epoch 12/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 4ms/step - accuracy: 0.9813 - loss: 0.0588 - val_accuracy: 0.9817 - val_loss: 0.1109\n",
"Epoch 13/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 4ms/step - accuracy: 0.9825 - loss: 0.0557 - val_accuracy: 0.9817 - val_loss: 0.1005\n",
"Epoch 14/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9811 - loss: 0.0554 - val_accuracy: 0.9825 - val_loss: 0.0965\n",
"Epoch 15/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9829 - loss: 0.0533 - val_accuracy: 0.9840 - val_loss: 0.0987\n",
"Epoch 16/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9832 - loss: 0.0520 - val_accuracy: 0.9788 - val_loss: 0.1099\n",
"Epoch 17/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9830 - loss: 0.0512 - val_accuracy: 0.9838 - val_loss: 0.0947\n",
"Epoch 18/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9829 - loss: 0.0503 - val_accuracy: 0.9835 - val_loss: 0.1032\n",
"Epoch 19/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 4ms/step - accuracy: 0.9829 - loss: 0.0499 - val_accuracy: 0.9842 - val_loss: 0.1041\n",
"Epoch 20/20\n",
"\u001b[1m422/422\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step - accuracy: 0.9854 - loss: 0.0464 - val_accuracy: 0.9793 - val_loss: 0.1131\n",
"Test loss: 0.11808445304632187\n",
"Test accuracy: 0.9824000597000122\n"
]
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "NhDLuEqmUsrl"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment