Last active
April 8, 2020 18:45
-
-
Save duhaime/90b6f359bc48b0763f092b249619e609 to your computer and use it in GitHub Desktop.
Vanilla GAN.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
%load_ext autoreload | |
%autoreload 2 | |
from keras.layers import Input, Dense, Reshape, Flatten, BatchNormalization, Concatenate, multiply, advanced_activations | |
import matplotlib.pyplot as plt | |
import tensorflow as tf | |
import numpy as np | |
import sys, os, warnings, keras | |
warnings.filterwarnings('ignore') | |
if not os.path.exists('images'): os.makedirs('images') | |
# allow dynamic GPU memory allocation | |
config = tf.ConfigProto() | |
config.gpu_options.allow_growth = True | |
class GAN(): | |
def __init__(self, width=28, height=28, channels=1, latent_dim=250, lr=0.0002): | |
self.WIDTH = int(width) # width of input images | |
self.HEIGHT = int(height) # height of input images | |
self.CHANNELS = int(channels) # n color channels in images | |
self.SHAPE = (self.WIDTH, self.HEIGHT, self.CHANNELS) | |
self.LATENT_DIM = latent_dim # length of vector used to model latent space (= noise) | |
self.OPTIMIZER = keras.optimizers.Adam(lr=lr, decay=8e-9) | |
# generator | |
self.G = self.generator() | |
self.G.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER) | |
# discriminator | |
self.D = self.discriminator() | |
self.D.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER, metrics=['accuracy']) | |
# stacked generator + discriminator | |
self.stacked_G_D = self.stacked_G_D() | |
self.stacked_G_D.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER) | |
def generator(self): | |
i = Input((self.LATENT_DIM,)) # noise input - allows generator to create different outputs | |
h = Dense(256)(i) | |
h = advanced_activations.LeakyReLU(alpha=0.2)(h) | |
h = BatchNormalization(momentum=0.8)(h) | |
h = Dense(512)(h) | |
h = advanced_activations.LeakyReLU(alpha=0.2)(h) | |
h = BatchNormalization(momentum=0.8)(h) | |
h = Dense(1024)(h) | |
h = advanced_activations.LeakyReLU(alpha=0.2)(h) | |
h = BatchNormalization(momentum=0.8)(h) | |
h = Dense(self.WIDTH * self.HEIGHT * self.CHANNELS, activation='tanh')(h) | |
o = Reshape((self.WIDTH, self.HEIGHT, self.CHANNELS))(h) | |
model = keras.models.Model(inputs=[i], outputs=[o]) | |
model.summary() | |
return model | |
def discriminator(self): | |
i = Input((self.SHAPE)) | |
h = Flatten()(i) | |
h = Dense((self.WIDTH * self.HEIGHT * self.CHANNELS))(h) | |
h = advanced_activations.LeakyReLU(alpha=0.2)(h) | |
h = Dense((self.WIDTH * self.HEIGHT * self.CHANNELS))(h) | |
h = advanced_activations.LeakyReLU(alpha=0.2)(h) | |
o = Dense(1, activation='sigmoid')(h) | |
model = keras.models.Model(inputs=[i], outputs=[o]) | |
model.summary() | |
return model | |
def stacked_G_D(self): | |
self.D.trainable = False # prevent gradients from influincing discriminator's weights | |
i = Input((self.LATENT_DIM,)) | |
h = self.G(i) | |
o = self.D(h) | |
model = keras.models.Model(inputs=[i], outputs=[o]) | |
return model | |
def train(self, X_train, X_labels, epochs=20000, batch=32, save_interval=100): | |
for idx in range(epochs): | |
# train the discriminator | |
random_index = np.random.randint(0, len(X_train) - batch) | |
legit_images = X_train[random_index : random_index + batch].reshape(batch, self.WIDTH, self.HEIGHT, self.CHANNELS) | |
gen_noise = np.random.normal(0, 1, (batch, self.LATENT_DIM)) | |
synthetic_images = self.G.predict(gen_noise) | |
x_combined_batch = np.concatenate((legit_images, synthetic_images)) | |
y_combined_batch = np.concatenate((np.ones((batch, 1)), np.zeros((batch, 1)))) | |
d_loss = self.D.train_on_batch(x_combined_batch, y_combined_batch) | |
# train the generator | |
noise = np.random.normal(0, 1, (batch, self.LATENT_DIM)) | |
y_mislabled = np.ones((batch, 1)) | |
g_loss = self.stacked_G_D.train_on_batch(noise, y_mislabled) | |
if idx % save_interval == 0: | |
print('epoch: {0} - discriminator loss: {1}], generator loss: {2}'.format(idx, d_loss[0], g_loss)) | |
self.plot_images(save_to_disk=True, step=idx) | |
def plot_images(self, save_to_disk=False, n_images=16, step=0, rows=4, size_scalar=4): | |
filename = './images/mnist_{0}.png'.format(step) | |
noise = np.random.normal(0, 1, (n_images, self.LATENT_DIM)) | |
images = self.G.predict(noise) | |
cols = np.ceil(n_images/rows) # n_cols in grid | |
fig = plt.figure(figsize=(cols*size_scalar, rows*size_scalar)) | |
for i in range(n_images): | |
ax = fig.add_subplot(rows, np.ceil(n_images/rows), i+1) | |
image = np.reshape(images[i], [28, 28]) | |
plt.imshow(image) | |
fig.subplots_adjust(hspace=0, wspace=0) | |
if save_to_disk: | |
fig.savefig(filename) | |
plt.close('all') | |
else: | |
fig.show() | |
(X_train, X_labels), (_, _) = keras.datasets.mnist.load_data() | |
X_train = (X_train.astype(np.float32) - 127.5) / 127.5 # rescale {-1 to 1} | |
X_train = np.expand_dims(X_train, axis=3) | |
gan = GAN() | |
gan.train(X_train, X_labels) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment