Skip to content

Instantly share code, notes, and snippets.

@DiogenesAnalytics
Created December 5, 2023 07:42
Show Gist options
  • Save DiogenesAnalytics/399a96031ac360da5c34a000f3eae4ea to your computer and use it in GitHub Desktop.
Save DiogenesAnalytics/399a96031ac360da5c34a000f3eae4ea to your computer and use it in GitHub Desktop.
Keras Autoencoder Jupyter Notebook Tutorials
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "1b2cf03f-8eaf-4adf-80ab-0d5d5c661deb",
"metadata": {},
"source": [
"# Building Autoencoders in Keras\n",
"The following `Jupyter Notebook` has been *adapted* from the [Keras blog article](https://blog.keras.io/building-autoencoders-in-keras.html) (and updated with code from a more [recent version](https://github.com/keras-team/keras-io/blob/master/examples/generative/vae.py)) written by *F. Chollet* on [autoencoders](https://en.wikipedia.org/wiki/Autoencoder)."
]
},
{
"cell_type": "markdown",
"id": "22f8d258-9de3-4717-9557-f8392815f6bd",
"metadata": {
"tags": []
},
"source": [
"## Variational Autoencoder (VAE)\n",
"Variational autoencoders are a slightly more modern and interesting take on autoencoding.\n",
"\n",
"What is a variational autoencoder, you ask? It's a type of autoencoder with added constraints on the encoded representations being learned. More precisely, it is an autoencoder that learns a [latent variable model](https://en.wikipedia.org/wiki/Latent_variable_model) for its input data. So instead of letting your neural network learn an arbitrary function, you are learning the parameters of a probability distribution modeling your data. If you sample points from this distribution, you can generate new input data samples: a VAE is a \"generative model\".\n",
"\n",
"How does a variational autoencoder work?\n",
"\n",
"First, an encoder network turns the input samples x into two parameters in a latent space, which we will note z_mean and z_log_sigma. Then, we randomly sample similar points z from the latent normal distribution that is assumed to generate the data, via z = z_mean + exp(z_log_sigma) * epsilon, where epsilon is a random normal tensor. Finally, a decoder network maps these latent space points back to the original input data.\n",
"\n",
"The parameters of the model are trained via two loss functions: a reconstruction loss forcing the decoded samples to match the initial inputs (just like in our previous autoencoders), and the KL divergence between the learned latent distribution and the prior distribution, acting as a regularization term. You could actually get rid of this latter term entirely, although it does help in learning well-formed latent spaces and reducing overfitting to the training data.\n",
"\n",
"Because a VAE is a more complex example, we have made the code available on Github as a [standalone script](https://github.com/keras-team/keras-io/blob/master/examples/generative/vae.py). Here we will review step by step how the model is created.\n",
"\n",
"First, here's our encoder network, mapping inputs to our latent distribution parameters:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5be52c92-8a39-41c0-90d7-9c0198620efd",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# get initial libs\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e6e5e7d3-663d-4132-9d05-7783c4780c35",
"metadata": {},
"outputs": [],
"source": [
"# ereate encoder layers\n",
"latent_dim = 2\n",
"encoder_inputs = keras.Input(shape=(28, 28, 1))\n",
"x = layers.Conv2D(32, 3, activation=\"relu\", strides=2, padding=\"same\")(encoder_inputs)\n",
"x = layers.Conv2D(64, 3, activation=\"relu\", strides=2, padding=\"same\")(x)\n",
"x = layers.Flatten()(x)\n",
"x = layers.Dense(16, activation=\"relu\")(x)\n",
"z_mean = layers.Dense(latent_dim, name=\"z_mean\")(x)\n",
"z_log_var = layers.Dense(latent_dim, name=\"z_log_var\")(x)"
]
},
{
"cell_type": "markdown",
"id": "18665e02-a50a-466f-8167-f682a2e305f5",
"metadata": {},
"source": [
"We can use these parameters to sample new similar points from the latent space:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e282db16-e39b-49fe-ba44-3f07108c23be",
"metadata": {},
"outputs": [],
"source": [
"# create custom sampling layer\n",
"class Sampling(layers.Layer):\n",
" \"\"\"Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.\"\"\"\n",
"\n",
" def call(self, inputs):\n",
" z_mean, z_log_var = inputs\n",
" batch = tf.shape(z_mean)[0]\n",
" dim = tf.shape(z_mean)[1]\n",
" epsilon = tf.random.normal(shape=(batch, dim))\n",
" return z_mean + tf.exp(0.5 * z_log_var) * epsilon\n",
"\n",
"# apply sampling and finish encoder\n",
"z = Sampling()([z_mean, z_log_var])\n",
"encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name=\"encoder\")\n",
"encoder.summary()"
]
},
{
"cell_type": "markdown",
"id": "fcc53270-4ac2-4ad2-9c12-efcc281c6bd9",
"metadata": {},
"source": [
"Finally, we can map these sampled latent points back to reconstructed inputs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0de3265a-c228-407c-9de1-0f8dec7d80b1",
"metadata": {},
"outputs": [],
"source": [
"# build decoder\n",
"latent_inputs = keras.Input(shape=(latent_dim,))\n",
"x = layers.Dense(7 * 7 * 64, activation=\"relu\")(latent_inputs)\n",
"x = layers.Reshape((7, 7, 64))(x)\n",
"x = layers.Conv2DTranspose(64, 3, activation=\"relu\", strides=2, padding=\"same\")(x)\n",
"x = layers.Conv2DTranspose(32, 3, activation=\"relu\", strides=2, padding=\"same\")(x)\n",
"decoder_outputs = layers.Conv2DTranspose(1, 3, activation=\"sigmoid\", padding=\"same\")(x)\n",
"decoder = keras.Model(latent_inputs, decoder_outputs, name=\"decoder\")\n",
"decoder.summary()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eb6cb292-631d-4813-916a-e5ccabf7209c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# create new VAE model class\n",
"class VAE(keras.Model):\n",
" def __init__(self, encoder, decoder, **kwargs):\n",
" super().__init__(**kwargs)\n",
" self.encoder = encoder\n",
" self.decoder = decoder\n",
" self.total_loss_tracker = keras.metrics.Mean(name=\"total_loss\")\n",
" self.reconstruction_loss_tracker = keras.metrics.Mean(\n",
" name=\"reconstruction_loss\"\n",
" )\n",
" self.kl_loss_tracker = keras.metrics.Mean(name=\"kl_loss\")\n",
"\n",
" @property\n",
" def metrics(self):\n",
" return [\n",
" self.total_loss_tracker,\n",
" self.reconstruction_loss_tracker,\n",
" self.kl_loss_tracker,\n",
" ]\n",
"\n",
" def train_step(self, data):\n",
" with tf.GradientTape() as tape:\n",
" z_mean, z_log_var, z = self.encoder(data)\n",
" reconstruction = self.decoder(z)\n",
" reconstruction_loss = tf.reduce_mean(\n",
" tf.reduce_sum(\n",
" keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)\n",
" )\n",
" )\n",
" kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))\n",
" kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))\n",
" total_loss = reconstruction_loss + kl_loss\n",
" grads = tape.gradient(total_loss, self.trainable_weights)\n",
" self.optimizer.apply_gradients(zip(grads, self.trainable_weights))\n",
" self.total_loss_tracker.update_state(total_loss)\n",
" self.reconstruction_loss_tracker.update_state(reconstruction_loss)\n",
" self.kl_loss_tracker.update_state(kl_loss)\n",
" return {\n",
" \"loss\": self.total_loss_tracker.result(),\n",
" \"reconstruction_loss\": self.reconstruction_loss_tracker.result(),\n",
" \"kl_loss\": self.kl_loss_tracker.result(),\n",
" }"
]
},
{
"cell_type": "markdown",
"id": "6417421d-896a-479c-92cf-5dbd6a548184",
"metadata": {},
"source": [
"What we've done so far allows us to instantiate 3 models:\n",
"\n",
"+ an end-to-end autoencoder mapping inputs to reconstructions\n",
"+ an encoder mapping inputs to the latent space\n",
"+ a generator that can take points on the latent space and will output the corresponding reconstructed samples.\n",
"\n",
"We train the model using the end-to-end model, with a custom loss function: the sum of a reconstruction term, and the KL divergence regularization term."
]
},
{
"cell_type": "markdown",
"id": "8ed21506-144a-4d9a-9ca3-e7bb5f1b4e17",
"metadata": {},
"source": [
"We train our VAE on MNIST digits:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8f607029-18bb-4a05-9436-9893d0a8145e",
"metadata": {},
"outputs": [],
"source": [
"(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()\n",
"mnist_digits = np.concatenate([x_train, x_test], axis=0)\n",
"mnist_digits = np.expand_dims(mnist_digits, -1).astype(\"float32\") / 255\n",
"\n",
"vae = VAE(encoder, decoder)\n",
"vae.compile(optimizer=keras.optimizers.Adam())\n",
"vae.fit(mnist_digits, epochs=30, batch_size=128)"
]
},
{
"cell_type": "markdown",
"id": "90685acd-c59b-44fd-b4b2-614cc272f23f",
"metadata": {},
"source": [
"Because our latent space is two-dimensional, there are a few cool visualizations that can be done at this point. One is to look at the neighborhoods of different classes on the latent 2D plane: "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "71689f1e-6b20-40ce-9cfb-26923b91241c",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"def plot_label_clusters(vae, data, labels):\n",
" # display a 2D plot of the digit classes in the latent space\n",
" z_mean, _, _ = vae.encoder.predict(data, verbose=0)\n",
" plt.figure(figsize=(12, 10))\n",
" plt.scatter(z_mean[:, 0], z_mean[:, 1], c=labels)\n",
" plt.colorbar()\n",
" plt.xlabel(\"z[0]\")\n",
" plt.ylabel(\"z[1]\")\n",
" plt.show()\n",
"\n",
"\n",
"(x_train, y_train), _ = keras.datasets.mnist.load_data()\n",
"x_train = np.expand_dims(x_train, -1).astype(\"float32\") / 255\n",
"\n",
"plot_label_clusters(vae, x_train, y_train)"
]
},
{
"cell_type": "markdown",
"id": "2e1231f0-8084-425f-a18d-207fb42b1863",
"metadata": {},
"source": [
"Each of these colored clusters is a type of digit. Close clusters are digits that are structurally similar (i.e. digits that share information in the latent space).\n",
"\n",
"Because the VAE is a generative model, we can also use it to generate new digits! Here we will scan the latent plane, sampling latent points at regular intervals, and generating the corresponding digit for each of these points. This gives us a visualization of the latent manifold that \"generates\" the MNIST digits."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4eeab384-5cbf-489c-875f-33f5034079d6",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def plot_latent_space(vae, n=30, figsize=15):\n",
" # display an n*n 2D manifold of digits\n",
" digit_size = 28\n",
" scale = 1.0\n",
" figure = np.zeros((digit_size * n, digit_size * n))\n",
" # linearly spaced coordinates corresponding to the 2D plot\n",
" # of digit classes in the latent space\n",
" grid_x = np.linspace(-scale, scale, n)\n",
" grid_y = np.linspace(-scale, scale, n)[::-1]\n",
"\n",
" for i, yi in enumerate(grid_y):\n",
" for j, xi in enumerate(grid_x):\n",
" z_sample = np.array([[xi, yi]])\n",
" x_decoded = vae.decoder.predict(z_sample)\n",
" digit = x_decoded[0].reshape(digit_size, digit_size)\n",
" figure[\n",
" i * digit_size : (i + 1) * digit_size,\n",
" j * digit_size : (j + 1) * digit_size,\n",
" ] = digit\n",
"\n",
" plt.figure(figsize=(figsize, figsize))\n",
" start_range = digit_size // 2\n",
" end_range = n * digit_size + start_range\n",
" pixel_range = np.arange(start_range, end_range, digit_size)\n",
" sample_range_x = np.round(grid_x, 1)\n",
" sample_range_y = np.round(grid_y, 1)\n",
" plt.xticks(pixel_range, sample_range_x)\n",
" plt.yticks(pixel_range, sample_range_y)\n",
" plt.xlabel(\"z[0]\")\n",
" plt.ylabel(\"z[1]\")\n",
" plt.imshow(figure, cmap=\"Greys_r\")\n",
" plt.show()\n",
"\n",
"# run\n",
"plot_latent_space(vae)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment