Skip to content

Instantly share code, notes, and snippets.

@ronekko
Created December 19, 2017 06:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ronekko/b339dfd30319e97ddf074cd74f1d3cdd to your computer and use it in GitHub Desktop.
Save ronekko/b339dfd30319e97ddf074cd74f1d3cdd to your computer and use it in GitHub Desktop.
Generative adversarial network (GAN) for 1-dimensional Gaussian
# -*- coding: utf-8 -*-
"""
Created on Fri Dec 15 15:22:04 2017
@author: sakurai
"""
import matplotlib.pyplot as plt
import numpy as np
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import cuda
class Generator(chainer.Chain):
def __init__(self, c_in=10, c_out=1):
c_1 = 200
c_2 = 200
super(Generator, self).__init__(
fc1=L.Linear(c_in, c_1),
fc2=L.Linear(c_1, c_2),
fc3=L.Linear(c_2, c_out))
self.dim_z = c_in
def __call__(self, z):
h = F.leaky_relu(self.fc1(z))
h = F.leaky_relu(self.fc2(h))
return self.fc3(h)
def draw_sample(self, size=100):
z = self.xp.random.uniform(-1, 1, size=(size, self.dim_z)).astype('f')
return self(z)
class Discriminator(chainer.Chain):
def __init__(self, c_in=1):
c_1 = 200
c_2 = 200
super(Discriminator, self).__init__(
fc1=L.Linear(c_in, c_1),
fc2=L.Linear(c_1, c_2),
fc3=L.Linear(c_2, 1))
def __call__(self, x):
h = F.elu(self.fc1(x))
h = F.elu(self.fc2(h))
return self.fc3(h)
if __name__ == '__main__':
real_mean = 10
real_std = 2
N = 1000
dim_z = 100
update_generator_interval = 5
plot_x_scale = 20
num_epochs = 10000
batch_size = 100
test_size = 50000
alpha = 0.001
x_train = real_std * np.random.randn(N, 1).astype(np.float32) + real_mean
ds_train = chainer.datasets.TupleDataset(x_train)
it_real = chainer.iterators.SerialIterator(ds_train, batch_size)
generator = Generator(dim_z)
discriminator = Discriminator()
opt_g = chainer.optimizers.Adam(alpha)
opt_d = chainer.optimizers.Adam(alpha)
opt_g.setup(generator)
opt_d.setup(discriminator)
# opt_d.add_hook(chainer.optimizer.WeightDecay(0.01))
t_fake = np.zeros((batch_size, 1), dtype=np.int32)
t_real = np.ones((batch_size, 1), dtype=np.int32)
for epoch in range(num_epochs + 1):
x_fake = generator.draw_sample(batch_size)
y_fake = discriminator(x_fake)
loss_g = F.sigmoid_cross_entropy(y_fake, t_real)
loss_d = F.sigmoid_cross_entropy(y_fake, t_fake)
x_real = chainer.dataset.concat_examples(next(it_real))[0]
x_real = chainer.Variable(x_real)
y_real = discriminator(x_real)
loss_d += F.sigmoid_cross_entropy(y_real, t_real)
discriminator.cleargrads()
loss_d.backward()
opt_d.update()
if epoch % update_generator_interval == 0:
generator.cleargrads()
loss_g.backward()
opt_g.update()
# plot
if epoch % 10 == 0 and epoch != 0:
print(f'# {epoch}')
# drawn samples
plt.hist(cuda.to_cpu(x_real.data).ravel(), color='b')
plt.hist(cuda.to_cpu(x_fake.data).ravel(), color='r', alpha=0.5)
plt.title('Histograms of real and fake sample (mini-batch)')
plt.legend(['Real', 'Fake'])
plt.xlabel('x')
plt.ylabel('Counts')
r = plot_x_scale
plt.xlim(real_mean - r * real_std, real_mean + r * real_std)
plt.ylim(0, 30)
plt.grid()
plt.show()
# discriminator's value
x = np.linspace(real_mean - r * real_std, real_mean + r * real_std,
100, dtype='f').reshape(-1, 1)
d = F.sigmoid(discriminator(x))
plt.plot(x, d.data.ravel())
plt.plot(x, np.full_like(x, 0.5), '--')
plt.title('Discriminator\'s values for each x')
plt.legend(['$\sigma(D(x))$', '0.5'])
plt.xlabel('x')
plt.ylabel('Discriminator\'s prediction (with sigmoid)')
plt.xlim(real_mean - r * real_std, real_mean + r * real_std)
plt.ylim(0 - 0.05, 1 + 0.05)
plt.grid()
plt.show()
if epoch % 100 == 0 and epoch != 0:
x = real_std * np.random.randn(test_size) + real_mean
plt.hist(x, 50, color='b')
xs = []
for i in range(test_size // (batch_size * 10)):
with chainer.no_backprop_mode():
xs.append(generator.draw_sample(batch_size * 10).data)
x = np.concatenate(xs).ravel()
plt.hist(cuda.to_cpu(x), 50, color='r', alpha=0.5)
plt.title(
'Histograms of real and fake sample ({} examples)'.format(
test_size))
plt.legend(['Real', 'Fake'])
plt.xlabel('x')
plt.ylabel('Counts')
plt.xlim(real_mean - 3 * real_std, real_mean + 3 * real_std)
plt.grid()
plt.show()
print('Sample mean = {}'.format(x.mean()))
print('Sample std = {}'.format(x.std()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment