Skip to content

Instantly share code, notes, and snippets.

@yoel-zeldes
Last active April 1, 2019 20:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yoel-zeldes/e721ed17c62c4497fb47cb1cc074dfaa to your computer and use it in GitHub Desktop.
Save yoel-zeldes/e721ed17c62c4497fb47cb1cc074dfaa to your computer and use it in GitHub Desktop.
def plot_samples(samples, num_epochs):
IMAGE_WIDTH = 0.7
epochs = np.linspace(0, len(samples) - 1, num_epochs).astype(int)
plt.figure(figsize=(IMAGE_WIDTH * NUM_DIGITS,
len(epochs) * IMAGE_WIDTH))
for epoch_index, epoch in enumerate(epochs):
for digit, image in enumerate(samples[epoch]):
plt.subplot(len(epochs),
NUM_DIGITS,
epoch_index * NUM_DIGITS + digit + 1)
plt.imshow(image.reshape((28, 28)),
cmap='Greys_r')
plt.gca().xaxis.set_visible(False)
if digit == 0:
plt.gca().yaxis.set_ticks([])
plt.ylabel('epoch {}'.format(epoch + 1),
verticalalignment='center',
horizontalalignment='right',
rotation=0,
fontsize=14)
else:
plt.gca().yaxis.set_visible(False)
plot_samples(samples=samples, num_epochs=20)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment