Skip to content

Instantly share code, notes, and snippets.

@DiogenesAnalytics
Created December 5, 2023 07:42
Show Gist options
  • Save DiogenesAnalytics/399a96031ac360da5c34a000f3eae4ea to your computer and use it in GitHub Desktop.
Save DiogenesAnalytics/399a96031ac360da5c34a000f3eae4ea to your computer and use it in GitHub Desktop.
Keras Autoencoder Jupyter Notebook Tutorials
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "1b2cf03f-8eaf-4adf-80ab-0d5d5c661deb",
"metadata": {},
"source": [
"# Building Autoencoders in Keras\n",
"The following `Jupyter Notebook` has been *adapted* from the [Keras blog article](https://blog.keras.io/building-autoencoders-in-keras.html) written by *F. Chollet* on [autoencoders](https://en.wikipedia.org/wiki/Autoencoder)."
]
},
{
"cell_type": "markdown",
"id": "6aa4cb57-b5b0-4d74-827f-2961537727a7",
"metadata": {},
"source": [
"## Deep Autoencoder\n",
"We do not have to limit ourselves to a single layer as encoder or decoder, we could instead use a stack of layers, such as:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5442e685-9586-4e8c-8509-26d7b125f5d3",
"metadata": {},
"outputs": [],
"source": [
"# get initial libs\n",
"import keras\n",
"from keras import layers\n",
"from keras.datasets import mnist\n",
"import numpy as np\n",
"import visualization.image"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a672d893-e9a0-43ec-aa73-ae602d9f3d7c",
"metadata": {},
"outputs": [],
"source": [
"input_img = keras.Input(shape=(784,))\n",
"encoded = layers.Dense(128, activation='relu')(input_img)\n",
"encoded = layers.Dense(64, activation='relu')(encoded)\n",
"encoded = layers.Dense(32, activation='relu')(encoded)\n",
"\n",
"decoded = layers.Dense(64, activation='relu')(encoded)\n",
"decoded = layers.Dense(128, activation='relu')(decoded)\n",
"decoded = layers.Dense(784, activation='sigmoid')(decoded)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "471c4ab2-7085-41b8-961d-50aa17a8f1cd",
"metadata": {},
"outputs": [],
"source": [
"(x_train, _), (x_test, _) = mnist.load_data()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7fec3333-b65d-4c0b-b641-1faded12c43a",
"metadata": {},
"outputs": [],
"source": [
"x_train = x_train.astype('float32') / 255.\n",
"x_test = x_test.astype('float32') / 255.\n",
"x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))\n",
"x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))\n",
"print(x_train.shape)\n",
"print(x_test.shape)"
]
},
{
"cell_type": "markdown",
"id": "2751d57c-1b33-45f2-9ad9-6d27e3c5a7ce",
"metadata": {},
"source": [
"Now train:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "546d6700-fc73-4983-88e8-829f8b1bd160",
"metadata": {},
"outputs": [],
"source": [
"autoencoder = keras.Model(input_img, decoded)\n",
"autoencoder.compile(optimizer='adam', loss='binary_crossentropy')\n",
"\n",
"autoencoder.fit(x_train, x_train,\n",
" epochs=100,\n",
" batch_size=256,\n",
" shuffle=True,\n",
" validation_data=(x_test, x_test))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "318f7cef-e1f0-42a6-8b88-8717fad52428",
"metadata": {},
"outputs": [],
"source": [
"visualization.image.compare_results(x_test, autoencoder.predict(x_test));"
]
},
{
"cell_type": "markdown",
"id": "9d6d8eb8-8737-44b6-a9a4-c6551dd3f4f5",
"metadata": {},
"source": [
"After 100 epochs, it reaches a train and validation loss of ~0.08, a bit better than our previous models. Our reconstructed digits look a bit better too."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment