Skip to content

Instantly share code, notes, and snippets.

@NMZivkovic

NMZivkovic/gan.py

Created Dec 16, 2018
Embed
What would you like to do?
from __future__ import print_function, division
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# Keras modules
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, BatchNormalization, LeakyReLU
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
class GAN():
def __init__(self, image_shape, generator_input_dim, image_hepler):
optimizer = Adam(0.0002, 0.5)
self._image_helper = image_hepler
self.img_shape = image_shape
self.generator_input_dim = generator_input_dim
# Build models
self._build_generator_model()
self._build_and_compile_discriminator_model(optimizer)
self._build_and_compile_gan(optimizer)
def train(self, epochs, train_data, batch_size):
real = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
history = []
for epoch in range(epochs):
# Train Discriminator
batch_indexes = np.random.randint(0, train_data.shape[0], batch_size)
batch = train_data[batch_indexes]
genenerated = self._predict_noise(batch_size)
loss_real = self.discriminator_model.train_on_batch(batch, real)
loss_fake = self.discriminator_model.train_on_batch(genenerated, fake)
discriminator_loss = 0.5 * np.add(loss_real, loss_fake)
# Train Generator
noise = np.random.normal(0, 1, (batch_size, self.generator_input_dim))
generator_loss = self.gan.train_on_batch(noise, real)
# Plot the progress
print ("---------------------------------------------------------")
print ("******************Epoch {}***************************".format(epoch))
print ("Discriminator loss: {}".format(discriminator_loss[0]))
print ("Generator loss: {}".format(generator_loss))
print ("---------------------------------------------------------")
history.append({"D":discriminator_loss[0],"G":generator_loss})
# Save images from every hundereth epoch generated images
if epoch % 100 == 0:
self._save_images(epoch)
self._plot_loss(history)
self._image_helper.makegif("generated/")
def _build_generator_model(self):
generator_input = Input(shape=(self.generator_input_dim,))
generator_seqence = Sequential(
[Dense(256, input_dim=self.generator_input_dim),
LeakyReLU(alpha=0.2),
BatchNormalization(momentum=0.8),
Dense(512),
LeakyReLU(alpha=0.2),
BatchNormalization(momentum=0.8),
Dense(1024),
LeakyReLU(alpha=0.2),
BatchNormalization(momentum=0.8),
Dense(np.prod(self.img_shape), activation='tanh'),
Reshape(self.img_shape)])
generator_output_tensor = generator_seqence(generator_input)
self.generator_model = Model(generator_input, generator_output_tensor)
def _build_and_compile_discriminator_model(self, optimizer):
discriminator_input = Input(shape=self.img_shape)
discriminator_sequence = Sequential(
[Flatten(input_shape=self.img_shape),
Dense(512),
LeakyReLU(alpha=0.2),
Dense(256),
LeakyReLU(alpha=0.2),
Dense(1, activation='sigmoid')])
discriminator_tensor = discriminator_sequence(discriminator_input)
self.discriminator_model = Model(discriminator_input, discriminator_tensor)
self.discriminator_model.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
self.discriminator_model.trainable = False
def _build_and_compile_gan(self, optimizer):
real_input = Input(shape=(self.generator_input_dim,))
generator_output = self.generator_model(real_input)
discriminator_output = self.discriminator_model(generator_output)
self.gan = Model(real_input, discriminator_output)
self.gan.compile(loss='binary_crossentropy', optimizer=optimizer)
def _save_images(self, epoch):
generated = self._predict_noise(25)
generated = 0.5 * generated + 0.5
self._image_helper.save_image(generated, epoch, "generated/")
def _predict_noise(self, size):
noise = np.random.normal(0, 1, (size, self.generator_input_dim))
return self.generator_model.predict(noise)
def _plot_loss(self, history):
hist = pd.DataFrame(history)
plt.figure(figsize=(20,5))
for colnm in hist.columns:
plt.plot(hist[colnm],label=colnm)
plt.legend()
plt.ylabel("loss")
plt.xlabel("epochs")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment