Skip to content

Instantly share code, notes, and snippets.

@machinaut
Last active March 18, 2018 02:36
Show Gist options
  • Save machinaut/97ed73c145b123ed8965805a4a377974 to your computer and use it in GitHub Desktop.
Save machinaut/97ed73c145b123ed8965805a4a377974 to your computer and use it in GitHub Desktop.
Little GAN
#!/usr/bin/env python
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from scipy.stats import norm, gaussian_kde
def get_generator(*,
x_size,
z_size,
hidden_units=[100, 100],
activation=tf.nn.relu,
scope='generator'):
# Noise Input
z = tf.placeholder(tf.float32, shape=[None, z_size], name='z')
tf.add_to_collection('z', z)
# Generator
with tf.variable_scope(scope):
net = tf.identity(z)
for i, units in enumerate(hidden_units):
net = tf.layers.dense(net, units=units, activation=activation,
name='dense%d' % i)
g = tf.layers.dense(net, units=x_size, name='g')
tf.add_to_collection('g', g)
for var in tf.trainable_variables(scope=scope):
tf.add_to_collection('g_var', var)
return g
def get_discriminator(*,
x_size,
hidden_units=[100, 100],
activation=tf.nn.relu,
scope='discriminator'):
# Correct Label
y = tf.placeholder(tf.bool, shape=[None, 1], name='y')
tf.add_to_collection('y', y)
# Real Input
x = tf.placeholder(tf.float32, shape=[None, x_size], name='x')
tf.add_to_collection('x', x)
# Generator Output
g = tf.get_collection('g')[0]
# Discriminator
with tf.variable_scope(scope):
net = tf.where(y, x, g, name='input')
for i, units in enumerate(hidden_units):
net = tf.layers.dense(net, units=units, activation=activation,
name='dense%d' % i)
d = tf.nn.sigmoid(tf.layers.dense(net, units=1), name='d')
tf.add_to_collection('d', d)
for var in tf.trainable_variables(scope=scope):
tf.add_to_collection('d_var', var)
return d
def get_gan(*,
generator_optimizer=tf.train.GradientDescentOptimizer,
generator_learning_rate=0.01,
generator_learning_rate_decay=0.0,
discriminator_optimizer=tf.train.GradientDescentOptimizer,
discriminator_learning_rate=0.01,
discriminator_learning_rate_decay=0.0):
# Correct Label
y = tf.get_collection('y')[0]
# Discriminator
d = tf.get_collection('d')[0]
# Generator Loss
g_step = tf.Variable(0, trainable=False, name='g_step')
g_lr = tf.multiply(generator_learning_rate,
tf.pow(1 - generator_learning_rate_decay,
tf.cast(g_step, tf.float32)),
name='g_lr')
g_loss = -tf.reduce_mean(tf.log(d))
g_opt = generator_optimizer(learning_rate=g_lr,
name='g_opt')
g_train = g_opt.minimize(g_loss,
var_list=tf.get_collection('g_var'),
global_step=g_step,
name='g_train')
tf.add_to_collection('g_loss', g_loss)
tf.add_to_collection('g_train', g_train)
# Discriminator Loss
d_step = tf.Variable(0, trainable=False, name='d_step')
d_lr = tf.multiply(discriminator_learning_rate,
tf.pow(1 - discriminator_learning_rate_decay,
tf.cast(d_step, tf.float32)),
name='d_lr')
d_loss = -tf.reduce_mean(tf.where(y, tf.log(d), tf.log(1 - d)))
d_opt = discriminator_optimizer(learning_rate=d_lr,
name='d_opt')
d_train = d_opt.minimize(d_loss,
var_list=tf.get_collection('d_var'),
global_step=d_step,
name='d_train')
tf.add_to_collection('d_lr', d_lr)
tf.add_to_collection('d_loss', d_loss)
tf.add_to_collection('d_train', d_train)
def get_data(*,
batch_size,
z_size,
x_size,
x_scale=[0.5, 1.0],
x_shift=[-10.0, 10.0],
n_modes=3,
p_generated=0.5):
if getattr(get_data, 'm', None) is None:
get_data.m = np.random.uniform(*x_scale, (1, n_modes, x_size))
get_data.b = np.random.uniform(*x_shift, (1, n_modes, x_size))
r = np.random.randn(batch_size, n_modes, x_size) * get_data.m + get_data.b
n = np.random.choice(n_modes, size=batch_size)
x = r.reshape(-1, x_size)[np.arange(batch_size) * n_modes + n, :]
q = np.random.randn(batch_size, z_size)
z = q / np.sqrt(np.sum(np.square(q), axis=1, keepdims=True))
probs = [p_generated, 1 - p_generated]
y = np.random.choice(a=[False, True], size=(batch_size, 1), p=probs)
return x, z, y
def get_density(*, x):
density = norm.pdf(x.reshape(-1), loc=get_data.b, scale=get_data.m)
return density.mean(1).reshape(x.shape)
def draw(*, data_x, data_z, data_y, sess=None):
lin = np.linspace(-20, 20, 200)
plt.cla()
plt.ylim([0, 1])
plt.plot(lin, get_density(x=lin), color=(0, 1, 0), label='p')
x_kde = gaussian_kde(data_x.reshape(-1), 0.15)
plt.plot(data_x, get_density(x=data_x), '.', color=(0, 0, 0), label='x')
plt.plot(lin, x_kde(lin), color=(0, 0.5, 0, 0.5), label='x')
g = tf.get_collection('g')[0]
z = tf.get_collection('z')[0]
data_g = sess.run(g, feed_dict={z: data_z}).reshape(-1)
g_kde = gaussian_kde(data_g, 0.15)
plt.plot(data_g, np.clip(g_kde(data_g), 0., 1.),
'.', color=(1, 0, 0), label='g')
plt.plot(lin, g_kde(lin), color=(0.5, 0, 0, 0.5), label='g')
x = tf.get_collection('x')[0]
y = tf.get_collection('y')[0]
d = tf.get_collection('d')[0]
data_d = sess.run(d, feed_dict={x: lin.reshape(-1, 1),
y: np.ones((lin.shape[0], 1)),
z: np.zeros((lin.shape[0], z.shape[1]))})
plt.plot(lin, data_d, color=(0, 0, 1), label='d')
plt.axhline(y=0.5, color=(0, 0, 0, 0.5), linestyle='-')
plt.legend(loc="lower right")
plt.pause(.01)
def train(*,
train_steps=10000,
step_render=100,
generator_steps=1,
discriminator_steps=1,
batch_size=100,
x_size=1,
z_size=10,
render=True):
get_generator(x_size=x_size, z_size=z_size)
get_discriminator(x_size=x_size)
get_gan()
x = tf.get_collection('x')[0]
y = tf.get_collection('y')[0]
z = tf.get_collection('z')[0]
g_train = tf.get_collection('g_train')[0]
d_train = tf.get_collection('d_train')[0]
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for i in range(train_steps):
if render and i % step_render == 0:
data_x, data_z, data_y = get_data(
x_size=x_size, z_size=z_size, batch_size=batch_size)
draw(data_x=data_x, data_z=data_z, data_y=data_y, sess=sess)
# Update discriminator
for k in range(discriminator_steps):
data_x, data_z, data_y = get_data(
x_size=x_size, z_size=z_size, batch_size=batch_size)
feed_dict = {x: data_x, y: data_y, z: data_z}
sess.run(d_train, feed_dict=feed_dict)
# Update generator
for k in range(generator_steps):
data_x, data_z, data_y = get_data(
x_size=x_size, z_size=z_size, batch_size=batch_size)
feed_dict = {x: data_x, y: data_y, z: data_z}
sess.run(g_train, feed_dict=feed_dict)
if __name__ == '__main__':
train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment