Skip to content

Instantly share code, notes, and snippets.

@yoel-zeldes
Last active November 12, 2018 20:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yoel-zeldes/fac58aedcb827fe34cd18dac6440d54d to your computer and use it in GitHub Desktop.
Save yoel-zeldes/fac58aedcb827fe34cd18dac6440d54d to your computer and use it in GitHub Desktop.
samples = []
losses_auto_encode = []
losses_digit_classifier = []
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in xrange(params['epochs']):
for _ in xrange(mnist.train.num_examples / params['batch_size']):
batch_images, batch_digits = mnist.train.next_batch(params['batch_size'])
sess.run(train_op, feed_dict={images: batch_images, digits: batch_digits})
train_loss_auto_encode, train_loss_digit_classifier = sess.run(
[loss_auto_encode, loss_digit_classifier],
{images: mnist.train.images, digits: mnist.train.labels})
losses_auto_encode.append(train_loss_auto_encode)
losses_digit_classifier.append(train_loss_digit_classifier)
sample_z = np.tile(np.random.randn(1, params['z_dim']), reps=[num_digits, 1])
gen_samples = sess.run(decoded_images,
feed_dict={z: sample_z, digit_prob: np.eye(num_digits)})
samples.append(gen_samples)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment