Skip to content

Instantly share code, notes, and snippets.

@NMZivkovic
Last active November 25, 2018 16:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save NMZivkovic/d184364fbe29f0303979801c25b1708c to your computer and use it in GitHub Desktop.
Save NMZivkovic/d184364fbe29f0303979801c25b1708c to your computer and use it in GitHub Desktop.
import numpy as np
from tensorflow.keras.datasets import fashion_mnist
from autoencoder_keras import Autoencoder
import matplotlib.pyplot as plt
# Import data
(x_train, _), (x_test, _) = fashion_mnist.load_data()
# Prepare input
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))
# Keras implementation
autoencoder = Autoencoder(x_train.shape[1], 32)
autoencoder.train(x_train, x_test, 256, 50)
encoded_imgs = autoencoder.getEncodedImage(x_test)
decoded_imgs = autoencoder.getDecodedImage(encoded_imgs)
# Keras implementation results
plt.figure(figsize=(20, 4))
for i in range(10):
# Original
subplot = plt.subplot(2, 10, i + 1)
plt.imshow(x_test[i].reshape(28, 28))
plt.gray()
subplot.get_xaxis().set_visible(False)
subplot.get_yaxis().set_visible(False)
# Reconstruction
subplot = plt.subplot(2, 10, i + 11)
plt.imshow(decoded_imgs[i].reshape(28, 28))
plt.gray()
subplot.get_xaxis().set_visible(False)
subplot.get_yaxis().set_visible(False)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment