Skip to content

Instantly share code, notes, and snippets.

@hristian-carabulea
Forked from jzuern/fashion-mnist.py
Created October 29, 2019 17:43
Show Gist options
  • Save hristian-carabulea/336518de1caa7a44190b295b5a05117f to your computer and use it in GitHub Desktop.
Save hristian-carabulea/336518de1caa7a44190b295b5a05117f to your computer and use it in GitHub Desktop.
from keras.layers import Input, Dense
from keras.models import Model
import matplotlib.pyplot as plt
import numpy as np
from sklearn.neighbors import LSHForest
import matplotlib.pyplot as plt
from keras.datasets import fashion_mnist
# Autoencoder model definition
input_img = Input(shape=(784,))
encoded = Dense(128, activation='relu')(input_img)
encoded = Dense(64, activation='relu')(encoded)
encoded = Dense(32, activation='relu', name='encoded')(encoded)
decoded = Dense(64, activation='relu')(encoded)
decoded = Dense(128, activation='relu')(decoded)
decoded = Dense(784, activation='sigmoid')(decoded)
autoencoder = Model(input_img, decoded)
autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')
# Load fashion MNIST dataset
(x_train, _), (x_test, _) = fashion_mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))
# Train autoencoder on data
autoencoder.fit(x_train, x_train,
epochs=1,
batch_size=128,
shuffle=True,
validation_data=(x_test, x_test),
callbacks=[])
# Define Enocder as new model
layer_name = 'encoded'
encoder = Model(inputs=autoencoder.input,
outputs=autoencoder.get_layer(layer_name).output)
# Generate feature vectors of test dataset
x_test_encoded = encoder.predict(x_test)
x_test = np.reshape(x_test, [-1, 28, 28])
# create Local Sensitivity hashing instance for fast neighborhood search
lshf = LSHForest(random_state=42)
lshf.fit(x_test_encoded)
# Random index of query image from test set
random_query = np.random.randint(0, 1000)
query_features = np.expand_dims(x_test_encoded[random_query, :], axis=0)
distances, indices = lshf.kneighbors(query_features, n_neighbors=5)
plt.imshow(x_test[random_query, :, :])
plt.title('Query image')
plt.gray()
plt.show()
for i in range(1, 5):
ax = plt.subplot(1, 4, i)
plt.imshow(x_test[indices[0][i], :, :])
plt.gray()
plt.title('Distance = ' + str(distances[0][i]))
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment