Skip to content

Instantly share code, notes, and snippets.

@UrusuLambda
Created May 13, 2018 11:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save UrusuLambda/c0cbe335ae73a33e83f51cc1bd5d9027 to your computer and use it in GitHub Desktop.
Save UrusuLambda/c0cbe335ae73a33e83f51cc1bd5d9027 to your computer and use it in GitHub Desktop.
canvas_orig = np.empty((28 * n, 28 * n))
canvas_recon = np.empty((28 * n, 28 * n))
canvas_middle = np.empty((16 * n, 8 * n))
for i in range(n):
# MNIST test set
batch_x, _ = mnist.test.next_batch(n)
# ここで中間層の出力値を取得することができる
# 取り出しはここが全て
m = sess.run(encoder_op, feed_dict={X: batch_x})
# Encode and decode the digit image
g = sess.run(decoder_op, feed_dict={X: batch_x})
# Display original images
for j in range(n):
# Draw the original digits
canvas_orig[i * 28:(i + 1) * 28, j * 28:(j + 1) * 28] = batch_x[j].reshape([28, 28])
# Display reconstructed images
for j in range(n):
# Draw the middle digits
canvas_middle[i * 16:(i + 1) * 16, j * 8:(j + 1) * 8] = m[j].reshape([16, 8])
# Display reconstructed images
for j in range(n):
# Draw the reconstructed digits
canvas_recon[i * 28:(i + 1) * 28, j * 28:(j + 1) * 28] = g[j].reshape([28, 28])
print("Original Images")
plt.figure(figsize=(n, n))
plt.imshow(canvas_orig, origin="upper", cmap="gray")
plt.show()
print("Middle Images")
plt.figure(figsize=(n, n))
plt.imshow(canvas_middle, origin="upper", cmap="gray")
plt.show()
print("Reconstructed Images")
plt.figure(figsize=(n, n))
plt.imshow(canvas_recon, origin="upper", cmap="gray")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment