Skip to content

Instantly share code, notes, and snippets.

@duhaime
Last active April 8, 2020 18:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save duhaime/90b6f359bc48b0763f092b249619e609 to your computer and use it in GitHub Desktop.
Save duhaime/90b6f359bc48b0763f092b249619e609 to your computer and use it in GitHub Desktop.
Vanilla GAN.ipynb
%load_ext autoreload
%autoreload 2
from keras.layers import Input, Dense, Reshape, Flatten, BatchNormalization, Concatenate, multiply, advanced_activations
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import sys, os, warnings, keras
warnings.filterwarnings('ignore')
if not os.path.exists('images'): os.makedirs('images')
# allow dynamic GPU memory allocation
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
class GAN():
def __init__(self, width=28, height=28, channels=1, latent_dim=250, lr=0.0002):
self.WIDTH = int(width) # width of input images
self.HEIGHT = int(height) # height of input images
self.CHANNELS = int(channels) # n color channels in images
self.SHAPE = (self.WIDTH, self.HEIGHT, self.CHANNELS)
self.LATENT_DIM = latent_dim # length of vector used to model latent space (= noise)
self.OPTIMIZER = keras.optimizers.Adam(lr=lr, decay=8e-9)
# generator
self.G = self.generator()
self.G.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER)
# discriminator
self.D = self.discriminator()
self.D.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER, metrics=['accuracy'])
# stacked generator + discriminator
self.stacked_G_D = self.stacked_G_D()
self.stacked_G_D.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER)
def generator(self):
i = Input((self.LATENT_DIM,)) # noise input - allows generator to create different outputs
h = Dense(256)(i)
h = advanced_activations.LeakyReLU(alpha=0.2)(h)
h = BatchNormalization(momentum=0.8)(h)
h = Dense(512)(h)
h = advanced_activations.LeakyReLU(alpha=0.2)(h)
h = BatchNormalization(momentum=0.8)(h)
h = Dense(1024)(h)
h = advanced_activations.LeakyReLU(alpha=0.2)(h)
h = BatchNormalization(momentum=0.8)(h)
h = Dense(self.WIDTH * self.HEIGHT * self.CHANNELS, activation='tanh')(h)
o = Reshape((self.WIDTH, self.HEIGHT, self.CHANNELS))(h)
model = keras.models.Model(inputs=[i], outputs=[o])
model.summary()
return model
def discriminator(self):
i = Input((self.SHAPE))
h = Flatten()(i)
h = Dense((self.WIDTH * self.HEIGHT * self.CHANNELS))(h)
h = advanced_activations.LeakyReLU(alpha=0.2)(h)
h = Dense((self.WIDTH * self.HEIGHT * self.CHANNELS))(h)
h = advanced_activations.LeakyReLU(alpha=0.2)(h)
o = Dense(1, activation='sigmoid')(h)
model = keras.models.Model(inputs=[i], outputs=[o])
model.summary()
return model
def stacked_G_D(self):
self.D.trainable = False # prevent gradients from influincing discriminator's weights
i = Input((self.LATENT_DIM,))
h = self.G(i)
o = self.D(h)
model = keras.models.Model(inputs=[i], outputs=[o])
return model
def train(self, X_train, X_labels, epochs=20000, batch=32, save_interval=100):
for idx in range(epochs):
# train the discriminator
random_index = np.random.randint(0, len(X_train) - batch)
legit_images = X_train[random_index : random_index + batch].reshape(batch, self.WIDTH, self.HEIGHT, self.CHANNELS)
gen_noise = np.random.normal(0, 1, (batch, self.LATENT_DIM))
synthetic_images = self.G.predict(gen_noise)
x_combined_batch = np.concatenate((legit_images, synthetic_images))
y_combined_batch = np.concatenate((np.ones((batch, 1)), np.zeros((batch, 1))))
d_loss = self.D.train_on_batch(x_combined_batch, y_combined_batch)
# train the generator
noise = np.random.normal(0, 1, (batch, self.LATENT_DIM))
y_mislabled = np.ones((batch, 1))
g_loss = self.stacked_G_D.train_on_batch(noise, y_mislabled)
if idx % save_interval == 0:
print('epoch: {0} - discriminator loss: {1}], generator loss: {2}'.format(idx, d_loss[0], g_loss))
self.plot_images(save_to_disk=True, step=idx)
def plot_images(self, save_to_disk=False, n_images=16, step=0, rows=4, size_scalar=4):
filename = './images/mnist_{0}.png'.format(step)
noise = np.random.normal(0, 1, (n_images, self.LATENT_DIM))
images = self.G.predict(noise)
cols = np.ceil(n_images/rows) # n_cols in grid
fig = plt.figure(figsize=(cols*size_scalar, rows*size_scalar))
for i in range(n_images):
ax = fig.add_subplot(rows, np.ceil(n_images/rows), i+1)
image = np.reshape(images[i], [28, 28])
plt.imshow(image)
fig.subplots_adjust(hspace=0, wspace=0)
if save_to_disk:
fig.savefig(filename)
plt.close('all')
else:
fig.show()
(X_train, X_labels), (_, _) = keras.datasets.mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5 # rescale {-1 to 1}
X_train = np.expand_dims(X_train, axis=3)
gan = GAN()
gan.train(X_train, X_labels)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment