Skip to content

Instantly share code, notes, and snippets.

@NduatiK
Created May 21, 2022 07:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save NduatiK/466e74f2a1e06d2f6413d892d64dfd91 to your computer and use it in GitHub Desktop.
Save NduatiK/466e74f2a1e06d2f6413d892d64dfd91 to your computer and use it in GitHub Desktop.

Fashion mnist autoencoder

Mix.install(
  [
    {:axon, "~> 0.1.0-dev", github: "elixir-nx/axon", branch: "main"},
    {:exla, "~> 0.2.0", github: "elixir-nx/nx", sparse: "exla", override: true},
    {:nx, "~> 0.2.0", github: "elixir-nx/nx", sparse: "nx", override: true},
    {:scidata, "~> 0.1.5"}
  ],
  system_env: [
    {"XLA_TARGET", "cuda111"},
    {"EXLA_TARGET", "cuda"},
    {"LIBTORCH_TARGET", "cuda"}
  ]
)

Introduction

An autoencoder is a deep learning model which consists of two parts: encoder and decoder. The encoder compresses high dimensional data into a low dimensional representation and feeds it to the decoder. The decoder tries to recreate the original data from the low dimensional representation. Autoencoders can be used in the following problems:

  • Dimensionality reduction
  • Noise reduction
  • Generative models
  • Data augmentation

Let's walk through a basic autoencoder implementation in Axon to get a better understanding of how they work in practice.

Imports

First, we have to import some essential packages. Axon will be the primary tool to build and train the autoencoder. We also need to import EXLA for hardware acceleration, Nx for some basic tensor operations, and Scidata for downloading training data.

Configure platforms precedence

EXLA provides JIT compilation to hardware accelerators. We'll start by telling EXLA which order of accelerators to use. If you have a specific target in mind, such as :cuda or :rocm, you can adjust the precedence to reflect that. The order of precedence here is tpu > cuda > rocm > host:

EXLA.set_preferred_defn_options([:tpu, :cuda, :rocm, :host])

Downloading Data

To train and test how our model works, we use one of the most popular data set: MNIST fashion. It consists of small black and white images of clothes. Loading this data set is very simple. Just use Scidata.FashionMNIST.download().

Preprocessing of the images

The next step is image preprocessing. First, we need to load the training data into tensor using Nx.from_binary and reshape with Nx.reshape. Then normalize the data with Nx.divide and split the dataset into a list of batches with Nx.to_batched_list.

Encoder and decoder

First we need to define the encoder and decoder. Both are one-layer neural nets.

In the encoder, we start by flattening the input using Axon.flatten because initially, the input shape is {batch_size, 1, 28, 28} and we want to pass the input into a dense layer with Axon.dense. Our dense layer has only latent_dim number of neurons. The latent_dim or latent space is a compressed representation of data. Remember, we want our encoder to compress the input data into a lower-dimensional representation, so we choose a latent_dim which is less than the dimensionality of the input.

Next, we pass the output of the encoder to the decoder and try to reconstruct the compressed data into its original form. Since our original input had a dimensionality of 784, we use an Axon.dense layer with 784 neurons. Because our original data was normalized to have pixel values between 0 and 1, we use a :sigmoid activation in our dense layer to squeeze output values between 0 and 1. Our original input shape was 28x28, so we use Axon.reshape to convert the flattened representation of the outputs into an image with correct the width and height.

Model

If we just bind the encoder and decoder into a sequential model, we'll get the desired model. This was pretty smooth, wasn't it?

Training the model

Finally, we can train the model. We'll use the :adam and :mean_squared_error loss with Axon.Loop.trainer. Our loss function will measure the aggregate error between pixels of original images and the model's reconstructed images. We'll also :mean_absolute_error using Axon.Loop.metric. Axon.Loop.run trains the model with the given training data.

To better understand what is mean absolute error and mean square error let's go through an example.

