Skip to content

Instantly share code, notes, and snippets.

@Nelthirion
Created December 31, 2017 16:43
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
Star You must be signed in to star a gist
Embed
What would you like to do?
import os
import cv2
import numpy as np
from keras.layers import Input, Dense, Flatten, Reshape, ZeroPadding2D, Convolution2D, K, concatenate
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D
from keras.models import Model
from keras.optimizers import Adam
from pixel_shuffler import PixelShuffler
from umeyama import umeyama
import matplotlib.pyplot as plt
from keras.utils import plot_model
random_transform_args = {
'rotation_range': 10,
'zoom_range': 0.05,
'shift_range': 0.05,
'random_flip': 0.4,
}
optimizer = Adam(lr=3e-5, beta_1=0.5, beta_2=0.999)
def get_training_data(image_paths, batch_size):
indices = np.random.randint(len(image_paths), size=batch_size)
for i, index in enumerate(indices):
image_path = image_paths[index]
image = cv2.imread(image_path) / 255.0
image = random_transform(image, **random_transform_args)
warped_img, target_img = random_warp(image)
if i == 0:
warped_images = np.empty((batch_size,) + warped_img.shape, warped_img.dtype)
target_images = np.empty((batch_size,) + target_img.shape, warped_img.dtype)
warped_images[i] = warped_img
target_images[i] = target_img
return warped_images, target_images
IMAGE_SHAPE = (64, 64, 3)
ENCODER_DIM = 1024
def conv(filters):
def block(x):
x = Conv2D(filters, kernel_size=5, strides=2, padding='same')(x)
x = LeakyReLU(0.1)(x)
return x
return block
def upscale(filters):
def block(x):
x = Conv2D(filters * 4, kernel_size=3, padding='same')(x)
x = LeakyReLU(0.1)(x)
x = PixelShuffler()(x)
return x
return block
def Encoder():
input_ = Input(shape=IMAGE_SHAPE)
x = input_
x = conv(128)(x)
x = conv(256)(x)
x = conv(512)(x)
x = conv(1024)(x)
x = Dense(ENCODER_DIM)(Flatten()(x))
x = Dense(4 * 4 * 1024)(x)
x = Reshape((4, 4, 1024))(x)
x = upscale(512)(x)
return Model(input_, x)
def Decoder():
input_ = Input(shape=(8, 8, 512))
x = input_
x = upscale(256)(x)
x = upscale(128)(x)
x = upscale(64)(x)
x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
return Model(input_, x)
def Discriminator():
inputs = Input(shape=(64, 64, 3 * 2))
d = ZeroPadding2D(padding=(1, 1))(inputs)
d = Convolution2D(64, 4, 4, subsample=(2, 2))(d)
d = LeakyReLU(alpha=0.2)(d)
d = ZeroPadding2D(padding=(1, 1))(d)
d = Convolution2D(128, 4, 4, subsample=(2, 2))(d)
d = LeakyReLU(alpha=0.2)(d)
d = ZeroPadding2D(padding=(1, 1))(d)
d = Convolution2D(256, 4, 4, subsample=(2, 2))(d)
d = LeakyReLU(alpha=0.2)(d)
d = ZeroPadding2D(padding=(1, 1))(d)
d = Convolution2D(512, 4, 4, subsample=(1, 1))(d)
# d = ZeroPadding2D(padding=(1, 1))(d)
# d = Convolution2D(1, 4, 4, subsample=(1, 1), activation='sigmoid')(d)
d = Flatten()(d)
d = Dense(1, activation='sigmoid')(d)
model = Model(inputs, d)
return model
def Generator_Containing_Discriminator(generator, discriminator):
warped_input = Input(IMAGE_SHAPE)
generator_fake_output = generator(warped_input)
merged = concatenate([warped_input, generator_fake_output], axis=-1)
discriminator.trainable = False
discriminator_output = discriminator(merged)
model = Model(warped_input, [generator_fake_output, discriminator_output])
return model
def get_image_paths(directory):
print(os.scandir(directory))
return [x.path for x in os.scandir(directory) if x.name.endswith(".jpg") or x.name.endswith(".png")]
def load_images(image_paths, batch_size=5):
iter_all_images = (cv2.imread(fn) for fn in image_paths)
for i, image in enumerate(iter_all_images):
if i == 0:
all_images = np.empty((len(image_paths),) + image.shape, dtype=image.dtype)
all_images[i] = image
return all_images
def get_transpose_axes(n):
if n % 2 == 0:
y_axes = list(range(1, n - 1, 2))
x_axes = list(range(0, n - 1, 2))
else:
y_axes = list(range(0, n - 1, 2))
x_axes = list(range(1, n - 1, 2))
return y_axes, x_axes, [n - 1]
def stack_images(images):
images_shape = np.array(images.shape)
new_axes = get_transpose_axes(len(images_shape))
new_shape = [np.prod(images_shape[x]) for x in new_axes]
return np.transpose(
images,
axes=np.concatenate(new_axes)
).reshape(new_shape)
def random_transform(image, rotation_range, zoom_range, shift_range, random_flip):
h, w = image.shape[0:2]
rotation = np.random.uniform(-rotation_range, rotation_range)
scale = np.random.uniform(1 - zoom_range, 1 + zoom_range)
tx = np.random.uniform(-shift_range, shift_range) * w
ty = np.random.uniform(-shift_range, shift_range) * h
mat = cv2.getRotationMatrix2D((w // 2, h // 2), rotation, scale)
mat[:, 2] += (tx, ty)
result = cv2.warpAffine(image, mat, (w, h), borderMode=cv2.BORDER_REPLICATE)
if np.random.random() < random_flip:
result = result[:, ::-1]
return result
# get pair of random warped images from aligened face image
def random_warp(image):
assert image.shape == (256, 256, 3)
range_ = np.linspace(128 - 80, 128 + 80, 5)
mapx = np.broadcast_to(range_, (5, 5))
mapy = mapx.T
mapx = mapx + np.random.normal(size=(5, 5), scale=5)
mapy = mapy + np.random.normal(size=(5, 5), scale=5)
interp_mapx = cv2.resize(mapx, (80, 80))[8:72, 8:72].astype('float32')
interp_mapy = cv2.resize(mapy, (80, 80))[8:72, 8:72].astype('float32')
warped_image = cv2.remap(image, interp_mapx, interp_mapy, cv2.INTER_LINEAR)
src_points = np.stack([mapx.ravel(), mapy.ravel()], axis=-1)
dst_points = np.mgrid[0:65:16, 0:65:16].T.reshape(-1, 2)
mat = umeyama(src_points, dst_points, True)[0:2]
target_image = cv2.warpAffine(image, mat, (64, 64))
return warped_image, target_image
def generator_l1_loss(y_true, y_pred):
return K.mean(K.abs(y_pred - y_true), axis=(1, 2, 3))
def discriminator_on_generator_loss(y_true, y_pred):
# return K.mean(K.binary_crossentropy(y_pred, y_true), axis=(1, 2, 3))
return K.mean(K.binary_crossentropy(y_pred, y_true))
source_actor = 'SOURCE_ACTOR'
target_actor = 'TARGET_ACTOR'
encoder = Encoder()
decoder_A = Decoder()
decoder_B = Decoder()
x = Input(shape=IMAGE_SHAPE)
generator_A = Model(x, decoder_A(encoder(x)))
generator_B = Model(x, decoder_B(encoder(x)))
generator_A.compile(optimizer=optimizer, loss='mean_absolute_error')
generator_B.compile(optimizer=optimizer, loss='mean_absolute_error')
discriminator_A = Discriminator()
discriminator_B = Discriminator()
discriminator_A.compile(loss=discriminator_on_generator_loss, optimizer=optimizer)
discriminator_B.compile(loss=discriminator_on_generator_loss, optimizer=optimizer)
GAN_A = Generator_Containing_Discriminator(generator_A, discriminator_A)
GAN_A.compile(loss=[generator_l1_loss, discriminator_on_generator_loss], loss_weights=[100.0, 1.0], optimizer=optimizer)
GAN_B = Generator_Containing_Discriminator(generator_B, discriminator_B)
GAN_B.compile(loss=[generator_l1_loss, discriminator_on_generator_loss], loss_weights=[100.0, 1.0], optimizer=optimizer)
# print('autoencoder_A.summary()')
# generator_A.summary()
# plot_model(generator_A, to_file='./Data/Models/[Progress]/%s to %s/Autoencoder_A.png' % (source_actor, target_actor), show_shapes=True)
# print('autoencoder_B.summary()')
# generator_B.summary()
# plot_model(generator_B, to_file='./Data/Models/[Progress]/%s to %s/Autoencoder_B.png' % (source_actor, target_actor), show_shapes=True)
#
# print('GAN_A.summary()')
# GAN_A.summary()
# plot_model(GAN_A, to_file='./Data/Models/[Progress]/%s to %s/GAN_A.png' % (source_actor, target_actor), show_shapes=True)
# print('GAN_B.summary()')
# GAN_B.summary()
# plot_model(GAN_B, to_file='./Data/Models/[Progress]/%s to %s/GAN_B.png' % (source_actor, target_actor), show_shapes=True)
#
# print('Discriminator_A.summary()')
# discriminator_A.summary()
# plot_model(discriminator_A, to_file='./Data/Models/[Progress]/%s to %s/Discriminator_A.png' % (source_actor, target_actor), show_shapes=True)
# print('GAN_B.summary()')
# discriminator_B.summary()
# plot_model(discriminator_B, to_file='./Data/Models/[Progress]/%s to %s/Discriminator_B.png' % (source_actor, target_actor), show_shapes=True)
def load_model_weights():
try:
print("Loading Previous Models...")
encoder.load_weights(".\Models/encoder_%s_%s.h5" % (source_actor, target_actor))
generator_A.load_weights(".\Models/generator_A_%s.h5" % source_actor)
generator_B.load_weights(".\Models/generator_B_%s.h5" % target_actor)
discriminator_A.load_weights(".\Models/discriminator_A_%s.h5" % source_actor)
discriminator_B.load_weights(".\Models/discriminator_B_%s.h5" % target_actor)
GAN_A.load_weights(".\Models/GAN_A_%s.h5" % source_actor)
GAN_B.load_weights(".\Models/GAN_B__%s.h5" % target_actor)
except Exception as e:
print("Couldn't Load Some or All of Models => %s" % e)
load_model_weights()
def save_model_weights():
encoder.save_weights(".\Models/encoder_%s_%s.h5" % (source_actor, target_actor))
generator_A.save_weights(".\Models/generator_A_%s.h5" % source_actor)
generator_B.save_weights(".\Models/generator_B_%s.h5" % target_actor)
discriminator_A.save_weights(".\Models/discriminator_A_%s.h5" % source_actor)
discriminator_B.save_weights(".\Models/discriminator_B_%s.h5" % target_actor)
GAN_A.save_weights(".\Models/GAN_A_%s.h5" % source_actor)
GAN_B.save_weights(".\Models/GAN_B__%s.h5" % target_actor)
image_paths_A = get_image_paths("./Actors/%s" % source_actor)
image_paths_B = get_image_paths("./Actors/%s" % target_actor)
# images_A += images_B.mean(axis=(0, 1, 2)) - images_A.mean(axis=(0, 1, 2))
print("press 'q' to stop training and save model")
dloss_A = []
dloss_B = []
gloss_A = []
gloss_B = []
batch_size = 15
for epoch in range(1000000):
warped_A, target_A = get_training_data(image_paths_A, batch_size)
warped_B, target_B = get_training_data(image_paths_B, batch_size)
predict_A = generator_A.predict(warped_A, batch_size)
predict_B = generator_B.predict(warped_B, batch_size)
# Train discriminator A
real_A = np.concatenate((warped_A, target_A), axis=-1)
fake_A = np.concatenate((warped_A, predict_A), axis=-1)
real_fake_A = np.concatenate((real_A, fake_A), axis=0)
real_fake_A_labels = np.zeros(2 * batch_size)
real_fake_A_labels[:batch_size] = 1.0
discriminator_A.trainable = True
dloss_A.append(discriminator_A.train_on_batch(real_fake_A, real_fake_A_labels))
discriminator_A.trainable = False
# Train GAN A
GAN_A_Labels = np.ones(batch_size)
gloss_A.append(GAN_A.train_on_batch(warped_A, [target_A, GAN_A_Labels])[1])
# Train discriminator B
real_B = np.concatenate((warped_B, target_B), axis=-1)
fake_B = np.concatenate((warped_B, predict_B), axis=-1)
real_fake_B = np.concatenate((real_B, fake_B), axis=0)
real_fake_B_labels = np.zeros(2 * batch_size)
real_fake_B_labels[:batch_size] = 1.0
discriminator_B.trainable = True
dloss_B.append(discriminator_B.train_on_batch(real_fake_B, real_fake_B_labels))
discriminator_B.trainable = False
# Train GAN B
GAN_B_Labels = np.ones(batch_size)
gloss_B.append(GAN_B.train_on_batch(warped_B, [target_B, GAN_B_Labels])[1])
if epoch % 500 == 0 and epoch > 0:
print("Saving Model -- Epoch %s" % str(epoch))
save_model_weights()
print("Saving Model Finished")
if epoch % 100 == 0:
test_A = target_A[0:14]
test_B = target_B[0:14]
figure_A = np.stack([
test_A,
generator_A.predict(test_A),
generator_B.predict(test_A),
], axis=1)
figure_B = np.stack([
test_B,
generator_B.predict(test_B),
generator_A.predict(test_B),
], axis=1)
figure = np.concatenate([figure_A, figure_B], axis=0)
# figure = figure.reshape((4, 7) + figure.shape[1:])
figure = stack_images(figure)
figure = np.clip(figure * 255, 0, 255).astype('uint8')
cv2.imwrite('./Data/Models/[Progress]/%s to %s/%s. %s to %s.jpg' % (source_actor, target_actor, epoch, source_actor, target_actor), figure)
plt.plot(dloss_A, label='D_Loss A')
plt.plot(gloss_A, label='G_Loss A')
plt.plot(dloss_B, label='D_Loss B')
plt.plot(gloss_B, label='G_Loss B')
plt.legend(loc='upper right')
plt.savefig('./Data/Models/[Progress]/%s to %s/%s to %s - Plot.png' % (source_actor, target_actor, source_actor, target_actor))
plt.clf()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment