Skip to content

Instantly share code, notes, and snippets.

@truongthanhdat
Last active March 23, 2020 09:03
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save truongthanhdat/595e9c95990f36dfea9e8ff66fd72b42 to your computer and use it in GitHub Desktop.
Save truongthanhdat/595e9c95990f36dfea9e8ff66fd72b42 to your computer and use it in GitHub Desktop.
MNIST GAN Tutorial
import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
def generator(inputs):
with tf.variable_scope("generator"):
net = slim.fully_connected(inputs, 256, scope = "fc1")
net = slim.fully_connected(net, 784, scope = "fake_images", activation_fn = tf.nn.sigmoid)
return net
def discriminator(inputs):
with tf.variable_scope("discriminator"):
net = slim.fully_connected(inputs, 256, scope = "fc1")
net = slim.fully_connected(net, 1, scope = "predictions", activation_fn = tf.nn.sigmoid)
return net
if __name__ == "__main__":
mnist_loader = input_data.read_data_sets('MNIST_data')
batch_size = 32
z_dim = 100
learning_rate = 0.0002
num_iters = 100000
random_z = tf.placeholder(shape = [batch_size, z_dim], dtype = tf.float32, name = "random_vector")
real_images = tf.placeholder(shape = [batch_size, 784], dtype = tf.float32, name = "real_images")
fake_images = generator(random_z)
predictions = discriminator(tf.concat([real_images, fake_images], axis = 0))
real_preds = tf.slice(predictions, [0, 0], [batch_size, -1])
fake_preds = tf.slice(predictions, [batch_size, 0], [batch_size, -1])
gen_loss = -tf.reduce_mean(tf.log(fake_preds))
dis_loss = -tf.reduce_mean(tf.log(real_preds) + tf.log(1. - fake_preds))
gen_vars = slim.get_variables(scope = "generator")
dis_vars = slim.get_variables(scope = "discriminator")
optimizer = tf.train.AdamOptimizer(learning_rate)
gen_train_op = optimizer.minimize(gen_loss, var_list = gen_vars)
dis_train_op = optimizer.minimize(dis_loss, var_list = dis_vars)
summaries = [
tf.summary.scalar("gen_loss", gen_loss),
tf.summary.scalar("dis_loss", dis_loss),
tf.summary.image("real_images", tf.reshape(real_images, [batch_size, 28, 28, 1])),
tf.summary.image("fake_images", tf.reshape(fake_images, [batch_size, 28, 28, 1]))
]
summary_op = tf.summary.merge(summaries)
summary_writer = tf.summary.FileWriter("log", graph=tf.get_default_graph())
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for iter in xrange(1, num_iters + 1):
feed_dict = {
random_z: np.random.uniform(-1., 1., size=[batch_size, z_dim]),
real_images: mnist_loader.train.next_batch(batch_size=batch_size)[0]
}
_, _, _gen_loss, _dis_loss, summary = sess.run([gen_train_op, dis_train_op, gen_loss, dis_loss, summary_op],
feed_dict = feed_dict)
summary_writer.add_summary(summary, iter)
if (iter % 50) == 0:
print("Iteration [{:06d}/{:06d}]".format(iter, num_iters))
print("\t>> Generator Loss: {}".format(_gen_loss))
print("\t>> Discriminator Loss: {}".format(_dis_loss))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment