Last active
December 4, 2020 17:46
-
-
Save furanzone/0b8eb824b71d5be292debb8581068b46 to your computer and use it in GitHub Desktop.
Conditional WGAN GP
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
original source: https://github.com/kongyanye/cwgan-gp/blob/master/cwgan_gp.py | |
''' | |
from __future__ import print_function, division | |
from keras.datasets import mnist | |
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply, Add, Concatenate | |
from keras.layers import BatchNormalization, Activation, ZeroPadding2D, Embedding, LSTM, MaxPool1D | |
from keras.layers.advanced_activations import LeakyReLU | |
from keras.layers.convolutional import UpSampling2D, Conv2D, Conv1D | |
from keras.models import Sequential, Model, load_model | |
from keras.utils import plot_model | |
from keras.optimizers import RMSprop | |
from functools import partial | |
import tensorflow as tf | |
import keras.backend as K | |
import matplotlib.pyplot as plt | |
import math | |
import numpy as np | |
import pandas as pd | |
from sklearn.preprocessing import LabelEncoder, MinMaxScaler | |
from scipy.stats import mode | |
tf.compat.v1.disable_v2_behavior() | |
import datetime | |
print(datetime.datetime.now()) | |
class RandomWeightedAverage(Add): | |
# Provides a (random) weighted average between real and generated image samples | |
def _merge_function(self, inputs): | |
input1, input2 = inputs | |
global batch_size | |
alpha = K.random_uniform((batch_size, 1, 1)) | |
return (alpha * input1) + ((1 - alpha) * input2) | |
class CWGANGP(): | |
def __init__(self, epochs=100, batch_size=32, sample_interval=5): | |
self.img_rows = 100 | |
self.img_cols = 1 | |
self.nclasses = 5 | |
self.img_shape = (self.img_rows, self.img_cols) | |
self.latent_dim = 100 | |
self.losslog = [] | |
self.epochs = epochs | |
self.batch_size = batch_size | |
self.sample_interval = sample_interval | |
# Following parameter and optimizer set as recommended in paper | |
self.n_critic = 5 | |
optimizer = RMSprop(lr=0.00005) | |
# Build the generator and critic | |
self.generator = self.build_generator() | |
self.critic = self.build_critic() | |
#------------------------------- | |
# Construct Computational Graph | |
# for the Critic | |
#------------------------------- | |
# Freeze generator's layers while training critic | |
self.generator.trainable = False | |
# Image input (real sample) | |
real_img = Input(shape=self.img_shape) | |
# Noise input | |
z_disc = Input(shape=(self.latent_dim,)) | |
# Generate image based of noise (fake sample) and add label to the input | |
label = Input(shape=(1,)) | |
fake_img = self.generator([z_disc, label]) | |
# Discriminator determines validity of the real and fake images | |
fake = self.critic([fake_img, label]) | |
valid = self.critic([real_img, label]) | |
# Construct weighted average between real and fake images | |
interpolated_img = RandomWeightedAverage()([real_img, fake_img]) | |
# Determine validity of weighted sample | |
validity_interpolated = self.critic([interpolated_img, label]) | |
# Use Python partial to provide loss function with additional | |
# 'averaged_samples' argument | |
partial_gp_loss = partial(self.gradient_penalty_loss, | |
averaged_samples=interpolated_img) | |
partial_gp_loss.__name__ = 'gradient_penalty' # Keras requires function names | |
self.critic_model = Model(inputs=[real_img, label, z_disc], outputs=[valid, fake, validity_interpolated]) | |
self.critic_model.compile(loss=[self.wasserstein_loss, | |
self.wasserstein_loss, | |
partial_gp_loss], | |
optimizer=optimizer, | |
loss_weights=[1, 1, 10]) | |
#------------------------------- | |
# Construct Computational Graph | |
# for Generator | |
#------------------------------- | |
# For the generator we freeze the critic's layers | |
self.critic.trainable = False | |
self.generator.trainable = True | |
# Sampled noise for input to generator | |
z_gen = Input(shape=(100,)) | |
# add label to the input | |
label = Input(shape=(1,)) | |
# Generate images based of noise | |
img = self.generator([z_gen, label]) | |
# Discriminator determines validity | |
valid = self.critic([img, label]) | |
# Defines generator model | |
self.generator_model = Model([z_gen, label], valid) | |
self.generator_model.compile(loss=self.wasserstein_loss, optimizer=optimizer) | |
def load_dataset(): | |
# load TRAIN data | |
df29 = pd.read_csv('/home/furanzu/FURANZU/Datasets/MHEALTH/mHEALTH_filtered_MEDBUT_train.csv') | |
df29 = df29.drop(['subject_id'], axis=1) | |
# encode the label first >> because the label does not start from 0 | |
le = LabelEncoder() | |
df29['Label'] = le.fit_transform(df29['Label']) | |
print('Data Shape: ', df29.shape) | |
print('Class Label: ',df29.iloc[:,-1].unique()) | |
# make array of data | |
dataset = df29.iloc[:,3:].values | |
dataset = dataset.astype('float64') | |
dataxy = dataset[:,0:] | |
# get maximum value from labels and make as integer value | |
maxer = np.amax(dataset[:,1]) | |
maxeri = maxer.astype('int') | |
maxchannels = maxeri | |
# make array of labels --> dimension (9999,) | |
idataset = np.zeros([len(dataset),],dtype=int) | |
idataset = dataset[:,1] | |
idataset = idataset.astype(int) | |
# init normalization with scale 0-1 | |
scaler = MinMaxScaler(copy=False) | |
# make train and test set | |
X_train = dataset[:,0] | |
y_train = idataset[:] | |
# define window size | |
window = 100 | |
overlap = int(window * 0.5) | |
# define n = len(dataset) - window_size | |
n = ((np.where(np.any(dataxy, axis=1))[0][-1] + 1) // window) * window | |
# perform normalization --> xx dimension (xxxx,1) | |
xx = scaler.fit_transform(dataxy[:n,0].reshape(-1,1)) | |
xx = dataxy[:n,0].reshape(-1,1) | |
yy = idataset.reshape(-1,1) | |
# segment the data with window size and overlapping 50% | |
X_train = np.asarray([xx[i:i+window] for i in range (0, (n - window), overlap)]) | |
y_train = np.asarray([mode(yy[i:i+window])[0] for i in range (0, (n - window), overlap)]).reshape(-1,1) | |
# make copy | |
X = X_train.copy() | |
trainy = y_train.copy() | |
print('Segmented Data shape: ', X.shape, trainy.shape) | |
return (X, trainy) | |
def gradient_penalty_loss(self, y_true, y_pred, averaged_samples): | |
""" | |
Computes gradient penalty based on prediction and weighted real / fake samples | |
""" | |
gradients = K.gradients(y_pred, averaged_samples)[0] | |
# compute the euclidean norm by squaring ... | |
gradients_sqr = K.square(gradients) | |
# ... summing over the rows ... | |
gradients_sqr_sum = K.sum(gradients_sqr, | |
axis=np.arange(1, len(gradients_sqr.shape))) | |
# ... and sqrt | |
gradient_l2_norm = K.sqrt(gradients_sqr_sum) | |
# compute lambda * (1 - ||grad||)^2 still for each single sample | |
gradient_penalty = K.square(1 - gradient_l2_norm) | |
# return the mean as loss over all the batch samples | |
return K.mean(gradient_penalty) | |
def wasserstein_loss(self, y_true, y_pred): | |
return K.mean(y_true * y_pred) | |
def build_generator(self): | |
model = Sequential() | |
model.add(Dense(10*10, activation="relu", input_dim=100)) | |
model.add(Reshape((-1,1))) | |
model.add(Conv1D(16,5, strides=1, padding='same')) | |
model.add(Activation('relu')) | |
model.add(Dense(1)) | |
model.add(Activation('linear')) | |
#model.summary() | |
noise = Input(shape=(self.latent_dim,)) | |
label = Input(shape=(1,), dtype='int32') | |
label_embedding = Flatten()(Embedding(self.nclasses, self.latent_dim)(label)) | |
model_input = multiply([noise, label_embedding]) | |
img = model(model_input) | |
# output (100,1) | |
return Model([noise, label], img) | |
def build_critic(self): | |
# label input | |
in_label = Input(shape=(1,), dtype='int32') | |
li = Embedding(self.nclasses, 100)(in_label) | |
li = Reshape((self.img_shape[0], 1))(li) | |
# image input | |
in_image = Input(shape=self.img_shape) | |
# concat label as a channel | |
merge = Concatenate()([in_image, li]) | |
# 1DConv | |
fe = Conv1D(64, 5)(merge) | |
fe = LeakyReLU(alpha=0.2)(fe) | |
fe = Dropout(0.4)(fe) | |
fe = Conv1D(128, 3)(fe) | |
fe = LeakyReLU(alpha=0.2)(fe) | |
fe = Dropout(0.4)(fe) | |
# maxpooling | |
fe = MaxPool1D(pool_size=2)(fe) | |
fe = Flatten()(fe) | |
fe = Dense(100, activation='relu')(fe) | |
# output | |
out_layer = Dense(1, activation='sigmoid')(fe) | |
# define model | |
model = Model([in_image, in_label], out_layer) | |
# output (1) real/fake | |
return model | |
def train(self): | |
# load mHEALTH dataset | |
(X_train, y_train) = load_dataset() | |
# Adversarial ground truths | |
valid = -np.ones((self.batch_size, 1)) | |
fake = np.ones((self.batch_size, 1)) | |
# Dummy gt for gradient penalty | |
dummy = np.zeros((self.batch_size, 1)) | |
for epoch in range(self.epochs): | |
for _ in range(self.n_critic): | |
# --------------------- | |
# Train Discriminator | |
# --------------------- | |
# Select a random batch of images | |
idx = np.random.randint(0, X_train.shape[0], self.batch_size) | |
imgs, labels = X_train[idx], y_train[idx] | |
# Sample generator input | |
noise = np.random.normal(0, 1, (self.batch_size, self.latent_dim)) | |
# Train the critic | |
d_loss = self.critic_model.train_on_batch([imgs, labels, noise], [valid, fake, dummy]) | |
# --------------------- | |
# Train Generator | |
# --------------------- | |
sampled_labels = np.random.randint(0, self.nclasses, self.batch_size).reshape(-1, 1) | |
g_loss = self.generator_model.train_on_batch([noise, sampled_labels], valid) | |
# Plot the progress | |
if (epoch+1) % 10 == 1: | |
print ("%d [D loss: %f] [G loss: %f]" % (epoch, d_loss[0], g_loss)) | |
self.losslog.append([d_loss[0], g_loss]) | |
# If at save interval => save generated image samples | |
if epoch % self.sample_interval == 0: | |
# make FAKE SAMPLES | |
self.sample_images(epoch) | |
self.generator.save_weights('wgan-generator', overwrite=True) | |
self.critic.save_weights('wgan-discriminator', overwrite=True) | |
self.generator.save('wgan-discriminator.h5') | |
with open('loss.log', 'w') as f: | |
f.writelines('d_loss, g_loss\n') | |
for each in self.losslog: | |
f.writelines('%s, %s\n'%(each[0], each[1])) | |
def sample_images(self, epoch): | |
r, c = 10, 10 | |
# generate NOISE + LATENT SPACE | |
noise = np.random.normal(0, 1, (r * c, self.latent_dim)) | |
sampled_labels = np.array(np.random.randint(0,5,100).reshape(-1,1)) | |
# generate FAKE SAMPLES | |
gen_imgs = self.generator.predict([noise, sampled_labels]) | |
# print('gen_imgs',gen_imgs.shape) | |
# gen_imgs = self.combine_images(gen_imgs) | |
# print(gen_imgs.shape) | |
plt.figure(figsize=(5,5)) | |
plt.plot(gen_imgs[epoch]) | |
plt.savefig("images/mhealth_%d.png" % epoch) | |
plt.close() | |
def combine_images(self, generated_images): | |
num = generated_images.shape[0] | |
width = int(math.sqrt(num)) | |
height = int(math.ceil(float(num)/width)) | |
shape = generated_images.shape[1:] | |
image = np.zeros((height*shape[0], width*shape[1]), | |
dtype=generated_images.dtype) | |
for index, img in enumerate(generated_images): | |
i = int(index/width) | |
j = index % width | |
image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = \ | |
img[:, :] | |
return image | |
def generate_images(self, label): | |
self.generator.load_weights('../CWGAN GP/wgan-generator') | |
noise = np.random.normal(0, 1, (1, self.latent_dim)) | |
gen_imgs = self.generator.predict([noise, np.array(label).reshape(-1,1)]) | |
plt.figure(figsize=(5,5)) | |
plt.plot(gen_imgs) | |
plt.close() | |
if __name__ == '__main__': | |
epochs = 100 | |
batch_size = 32 | |
sample_interval = 10 | |
wgan = CWGANGP(epochs, batch_size, sample_interval) | |
wgan.train() | |
# generate for specific class | |
# wgan.generate_images(1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment