Skip to content

Instantly share code, notes, and snippets.

@jkjung-avt
Created October 28, 2018 09:44
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jkjung-avt/ab01a4f2ab861d21d2345b7f9ebe80f4 to your computer and use it in GitHub Desktop.
Save jkjung-avt/ab01a4f2ab861d21d2345b7f9ebe80f4 to your computer and use it in GitHub Desktop.
A simple DCGAN with MNIST
"""dcgan_mnist.py
This script was orginally written by Rowel Atienza (see below), and
was modified by JK Jung <jkjung13@gmail.com>.
------
DCGAN on MNIST using Keras
Author: Rowel Atienza
Project: https://github.com/roatienza/Deep-Learning-Experiments
Dependencies: tensorflow 1.0 and keras 2.0
Usage: python3 dcgan_mnist.py
"""
import time
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from keras import backend as K
from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten, Reshape
from keras.layers import Conv2D, Conv2DTranspose, UpSampling2D
from keras.layers import LeakyReLU, Dropout
from keras.layers import BatchNormalization
from keras.optimizers import Adam, RMSprop
EMBEDDING_DIM = 1024
DROPOUT_RATE = 0.4
GPU_MEM_FRACTION = 0.2
class ElapsedTimer(object):
def __init__(self):
self.start_time = time.time()
def elapsed(self, sec):
if sec < 60:
return str(sec) + " sec"
elif sec < (60 * 60):
return str(sec / 60) + " min"
else:
return str(sec / (60 * 60)) + " hr"
def elapsed_time(self):
print("Elapsed: %s" % self.elapsed(time.time() - self.start_time))
class DCGAN(object):
def __init__(self, img_rows=28, img_cols=28, channel=1):
self.img_rows = img_rows
self.img_cols = img_cols
self.channel = channel
self.D = None # discriminator
self.G = None # generator
self.AM = None # adversarial model
self.DM = None # discriminator model
# (W−F+2P)/S+1
def discriminator(self):
if self.D:
return self.D
self.D = Sequential()
depth = 64
dropout = DROPOUT_RATE
# In: 28 x 28 x 1, depth = 1
# Out: 14 x 14 x 1, depth=64
input_shape = (self.img_rows, self.img_cols, self.channel)
self.D.add(Conv2D(depth*1, 5, strides=2,
input_shape=input_shape, padding='same'))
self.D.add(LeakyReLU(alpha=0.2))
self.D.add(Dropout(dropout))
self.D.add(Conv2D(depth*2, 5, strides=2, padding='same'))
self.D.add(LeakyReLU(alpha=0.2))
self.D.add(Dropout(dropout))
self.D.add(Conv2D(depth*4, 5, strides=2, padding='same'))
self.D.add(LeakyReLU(alpha=0.2))
self.D.add(Dropout(dropout))
self.D.add(Conv2D(depth*8, 5, strides=1, padding='same'))
self.D.add(LeakyReLU(alpha=0.2))
self.D.add(Dropout(dropout))
# Out: 1-dim probability
self.D.add(Flatten())
self.D.add(Dense(1))
self.D.add(Activation('sigmoid'))
self.D.summary()
return self.D
def generator(self):
if self.G:
return self.G
self.G = Sequential()
dropout = DROPOUT_RATE
depth = 64+64+64+64
dim = 7
# In: EMBEDDING_DIM
# Out: dim x dim x depth
self.G.add(Dense(dim*dim*depth, input_dim=EMBEDDING_DIM))
self.G.add(BatchNormalization(momentum=0.9))
self.G.add(Activation('relu'))
self.G.add(Reshape((dim, dim, depth)))
self.G.add(Dropout(dropout))
# In: dim x dim x depth
# Out: 2*dim x 2*dim x depth/2
self.G.add(UpSampling2D())
self.G.add(Conv2DTranspose(int(depth/2), 5, padding='same'))
self.G.add(BatchNormalization(momentum=0.9))
self.G.add(Activation('relu'))
self.G.add(Conv2DTranspose(int(depth/2), 5, padding='same'))
self.G.add(BatchNormalization(momentum=0.9))
self.G.add(Activation('relu'))
self.G.add(UpSampling2D())
self.G.add(Conv2DTranspose(int(depth/4), 5, padding='same'))
self.G.add(BatchNormalization(momentum=0.9))
self.G.add(Activation('relu'))
self.G.add(Conv2DTranspose(int(depth/4), 5, padding='same'))
self.G.add(BatchNormalization(momentum=0.9))
self.G.add(Activation('relu'))
self.G.add(Conv2DTranspose(int(depth/8), 5, padding='same'))
self.G.add(BatchNormalization(momentum=0.9))
self.G.add(Activation('relu'))
# Out: 28 x 28 x 1 grayscale image [0.0,1.0] per pix
self.G.add(Conv2DTranspose(1, 5, padding='same'))
self.G.add(Activation('sigmoid'))
self.G.summary()
return self.G
def discriminator_model(self):
if self.DM:
return self.DM
optimizer = RMSprop(lr=0.0002, decay=6e-8)
self.DM = Sequential()
self.DM.add(self.discriminator())
self.DM.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
return self.DM
def adversarial_model(self):
if self.AM:
return self.AM
optimizer = RMSprop(lr=0.0001, decay=3e-8)
self.AM = Sequential()
self.AM.add(self.generator())
self.AM.add(self.discriminator())
self.AM.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
return self.AM
class MNIST_DCGAN(object):
def __init__(self):
self.img_rows = 28
self.img_cols = 28
self.channel = 1
mnist = input_data.read_data_sets('mnist', one_hot=True)
self.x_train = mnist.train.images.reshape(
-1, self.img_rows, self.img_cols, 1)
self.x_train = self.x_train.astype(np.float32)
self.DCGAN = DCGAN()
self.generator = self.DCGAN.generator()
self.discriminator = self.DCGAN.discriminator()
self.discriminator_model = self.DCGAN.discriminator_model()
self.adversarial_model = self.DCGAN.adversarial_model()
def set_discriminator_trainable(self):
for layer in self.discriminator.layers:
layer.trainable = True
def set_discriminator_untrainable(self):
for layer in self.discriminator.layers:
layer.trainable = False
def train(self, train_steps=2000, batch_size=256, save_interval=0):
noise_input = None
if save_interval > 0:
noise_input = np.random.uniform(-1.0, 1.0,
size=[16, EMBEDDING_DIM])
for i in range(train_steps):
# sample some real images
images_train = self.x_train[
np.random.randint(0, self.x_train.shape[0],
size=batch_size), :, :, :]
# use the generator to generate same number of fake images
noise = np.random.uniform(-1.0, 1.0,
size=[batch_size, EMBEDDING_DIM])
images_fake = self.generator.predict(noise)
# stack real and fake images together
x = np.concatenate((images_train, images_fake))
# label real images as 1, fake images as 0
y = np.ones([batch_size*2, 1])
y[batch_size:, :] = 0
# train the discriminator for 1 step
self.set_discriminator_trainable()
d_loss = self.discriminator_model.train_on_batch(x, y)
# then train the whole GAN for 1 step; note that we freeze
# the weights in the discriminator and set the labels to 1
# here, so we are (hopefully) effectively training the
# generator for 1 step
y = np.ones([batch_size, 1])
noise = np.random.uniform(-1.0, 1.0,
size=[batch_size, EMBEDDING_DIM])
self.set_discriminator_untrainable()
a_loss = self.adversarial_model.train_on_batch(noise, y)
log_mesg = "%d: [D loss: %f, acc: %f]" % \
(i, d_loss[0], d_loss[1])
log_mesg = "%s [A loss: %f, acc: %f]" % \
(log_mesg, a_loss[0], a_loss[1])
print(log_mesg)
if save_interval > 0:
if (i+1) % save_interval == 0:
self.plot_images(save2file=True,
samples=noise_input.shape[0],
noise=noise_input,
step=(i+1))
def plot_images(self, save2file=False, fake=True, samples=16,
noise=None, step=0):
filename = 'sample.png'
if fake:
if noise is None:
noise = np.random.uniform(-1.0, 1.0,
size=[samples, EMBEDDING_DIM])
else:
filename = "mnist_%d.png" % step
images = self.generator.predict(noise)
else:
i = np.random.randint(0, self.x_train.shape[0], samples)
images = self.x_train[i, :, :, :]
plt.figure(figsize=(10, 10))
for i in range(images.shape[0]):
plt.subplot(4, 4, i+1)
image = images[i, :, :, :]
image = np.reshape(image, [self.img_rows, self.img_cols])
plt.imshow(image, cmap='gray')
plt.axis('off')
plt.tight_layout()
if save2file:
plt.savefig(filename)
plt.close('all')
else:
plt.show()
if __name__ == '__main__':
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
session = tf.Session(config=config)
K.set_session(session)
mnist_dcgan = MNIST_DCGAN()
timer = ElapsedTimer()
mnist_dcgan.train(train_steps=10000, batch_size=256, save_interval=500)
timer.elapsed_time()
# mnist_dcgan.plot_images(fake=True, save2file=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment