Skip to content

Instantly share code, notes, and snippets.

@scturtle scturtle/gan.py
Last active Jun 21, 2017

Embed
What would you like to do?
GAN a Gaussian distribution
from __future__ import print_function
import numpy as np
from scipy.stats import norm
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow.contrib.layers as tfl
import tensorflow.contrib.framework as tff
def minibatch(features):
size_a, size_b, size_c = features.get_shape()[1], 3, 3
t = tf.get_variable('T', (size_a, size_b * size_c),
initializer=tf.random_normal_initializer())
m = tf.reshape(tf.matmul(features, t), (-1, size_b, size_c))
diff = tf.expand_dims(m, 3) - tf.expand_dims(tf.transpose(m, (1, 2, 0)), 0)
l1 = tf.reduce_mean(tf.abs(diff), 2)
o = tf.reduce_mean(tf.exp(-l1), 2)
return tf.concat(values=[features, o], axis=1)
def network(inputs, *, use_minibatch=False, last_fn=tf.nn.sigmoid):
w_init = tf.random_normal_initializer()
b_init = tf.constant_initializer(0.)
h0 = tfl.fully_connected(
inputs, 10, scope='h0', activation_fn=tf.nn.tanh,
weights_initializer=w_init, biases_initializer=b_init)
h1 = tfl.fully_connected(
h0, 10, scope='h1', activation_fn=tf.nn.tanh,
weights_initializer=w_init, biases_initializer=b_init)
if use_minibatch:
h1 = minibatch(h1)
h2 = tfl.fully_connected(
h1, 1, scope='h2', activation_fn=last_fn,
weights_initializer=w_init, biases_initializer=b_init)
return h2
def optimizer(loss, lr, var_list):
lr_gen = tf.train.exponential_decay(lr, tff.get_or_create_global_step(), 1000, 0.95, staircase=True)
return tf.train.MomentumOptimizer(lr_gen, 0.5).minimize(loss, tff.get_or_create_global_step(), var_list)
def main():
with tf.variable_scope('G'):
z_input = tf.placeholder(shape=(None, 1), dtype=tf.float32)
G = network(z_input, last_fn=None)
with tf.variable_scope('D') as scope:
x_input = tf.placeholder(shape=(None, 1), dtype=tf.float32)
use_minibatch = True
D1 = network(x_input, use_minibatch=use_minibatch)
scope.reuse_variables()
D2 = network(G, use_minibatch=use_minibatch)
x_label = tf.placeholder(shape=(None, 1), dtype=tf.float32)
loss_dp = tf.reduce_mean(tf.square(D1 - x_label))
loss_d = tf.reduce_mean(-tf.log(D1) - tf.log(1 - D2))
loss_g = tf.reduce_mean(-tf.log(D2))
var_d = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D')
opt_dp = optimizer(loss_dp, 0.3, var_d)
opt_d = optimizer(loss_d, 0.03, var_d)
var_g = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G')
opt_g = optimizer(loss_g, 0.01, var_g)
with tf.Session() as sess:
mu = 2
sigma = 0.5
tf.global_variables_initializer().run()
batch_size = 128
pre_train = True
if pre_train:
for step in range(2000):
x = np.random.uniform(-2, 6, batch_size)
x.sort()
x = x.reshape((-1, 1))
y = norm.pdf(x, loc=mu, scale=sigma).reshape((-1, 1))
ldp, _ = sess.run([loss_dp, opt_dp], {x_input: x, x_label: y})
print('[D_pre] step: {} loss: {}'.format(step, ldp))
tx = np.linspace(0, 4, 200)
ty = norm.pdf(tx, loc=mu, scale=sigma)
pred = sess.run(D1, {x_input: tx.reshape((-1, 1))}).ravel()
plt.plot(tx, ty, label='real')
plt.plot(tx, pred, label='pred')
plt.legend()
plt.show()
plt.ion()
batch_size = 32
for step in range(3000):
x = np.random.normal(mu, sigma, batch_size)
x.sort()
x = x.reshape((-1, 1))
z = np.random.uniform(0, 4, batch_size)
z.sort()
z = z.reshape((-1, 1))
ld, _ = sess.run([loss_d, opt_d], {x_input: x, z_input: z})
lg, _ = sess.run([loss_g, opt_g], {z_input: z})
print('[GAN] step: {} lg: {} ld: {}'.format(step, lg, ld))
if step and step % 100 == 0:
xs = []
gxs = []
for i in range(3000):
z = np.random.uniform(0, 4, batch_size)
z.sort()
z = z.reshape((-1, 1))
gxs.append(sess.run(G, {z_input: z}))
xs.append(np.random.normal(mu, sigma, batch_size))
gxs = np.concatenate(gxs)
xs = np.concatenate(xs)
plt.cla()
plt.hist(xs, bins=100, alpha=0.8, label='real')
plt.hist(gxs, bins=100, alpha=0.8, label='gan')
plt.legend()
plt.pause(0.1)
plt.ioff()
plt.show()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.