defmodule Example do
  defp transform_images({bin, type, shape}) do
    bin
    |> Nx.from_binary(type)
    |> Nx.reshape({elem(shape, 0), 1, 28, 28})
    |> Nx.divide(255.0)
    |> Nx.to_batched_list(32)
  end

  def calculate_mse(y_pred, y) do
    y_pred
    |> Nx.subtract(y)
    |> Nx.power(2)
    |> Nx.mean(axes: [-1])
    |> Nx.sum()
  end

  def calculate_mae(y_pred, y) do
    y_pred
    |> Nx.subtract(y)
    |> Nx.abs()
    |> Nx.mean(axes: [-1])
    |> Nx.sum()
  end

  def run_example do
    {images, _} = Scidata.FashionMNIST.download()

    train_images = transform_images(images)

    shoe_image =
      train_images
      |> hd()
      |> Nx.slice_axis(0, 1, 0)
      |> Nx.reshape({1, 1, 28, 28})

    noised_shoe_image =
      train_images
      |> hd()
      |> Nx.slice_axis(0, 1, 0)
      |> Nx.reshape({1, 1, 28, 28})
      |> Nx.add(Nx.random_normal({1, 1, 28, 28}, 0.0, 0.05))

    random_image =
      train_images
      |> Enum.at(1)
      |> Nx.slice_axis(0, 1, 0)
      |> Nx.reshape({1, 1, 28, 28})

    same_images_mse_error = calculate_mse(shoe_image, shoe_image)
    similar_images_mse_error = calculate_mse(shoe_image, noised_shoe_image)
    different_images_mse_error = calculate_mse(shoe_image, random_image)

    same_images_mae_error = calculate_mae(shoe_image, shoe_image)
    similar_images_mae_error = calculate_mae(shoe_image, noised_shoe_image)
    different_images_mae_error = calculate_mae(shoe_image, random_image)

    %{
      same: {[mse: same_images_mse_error], [mae: same_images_mae_error]},
      similar: {[mse: similar_images_mse_error], [mae: similar_images_mae_error]},
      different: {[mse: different_images_mse_error], [mae: different_images_mae_error]}
    }
  end
end
Example.run_example()

We choose an image of a shoe as a reference. As we can see, the error for the same picture is 0 (both mae and mse). Indeed, when we have two exact copies, there is no pixel with different values, and the error is equal to zero. Noised image has a non-zero mse and mae but is much smaller than the error of two completely different pictures. This fact indicates the level of similarity between images. The small error implies decent prediction values. On the other hand, a large value of loss suggests poor quality of predictions.

Evaluation of the model

defmodule Fashionmist do
  require Axon

  defp transform_images({bin, type, shape}) do
    bin
    |> Nx.from_binary(type)
    |> Nx.reshape({elem(shape, 0), 1, 28, 28})
    |> Nx.divide(255.0)
    |> Nx.to_batched_list(256)
  end

  defp encoder(x, latent_dim) do
    x
    |> Axon.flatten()
    |> Axon.dense(latent_dim, activation: :relu)
  end

  defp decoder(x) do
    x
    |> Axon.dense(784, activation: :sigmoid)
    |> Axon.reshape({1, 28, 28})
  end

  def build_model(input_shape, latent_dim) do
    Axon.input(input_shape)
    |> encoder(latent_dim)
    |> decoder()
  end

  defp train_model(model, train_images, epochs) do
    model
    |> Axon.Loop.trainer(:mean_squared_error, :adam)
    |> Axon.Loop.metric(:mean_absolute_error, "Error")
    |> Axon.Loop.run(Stream.zip(train_images, train_images), epochs: epochs, compiler: EXLA)
  end

  def run do
    {images, _} = Scidata.FashionMNIST.download()

    train_images = transform_images(images)

    model = build_model({nil, 1, 28, 28}, 64) |> IO.inspect()

    model_state = train_model(model, train_images, 5)

    sample_image =
      train_images
      |> hd()
      |> Nx.slice_axis(0, 1, 0)
      |> Nx.reshape({1, 1, 28, 28})

    sample_image |> Nx.to_heatmap() |> IO.inspect()

    model
    |> Axon.predict(model_state, sample_image, compiler: EXLA)
    |> Nx.to_heatmap()
    |> IO.inspect()
  end
end

Results analysis

As we can see, the generated image is similar to the input image. The only difference between them is the absence of a sign in the middle of the second shoe. The model treated the sign as noise and bled this into the plain shoe. To get the outputs run this code.

Fashionmist.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment