This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
samples=[] #generator examples | |
with tf.Session() as sess: | |
sess.run(init) | |
for epoch in range(epochs): | |
num_batches=mnist.train.num_examples//batch_size | |
for i in range(num_batches): | |
batch=mnist.train.next_batch(batch_size) | |
batch_images=batch[0].reshape((batch_size,784)) | |
batch_images=batch_images*2-1 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
lr=0.001 | |
#Do this when multiple networks interact with each other | |
tvars=tf.trainable_variables() #returns all variables created(the two variable scopes) and makes trainable true | |
d_vars=[var for var in tvars if 'dis' in var.name] | |
g_vars=[var for var in tvars if 'gen' in var.name] | |
D_trainer=tf.train.AdamOptimizer(lr).minimize(D_loss,var_list=d_vars) | |
G_trainer=tf.train.AdamOptimizer(lr).minimize(G_loss,var_list=g_vars) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def loss_func(logits_in,labels_in): | |
return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_in,labels=labels_in)) | |
D_real_loss=loss_func(D_logits_real,tf.ones_like(D_logits_real)*0.9) #Smoothing for generalization | |
D_fake_loss=loss_func(D_logits_fake,tf.zeros_like(D_logits_real)) | |
D_loss=D_real_loss+D_fake_loss | |
G_loss= loss_func(D_logits_fake,tf.ones_like(D_logits_fake)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
tf.reset_default_graph() | |
real_images=tf.placeholder(tf.float32,shape=[None,784]) | |
z=tf.placeholder(tf.float32,shape=[None,100]) | |
G=generator(z) | |
D_output_real,D_logits_real=discriminator(real_images) | |
D_output_fake,D_logits_fake=discriminator(G,reuse=True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def generator(z,reuse=None): | |
with tf.variable_scope('gen',reuse=reuse): | |
hidden1=tf.layers.dense(inputs=z,units=128,activation=tf.nn.leaky_relu) | |
hidden2=tf.layers.dense(inputs=hidden1,units=128,activation=tf.nn.leaky_relu) | |
output=tf.layers.dense(inputs=hidden2,units=784,activation=tf.nn.tanh) | |
return output | |
def discriminator(X,reuse=None): | |
with tf.variable_scope('dis',reuse=reuse): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf #tensorflow is a machine learning library | |
import numpy as np #numpy is useful for matrice multiplication | |
import matplotlib.pyplot as plt #visual tool | |
from tensorflow.examples.tutorials.mnist import input_data #import the MNIST datatset | |
mnist=input_data.read_data_sets("MNIST_data") #read in the MNIST images |