-
-
Save ketyi/18cb4462992152d98837bf21e82a9454 to your computer and use it in GitHub Desktop.
An implementation of InfoGAN.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# InfoGAN Tutorial" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"This tutorials walks through an implementation of InfoGAN as described in [InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets](https://arxiv.org/abs/1606.03657).\n", | |
"\n", | |
"To learn more about InfoGAN, see this [Medium post](https://medium.com/p/dd710852db46) on them. To lean more about GANs generally, see [this one](https://medium.com/@awjuliani/generative-adversarial-networks-explained-with-a-classic-spongebob-squarepants-episode-54deab2fce39#.692jyamki)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"#Import the libraries we will need.\n", | |
"import tensorflow as tf\n", | |
"import numpy as np\n", | |
"import input_data\n", | |
"import matplotlib.pyplot as plt\n", | |
"import tensorflow.contrib.slim as slim\n", | |
"import os\n", | |
"import scipy.misc\n", | |
"import scipy" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Load the MNIST dataset." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": true | |
}, | |
"outputs": [], | |
"source": [ | |
"mnist = input_data.read_data_sets(\"MNIST_data/\", one_hot=False)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Helper Functions" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"#This function performns a leaky relu activation, which is needed for the discriminator network.\n", | |
"def lrelu(x, leak=0.2, name=\"lrelu\"):\n", | |
" with tf.variable_scope(name):\n", | |
" f1 = 0.5 * (1 + leak)\n", | |
" f2 = 0.5 * (1 - leak)\n", | |
" return f1 * x + f2 * abs(x)\n", | |
" \n", | |
"#The below functions are taken from carpdem20's implementation https://github.com/carpedm20/DCGAN-tensorflow\n", | |
"#They allow for saving sample images from the generator to follow progress\n", | |
"def save_images(images, size, image_path):\n", | |
" return imsave(inverse_transform(images), size, image_path)\n", | |
"\n", | |
"def imsave(images, size, path):\n", | |
" return scipy.misc.imsave(path, merge(images, size))\n", | |
"\n", | |
"def inverse_transform(images):\n", | |
" return (images+1.)/2.\n", | |
"\n", | |
"def merge(images, size):\n", | |
" h, w = images.shape[1], images.shape[2]\n", | |
" img = np.zeros((h * size[0], w * size[1]))\n", | |
"\n", | |
" for idx, image in enumerate(images):\n", | |
" i = idx % size[1]\n", | |
" j = idx // size[1]\n", | |
" img[j*h:j*h+h, i*w:i*w+w] = image\n", | |
"\n", | |
" return img" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Defining the Adversarial Networks" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Generator Network\n", | |
"\n", | |
"The generator takes a vector of random numbers and transforms it into a 32x32 image. Each layer in the network involves a strided transpose convolution, batch normalization, and rectified nonlinearity. Tensorflow's slim library allows us to easily define each of these layers." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def generator(z):\n", | |
" \n", | |
" zP = slim.fully_connected(z,4*4*256,normalizer_fn=slim.batch_norm,\\\n", | |
" activation_fn=tf.nn.relu,scope='g_project',weights_initializer=initializer)\n", | |
" zCon = tf.reshape(zP,[-1,4,4,256])\n", | |
" \n", | |
" gen1 = slim.convolution2d(\\\n", | |
" zCon,num_outputs=128,kernel_size=[3,3],\\\n", | |
" padding=\"SAME\",normalizer_fn=slim.batch_norm,\\\n", | |
" activation_fn=tf.nn.relu,scope='g_conv1', weights_initializer=initializer)\n", | |
" gen1 = tf.depth_to_space(gen1,2)\n", | |
" \n", | |
" gen2 = slim.convolution2d(\\\n", | |
" gen1,num_outputs=64,kernel_size=[3,3],\\\n", | |
" padding=\"SAME\",normalizer_fn=slim.batch_norm,\\\n", | |
" activation_fn=tf.nn.relu,scope='g_conv2', weights_initializer=initializer)\n", | |
" gen2 = tf.depth_to_space(gen2,2)\n", | |
" \n", | |
" gen3 = slim.convolution2d(\\\n", | |
" gen2,num_outputs=32,kernel_size=[3,3],\\\n", | |
" padding=\"SAME\",normalizer_fn=slim.batch_norm,\\\n", | |
" activation_fn=tf.nn.relu,scope='g_conv3', weights_initializer=initializer)\n", | |
" gen3 = tf.depth_to_space(gen3,2)\n", | |
" \n", | |
" g_out = slim.convolution2d(\\\n", | |
" gen3,num_outputs=1,kernel_size=[32,32],padding=\"SAME\",\\\n", | |
" biases_initializer=None,activation_fn=tf.nn.tanh,\\\n", | |
" scope='g_out', weights_initializer=initializer)\n", | |
" \n", | |
" return g_out" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Discriminator Network\n", | |
"The discriminator network takes as input a 32x32 image and transforms it into a single valued probability of being generated from real-world data. Again we use tf.slim to define the convolutional layers, batch normalization, and weight initialization." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"def discriminator(bottom, cat_list,conts, reuse=False):\n", | |
" \n", | |
" dis1 = slim.convolution2d(bottom,32,[3,3],padding=\"SAME\",\\\n", | |
" biases_initializer=None,activation_fn=lrelu,\\\n", | |
" reuse=reuse,scope='d_conv1',weights_initializer=initializer)\n", | |
" dis1 = tf.space_to_depth(dis1,2)\n", | |
" \n", | |
" dis2 = slim.convolution2d(dis1,64,[3,3],padding=\"SAME\",\\\n", | |
" normalizer_fn=slim.batch_norm,activation_fn=lrelu,\\\n", | |
" reuse=reuse,scope='d_conv2', weights_initializer=initializer)\n", | |
" dis2 = tf.space_to_depth(dis2,2)\n", | |
" \n", | |
" dis3 = slim.convolution2d(dis2,128,[3,3],padding=\"SAME\",\\\n", | |
" normalizer_fn=slim.batch_norm,activation_fn=lrelu,\\\n", | |
" reuse=reuse,scope='d_conv3',weights_initializer=initializer)\n", | |
" dis3 = tf.space_to_depth(dis3,2)\n", | |
" \n", | |
" dis4 = slim.fully_connected(slim.flatten(dis3),1024,activation_fn=lrelu,\\\n", | |
" reuse=reuse,scope='d_fc1', weights_initializer=initializer)\n", | |
" \n", | |
" d_out = slim.fully_connected(dis4,1,activation_fn=tf.nn.sigmoid,\\\n", | |
" reuse=reuse,scope='d_out', weights_initializer=initializer)\n", | |
" \n", | |
" q_a = slim.fully_connected(dis4,128,normalizer_fn=slim.batch_norm,\\\n", | |
" reuse=reuse,scope='q_fc1', weights_initializer=initializer)\n", | |
" \n", | |
" \n", | |
" ## Here we define the unique layers used for the q-network. The number of outputs depends on the number of \n", | |
" ## latent variables we choose to define.\n", | |
" q_cat_outs = []\n", | |
" for idx,var in enumerate(cat_list):\n", | |
" q_outA = slim.fully_connected(q_a,var,activation_fn=tf.nn.softmax,\\\n", | |
" reuse=reuse,scope='q_out_cat_'+str(idx), weights_initializer=initializer)\n", | |
" q_cat_outs.append(q_outA)\n", | |
" \n", | |
" q_cont_outs = None\n", | |
" if conts > 0:\n", | |
" q_cont_outs = slim.fully_connected(q_a,conts,activation_fn=tf.nn.tanh,\\\n", | |
" reuse=reuse,scope='q_out_cont_'+str(conts), weights_initializer=initializer)\n", | |
" \n", | |
" return d_out,q_cat_outs,q_cont_outs" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Connecting them together" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"tf.reset_default_graph()\n", | |
"\n", | |
"z_size = 64 #Size of initial z vector used for generator.\n", | |
"\n", | |
"# Define latent variables.\n", | |
"categorical_list = [10] # Each entry in this list defines a categorical variable of a specific size.\n", | |
"number_continuous = 2 # The number of continous variables.\n", | |
"\n", | |
"#This initializaer is used to initialize all the weights of the network.\n", | |
"initializer = tf.truncated_normal_initializer(stddev=0.02)\n", | |
"\n", | |
"#These placeholders are used for input into the generator and discriminator, respectively.\n", | |
"z_in = tf.placeholder(shape=[None,z_size],dtype=tf.float32) #Random vector\n", | |
"real_in = tf.placeholder(shape=[None,32,32,1],dtype=tf.float32) #Real images\n", | |
"\n", | |
"#These placeholders load the latent variables.\n", | |
"latent_cat_in = tf.placeholder(shape=[None,len(categorical_list)],dtype=tf.int32)\n", | |
"latent_cat_list = tf.split(1,len(categorical_list),latent_cat_in)\n", | |
"latent_cont_in = tf.placeholder(shape=[None,number_continuous],dtype=tf.float32)\n", | |
"\n", | |
"oh_list = []\n", | |
"for idx,var in enumerate(categorical_list):\n", | |
" latent_oh = tf.one_hot(tf.reshape(latent_cat_list[idx],[-1]),var)\n", | |
" oh_list.append(latent_oh)\n", | |
"\n", | |
"#Concatenate all c and z variables.\n", | |
"z_lats = oh_list[:]\n", | |
"z_lats.append(z_in)\n", | |
"z_lats.append(latent_cont_in)\n", | |
"z_lat = tf.concat(1,z_lats)\n", | |
"\n", | |
"\n", | |
"Gz = generator(z_lat) #Generates images from random z vectors\n", | |
"Dx,_,_ = discriminator(real_in,categorical_list,number_continuous) #Produces probabilities for real images\n", | |
"Dg,QgCat,QgCont = discriminator(Gz,categorical_list,number_continuous,reuse=True) #Produces probabilities for generator images\n", | |
"\n", | |
"#These functions together define the optimization objective of the GAN.\n", | |
"d_loss = -tf.reduce_mean(tf.log(Dx) + tf.log(1.-Dg)) #This optimizes the discriminator.\n", | |
"g_loss = -tf.reduce_mean(tf.log((Dg/(1-Dg)))) #KL Divergence optimizer\n", | |
"\n", | |
"#Combine losses for each of the categorical variables.\n", | |
"cat_losses = []\n", | |
"for idx,latent_var in enumerate(oh_list):\n", | |
" cat_loss = -tf.reduce_sum(latent_var*tf.log(QgCat[idx]),reduction_indices=1)\n", | |
" cat_losses.append(cat_loss)\n", | |
" \n", | |
"#Combine losses for each of the continous variables.\n", | |
"if number_continuous > 0:\n", | |
" q_cont_loss = tf.reduce_sum(0.5 * tf.square(latent_cont_in - QgCont),reduction_indices=1)\n", | |
"else:\n", | |
" q_cont_loss = tf.constant(0.0)\n", | |
"\n", | |
"q_cont_loss = tf.reduce_mean(q_cont_loss)\n", | |
"q_cat_loss = tf.reduce_mean(cat_losses)\n", | |
"q_loss = tf.add(q_cat_loss,q_cont_loss)\n", | |
"tvars = tf.trainable_variables()\n", | |
"\n", | |
"#The below code is responsible for applying gradient descent to update the GAN.\n", | |
"trainerD = tf.train.AdamOptimizer(learning_rate=0.0002,beta1=0.5)\n", | |
"trainerG = tf.train.AdamOptimizer(learning_rate=0.002,beta1=0.5)\n", | |
"trainerQ = tf.train.AdamOptimizer(learning_rate=0.0002,beta1=0.5)\n", | |
"d_grads = trainerD.compute_gradients(d_loss,tvars[9:-2-((number_continuous>0)*2)-(len(categorical_list)*2)]) #Only update the weights for the discriminator network.\n", | |
"g_grads = trainerG.compute_gradients(g_loss, tvars[0:9]) #Only update the weights for the generator network.\n", | |
"q_grads = trainerG.compute_gradients(q_loss, tvars) \n", | |
"\n", | |
"update_D = trainerD.apply_gradients(d_grads)\n", | |
"update_G = trainerG.apply_gradients(g_grads)\n", | |
"update_Q = trainerQ.apply_gradients(q_grads)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"collapsed": true | |
}, | |
"source": [ | |
"## Training the network\n", | |
"Now that we have fully defined our network, it is time to train it!" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": true | |
}, | |
"outputs": [], | |
"source": [ | |
"batch_size = 64 #Size of image batch to apply at each iteration.\n", | |
"iterations = 500000 #Total number of iterations to use.\n", | |
"sample_directory = './figsTut' #Directory to save sample images from generator in.\n", | |
"model_directory = './models' #Directory to save trained model to.\n", | |
"\n", | |
"init = tf.initialize_all_variables()\n", | |
"saver = tf.train.Saver()\n", | |
"with tf.Session() as sess: \n", | |
" sess.run(init)\n", | |
" for i in range(iterations):\n", | |
" zs = np.random.uniform(-1.0,1.0,size=[batch_size,z_size]).astype(np.float32) #Generate a random z batch\n", | |
" lcat = np.random.randint(0,10,[batch_size,len(categorical_list)]) #Generate random c batch\n", | |
" lcont = np.random.uniform(-1,1,[batch_size,number_continuous]) #\n", | |
" \n", | |
" xs,_ = mnist.train.next_batch(batch_size) #Draw a sample batch from MNIST dataset.\n", | |
" xs = (np.reshape(xs,[batch_size,28,28,1]) - 0.5) * 2.0 #Transform it to be between -1 and 1\n", | |
" xs = np.lib.pad(xs, ((0,0),(2,2),(2,2),(0,0)),'constant', constant_values=(-1, -1)) #Pad the images so the are 32x32\n", | |
" \n", | |
" _,dLoss = sess.run([update_D,d_loss],feed_dict={z_in:zs,real_in:xs,latent_cat_in:lcat,latent_cont_in:lcont}) #Update the discriminator\n", | |
" _,gLoss = sess.run([update_G,g_loss],feed_dict={z_in:zs,latent_cat_in:lcat,latent_cont_in:lcont}) #Update the generator, twice for good measure.\n", | |
" _,qLoss,qK,qC = sess.run([update_Q,q_loss,q_cont_loss,q_cat_loss],feed_dict={z_in:zs,latent_cat_in:lcat,latent_cont_in:lcont}) #Update to optimize mutual information.\n", | |
" if i % 100 == 0:\n", | |
" print \"Gen Loss: \" + str(gLoss) + \" Disc Loss: \" + str(dLoss) + \" Q Losses: \" + str([qK,qC])\n", | |
" z_sample = np.random.uniform(-1.0,1.0,size=[100,z_size]).astype(np.float32) #Generate another z batch\n", | |
" lcat_sample = np.reshape(np.array([e for e in range(10) for _ in range(10)]),[100,1])\n", | |
" a = a = np.reshape(np.array([[(e/4.5 - 1.)] for e in range(10) for _ in range(10)]),[10,10]).T\n", | |
" b = np.reshape(a,[100,1])\n", | |
" c = np.zeros_like(b)\n", | |
" lcont_sample = np.hstack([b,c])\n", | |
" samples = sess.run(Gz,feed_dict={z_in:z_sample,latent_cat_in:lcat_sample,latent_cont_in:lcont_sample}) #Use new z to get sample images from generator.\n", | |
" if not os.path.exists(sample_directory):\n", | |
" os.makedirs(sample_directory)\n", | |
" #Save sample generator images for viewing training progress.\n", | |
" save_images(np.reshape(samples[0:100],[100,32,32]),[10,10],sample_directory+'/fig'+str(i)+'.png')\n", | |
" if i % 1000 == 0 and i != 0:\n", | |
" if not os.path.exists(model_directory):\n", | |
" os.makedirs(model_directory)\n", | |
" saver.save(sess,model_directory+'/model-'+str(i)+'.cptk')\n", | |
" print \"Saved Model\"" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Using a trained network\n", | |
"Once we have a trained model saved, we may want to use it to generate new images, and explore the representation it has learned." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"sample_directory = './figsTut' #Directory to save sample images from generator in.\n", | |
"model_directory = './models' #Directory to load trained model from.\n", | |
"\n", | |
"init = tf.initialize_all_variables()\n", | |
"saver = tf.train.Saver()\n", | |
"with tf.Session() as sess: \n", | |
" sess.run(init)\n", | |
" #Reload the model.\n", | |
" print 'Loading Model...'\n", | |
" ckpt = tf.train.get_checkpoint_state(path)\n", | |
" saver.restore(sess,ckpt.model_checkpoint_path)\n", | |
" \n", | |
" z_sample = np.random.uniform(-1.0,1.0,size=[100,z_size]).astype(np.float32) #Generate another z batch\n", | |
" lcat_sample = np.reshape(np.array([e for e in range(10) for _ in range(10)]),[100,1])\n", | |
" a = a = np.reshape(np.array([[(e/4.5 - 1.)] for e in range(10) for _ in range(10)]),[10,10]).T\n", | |
" b = np.reshape(a,[100,1])\n", | |
" c = np.zeros_like(b)\n", | |
" lcont_sample = np.hstack([b,c])\n", | |
" samples = sess.run(Gz,feed_dict={z_in:z_sample,latent_cat_in:lcat_sample,latent_cont_in:lcont_sample}) #Use new z to get sample images from generator.\n", | |
" if not os.path.exists(sample_directory):\n", | |
" os.makedirs(sample_directory)\n", | |
" #Save sample generator images for viewing training progress.\n", | |
" save_images(np.reshape(samples[0:100],[100,32,32]),[10,10],sample_directory+'/fig_test+'.png')" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.12" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment