Skip to content

Instantly share code, notes, and snippets.

@NMZivkovic
Last active November 25, 2018 16:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save NMZivkovic/03378abd6e635c7cb47503b945cc7616 to your computer and use it in GitHub Desktop.
Save NMZivkovic/03378abd6e635c7cb47503b945cc7616 to your computer and use it in GitHub Desktop.
import numpy as np
from tensorflow.keras.datasets import fashion_mnist
from autoencoder_convonutional import Autoencoder
import matplotlib.pyplot as plt
# Import data
(x_train, _), (x_test, _) = fashion_mnist.load_data()
# Prepare input
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = np.reshape(x_train, (len(x_train), 28, 28, 1))
x_test = np.reshape(x_test, (len(x_test), 28, 28, 1))
# Convolutional implementation
autoencoder = Autoencoder()
autoencoder.train(x_train, x_test, 256, 50)
decoded_imgs = autoencoder.getDecodedImage(x_test)
# Convolutional implementation results
plt.figure(figsize=(20, 4))
for i in range(10):
# Original
subplot = plt.subplot(2, 10, i + 1)
plt.imshow(x_test[i].reshape(28, 28))
plt.gray()
subplot.get_xaxis().set_visible(False)
subplot.get_yaxis().set_visible(False)
# Reconstruction
subplot = plt.subplot(2, 10, i + 11)
plt.imshow(decoded_imgs[i].reshape(28, 28))
plt.gray()
subplot.get_xaxis().set_visible(False)
subplot.get_yaxis().set_visible(False)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment