Skip to content

Instantly share code, notes, and snippets.

@asterisk37n asterisk37n/gan-mnist.py
Last active Jan 21, 2019

Embed
What would you like to do?
simple Generative adversarial networks for MNIST
import matplotlib.pyplot as plt
import numpy as np
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten, BatchNormalization
from keras.layers.advanced_activations import LeakyReLU
from keras.layers import Conv2D, MaxPooling2D, Reshape, UpSampling2D, InputLayer
from keras.optimizers import Adam
import os
class GAN():
def __init__(self):
self.img_shape = (28, 28, 1) # MNIST
self.z_dim = 100
optimizer = Adam(0.0002, 0.5)
# Discriminator
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
# Generator + Discriminator
self.generator = self.build_generator()
self.discriminator.trainable = False
self.combined = Sequential([self.generator, self.discriminator])
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
os.makedirs('images', exist_ok=True)
self.noise = None
def build_generator(self):
model = Sequential()
model.add(Dense(np.product(self.img_shape) * 16// 2**4, input_shape=(self.z_dim,)))
model.add(Reshape((self.img_shape[0] // 2**2, self.img_shape[1] // 2**2, 16)))
model.add(UpSampling2D(size=2))
model.add(Conv2D(32, kernel_size=2, padding='same', activation='relu'))
model.add(BatchNormalization(momentum=0.8))
model.add(UpSampling2D(size=2))
model.add(Conv2D(1, kernel_size=2, padding='same', activation='tanh'))
model.summary()
return model
def build_discriminator(self):
model = Sequential()
model.add(Conv2D(32, kernel_size=2, strides=2, padding='same', input_shape=(self.img_shape)))
model.add(LeakyReLU(alpha=0.2))
model.add(Conv2D(64, kernel_size=2, strides=2, padding='same'))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Flatten())
model.add(Dense(1, activation='sigmoid'))
model.summary()
return model
def train(self, epochs, batch_size=128, sample_interval=50):
(X_train, _), (_, _) = mnist.load_data()
X_train = X_train / 127.5 - 1. # Rescale -1 to 1
X_train = np.expand_dims(X_train, axis=3)
# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(epochs):
idx = np.random.randint(0, X_train.shape[0], batch_size)
real_imgs = X_train[idx]
noise = np.random.normal(0, 1, (batch_size, self.z_dim))
fake_imgs = self.generator.predict(noise)
# Train the discriminator and generator
d_score_real = self.discriminator.train_on_batch(real_imgs, valid)
d_score_fake = self.discriminator.train_on_batch(fake_imgs, fake)
g_score = self.combined.train_on_batch(noise, valid)
# Plot the progress
d_score = 0.5 * np.add(d_score_real, d_score_fake)
print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_score[0], 100*d_score[1], g_score))
# Save fake image snapshot
if epoch % sample_interval == 0:
self.sample_images(epoch)
def sample_images(self, epoch, r=10, c=10):
if self.noise is None:
self.noise = np.random.normal(0, 1, (r * c, self.z_dim))
gen_imgs = self.generator.predict(self.noise)
gen_imgs = 0.5 * gen_imgs + 0.5 # Rescale [-1, 1] images into [0, 1]
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray', vmin=0, vmax=1)
axs[i,j].axis('off')
cnt += 1
fig.savefig("images/%d.png" % epoch)
plt.close()
if __name__ == '__main__':
gan = GAN()
gan.train(epochs=100001, batch_size=128, sample_interval=100)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.