Created
December 31, 2017 16:43
-
-
Save Nelthirion/603f5aa3fb4b0f6eb06421fc172325c2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import cv2 | |
import numpy as np | |
from keras.layers import Input, Dense, Flatten, Reshape, ZeroPadding2D, Convolution2D, K, concatenate | |
from keras.layers.advanced_activations import LeakyReLU | |
from keras.layers.convolutional import Conv2D | |
from keras.models import Model | |
from keras.optimizers import Adam | |
from pixel_shuffler import PixelShuffler | |
from umeyama import umeyama | |
import matplotlib.pyplot as plt | |
from keras.utils import plot_model | |
random_transform_args = { | |
'rotation_range': 10, | |
'zoom_range': 0.05, | |
'shift_range': 0.05, | |
'random_flip': 0.4, | |
} | |
optimizer = Adam(lr=3e-5, beta_1=0.5, beta_2=0.999) | |
def get_training_data(image_paths, batch_size): | |
indices = np.random.randint(len(image_paths), size=batch_size) | |
for i, index in enumerate(indices): | |
image_path = image_paths[index] | |
image = cv2.imread(image_path) / 255.0 | |
image = random_transform(image, **random_transform_args) | |
warped_img, target_img = random_warp(image) | |
if i == 0: | |
warped_images = np.empty((batch_size,) + warped_img.shape, warped_img.dtype) | |
target_images = np.empty((batch_size,) + target_img.shape, warped_img.dtype) | |
warped_images[i] = warped_img | |
target_images[i] = target_img | |
return warped_images, target_images | |
IMAGE_SHAPE = (64, 64, 3) | |
ENCODER_DIM = 1024 | |
def conv(filters): | |
def block(x): | |
x = Conv2D(filters, kernel_size=5, strides=2, padding='same')(x) | |
x = LeakyReLU(0.1)(x) | |
return x | |
return block | |
def upscale(filters): | |
def block(x): | |
x = Conv2D(filters * 4, kernel_size=3, padding='same')(x) | |
x = LeakyReLU(0.1)(x) | |
x = PixelShuffler()(x) | |
return x | |
return block | |
def Encoder(): | |
input_ = Input(shape=IMAGE_SHAPE) | |
x = input_ | |
x = conv(128)(x) | |
x = conv(256)(x) | |
x = conv(512)(x) | |
x = conv(1024)(x) | |
x = Dense(ENCODER_DIM)(Flatten()(x)) | |
x = Dense(4 * 4 * 1024)(x) | |
x = Reshape((4, 4, 1024))(x) | |
x = upscale(512)(x) | |
return Model(input_, x) | |
def Decoder(): | |
input_ = Input(shape=(8, 8, 512)) | |
x = input_ | |
x = upscale(256)(x) | |
x = upscale(128)(x) | |
x = upscale(64)(x) | |
x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x) | |
return Model(input_, x) | |
def Discriminator(): | |
inputs = Input(shape=(64, 64, 3 * 2)) | |
d = ZeroPadding2D(padding=(1, 1))(inputs) | |
d = Convolution2D(64, 4, 4, subsample=(2, 2))(d) | |
d = LeakyReLU(alpha=0.2)(d) | |
d = ZeroPadding2D(padding=(1, 1))(d) | |
d = Convolution2D(128, 4, 4, subsample=(2, 2))(d) | |
d = LeakyReLU(alpha=0.2)(d) | |
d = ZeroPadding2D(padding=(1, 1))(d) | |
d = Convolution2D(256, 4, 4, subsample=(2, 2))(d) | |
d = LeakyReLU(alpha=0.2)(d) | |
d = ZeroPadding2D(padding=(1, 1))(d) | |
d = Convolution2D(512, 4, 4, subsample=(1, 1))(d) | |
# d = ZeroPadding2D(padding=(1, 1))(d) | |
# d = Convolution2D(1, 4, 4, subsample=(1, 1), activation='sigmoid')(d) | |
d = Flatten()(d) | |
d = Dense(1, activation='sigmoid')(d) | |
model = Model(inputs, d) | |
return model | |
def Generator_Containing_Discriminator(generator, discriminator): | |
warped_input = Input(IMAGE_SHAPE) | |
generator_fake_output = generator(warped_input) | |
merged = concatenate([warped_input, generator_fake_output], axis=-1) | |
discriminator.trainable = False | |
discriminator_output = discriminator(merged) | |
model = Model(warped_input, [generator_fake_output, discriminator_output]) | |
return model | |
def get_image_paths(directory): | |
print(os.scandir(directory)) | |
return [x.path for x in os.scandir(directory) if x.name.endswith(".jpg") or x.name.endswith(".png")] | |
def load_images(image_paths, batch_size=5): | |
iter_all_images = (cv2.imread(fn) for fn in image_paths) | |
for i, image in enumerate(iter_all_images): | |
if i == 0: | |
all_images = np.empty((len(image_paths),) + image.shape, dtype=image.dtype) | |
all_images[i] = image | |
return all_images | |
def get_transpose_axes(n): | |
if n % 2 == 0: | |
y_axes = list(range(1, n - 1, 2)) | |
x_axes = list(range(0, n - 1, 2)) | |
else: | |
y_axes = list(range(0, n - 1, 2)) | |
x_axes = list(range(1, n - 1, 2)) | |
return y_axes, x_axes, [n - 1] | |
def stack_images(images): | |
images_shape = np.array(images.shape) | |
new_axes = get_transpose_axes(len(images_shape)) | |
new_shape = [np.prod(images_shape[x]) for x in new_axes] | |
return np.transpose( | |
images, | |
axes=np.concatenate(new_axes) | |
).reshape(new_shape) | |
def random_transform(image, rotation_range, zoom_range, shift_range, random_flip): | |
h, w = image.shape[0:2] | |
rotation = np.random.uniform(-rotation_range, rotation_range) | |
scale = np.random.uniform(1 - zoom_range, 1 + zoom_range) | |
tx = np.random.uniform(-shift_range, shift_range) * w | |
ty = np.random.uniform(-shift_range, shift_range) * h | |
mat = cv2.getRotationMatrix2D((w // 2, h // 2), rotation, scale) | |
mat[:, 2] += (tx, ty) | |
result = cv2.warpAffine(image, mat, (w, h), borderMode=cv2.BORDER_REPLICATE) | |
if np.random.random() < random_flip: | |
result = result[:, ::-1] | |
return result | |
# get pair of random warped images from aligened face image | |
def random_warp(image): | |
assert image.shape == (256, 256, 3) | |
range_ = np.linspace(128 - 80, 128 + 80, 5) | |
mapx = np.broadcast_to(range_, (5, 5)) | |
mapy = mapx.T | |
mapx = mapx + np.random.normal(size=(5, 5), scale=5) | |
mapy = mapy + np.random.normal(size=(5, 5), scale=5) | |
interp_mapx = cv2.resize(mapx, (80, 80))[8:72, 8:72].astype('float32') | |
interp_mapy = cv2.resize(mapy, (80, 80))[8:72, 8:72].astype('float32') | |
warped_image = cv2.remap(image, interp_mapx, interp_mapy, cv2.INTER_LINEAR) | |
src_points = np.stack([mapx.ravel(), mapy.ravel()], axis=-1) | |
dst_points = np.mgrid[0:65:16, 0:65:16].T.reshape(-1, 2) | |
mat = umeyama(src_points, dst_points, True)[0:2] | |
target_image = cv2.warpAffine(image, mat, (64, 64)) | |
return warped_image, target_image | |
def generator_l1_loss(y_true, y_pred): | |
return K.mean(K.abs(y_pred - y_true), axis=(1, 2, 3)) | |
def discriminator_on_generator_loss(y_true, y_pred): | |
# return K.mean(K.binary_crossentropy(y_pred, y_true), axis=(1, 2, 3)) | |
return K.mean(K.binary_crossentropy(y_pred, y_true)) | |
source_actor = 'SOURCE_ACTOR' | |
target_actor = 'TARGET_ACTOR' | |
encoder = Encoder() | |
decoder_A = Decoder() | |
decoder_B = Decoder() | |
x = Input(shape=IMAGE_SHAPE) | |
generator_A = Model(x, decoder_A(encoder(x))) | |
generator_B = Model(x, decoder_B(encoder(x))) | |
generator_A.compile(optimizer=optimizer, loss='mean_absolute_error') | |
generator_B.compile(optimizer=optimizer, loss='mean_absolute_error') | |
discriminator_A = Discriminator() | |
discriminator_B = Discriminator() | |
discriminator_A.compile(loss=discriminator_on_generator_loss, optimizer=optimizer) | |
discriminator_B.compile(loss=discriminator_on_generator_loss, optimizer=optimizer) | |
GAN_A = Generator_Containing_Discriminator(generator_A, discriminator_A) | |
GAN_A.compile(loss=[generator_l1_loss, discriminator_on_generator_loss], loss_weights=[100.0, 1.0], optimizer=optimizer) | |
GAN_B = Generator_Containing_Discriminator(generator_B, discriminator_B) | |
GAN_B.compile(loss=[generator_l1_loss, discriminator_on_generator_loss], loss_weights=[100.0, 1.0], optimizer=optimizer) | |
# print('autoencoder_A.summary()') | |
# generator_A.summary() | |
# plot_model(generator_A, to_file='./Data/Models/[Progress]/%s to %s/Autoencoder_A.png' % (source_actor, target_actor), show_shapes=True) | |
# print('autoencoder_B.summary()') | |
# generator_B.summary() | |
# plot_model(generator_B, to_file='./Data/Models/[Progress]/%s to %s/Autoencoder_B.png' % (source_actor, target_actor), show_shapes=True) | |
# | |
# print('GAN_A.summary()') | |
# GAN_A.summary() | |
# plot_model(GAN_A, to_file='./Data/Models/[Progress]/%s to %s/GAN_A.png' % (source_actor, target_actor), show_shapes=True) | |
# print('GAN_B.summary()') | |
# GAN_B.summary() | |
# plot_model(GAN_B, to_file='./Data/Models/[Progress]/%s to %s/GAN_B.png' % (source_actor, target_actor), show_shapes=True) | |
# | |
# print('Discriminator_A.summary()') | |
# discriminator_A.summary() | |
# plot_model(discriminator_A, to_file='./Data/Models/[Progress]/%s to %s/Discriminator_A.png' % (source_actor, target_actor), show_shapes=True) | |
# print('GAN_B.summary()') | |
# discriminator_B.summary() | |
# plot_model(discriminator_B, to_file='./Data/Models/[Progress]/%s to %s/Discriminator_B.png' % (source_actor, target_actor), show_shapes=True) | |
def load_model_weights(): | |
try: | |
print("Loading Previous Models...") | |
encoder.load_weights(".\Models/encoder_%s_%s.h5" % (source_actor, target_actor)) | |
generator_A.load_weights(".\Models/generator_A_%s.h5" % source_actor) | |
generator_B.load_weights(".\Models/generator_B_%s.h5" % target_actor) | |
discriminator_A.load_weights(".\Models/discriminator_A_%s.h5" % source_actor) | |
discriminator_B.load_weights(".\Models/discriminator_B_%s.h5" % target_actor) | |
GAN_A.load_weights(".\Models/GAN_A_%s.h5" % source_actor) | |
GAN_B.load_weights(".\Models/GAN_B__%s.h5" % target_actor) | |
except Exception as e: | |
print("Couldn't Load Some or All of Models => %s" % e) | |
load_model_weights() | |
def save_model_weights(): | |
encoder.save_weights(".\Models/encoder_%s_%s.h5" % (source_actor, target_actor)) | |
generator_A.save_weights(".\Models/generator_A_%s.h5" % source_actor) | |
generator_B.save_weights(".\Models/generator_B_%s.h5" % target_actor) | |
discriminator_A.save_weights(".\Models/discriminator_A_%s.h5" % source_actor) | |
discriminator_B.save_weights(".\Models/discriminator_B_%s.h5" % target_actor) | |
GAN_A.save_weights(".\Models/GAN_A_%s.h5" % source_actor) | |
GAN_B.save_weights(".\Models/GAN_B__%s.h5" % target_actor) | |
image_paths_A = get_image_paths("./Actors/%s" % source_actor) | |
image_paths_B = get_image_paths("./Actors/%s" % target_actor) | |
# images_A += images_B.mean(axis=(0, 1, 2)) - images_A.mean(axis=(0, 1, 2)) | |
print("press 'q' to stop training and save model") | |
dloss_A = [] | |
dloss_B = [] | |
gloss_A = [] | |
gloss_B = [] | |
batch_size = 15 | |
for epoch in range(1000000): | |
warped_A, target_A = get_training_data(image_paths_A, batch_size) | |
warped_B, target_B = get_training_data(image_paths_B, batch_size) | |
predict_A = generator_A.predict(warped_A, batch_size) | |
predict_B = generator_B.predict(warped_B, batch_size) | |
# Train discriminator A | |
real_A = np.concatenate((warped_A, target_A), axis=-1) | |
fake_A = np.concatenate((warped_A, predict_A), axis=-1) | |
real_fake_A = np.concatenate((real_A, fake_A), axis=0) | |
real_fake_A_labels = np.zeros(2 * batch_size) | |
real_fake_A_labels[:batch_size] = 1.0 | |
discriminator_A.trainable = True | |
dloss_A.append(discriminator_A.train_on_batch(real_fake_A, real_fake_A_labels)) | |
discriminator_A.trainable = False | |
# Train GAN A | |
GAN_A_Labels = np.ones(batch_size) | |
gloss_A.append(GAN_A.train_on_batch(warped_A, [target_A, GAN_A_Labels])[1]) | |
# Train discriminator B | |
real_B = np.concatenate((warped_B, target_B), axis=-1) | |
fake_B = np.concatenate((warped_B, predict_B), axis=-1) | |
real_fake_B = np.concatenate((real_B, fake_B), axis=0) | |
real_fake_B_labels = np.zeros(2 * batch_size) | |
real_fake_B_labels[:batch_size] = 1.0 | |
discriminator_B.trainable = True | |
dloss_B.append(discriminator_B.train_on_batch(real_fake_B, real_fake_B_labels)) | |
discriminator_B.trainable = False | |
# Train GAN B | |
GAN_B_Labels = np.ones(batch_size) | |
gloss_B.append(GAN_B.train_on_batch(warped_B, [target_B, GAN_B_Labels])[1]) | |
if epoch % 500 == 0 and epoch > 0: | |
print("Saving Model -- Epoch %s" % str(epoch)) | |
save_model_weights() | |
print("Saving Model Finished") | |
if epoch % 100 == 0: | |
test_A = target_A[0:14] | |
test_B = target_B[0:14] | |
figure_A = np.stack([ | |
test_A, | |
generator_A.predict(test_A), | |
generator_B.predict(test_A), | |
], axis=1) | |
figure_B = np.stack([ | |
test_B, | |
generator_B.predict(test_B), | |
generator_A.predict(test_B), | |
], axis=1) | |
figure = np.concatenate([figure_A, figure_B], axis=0) | |
# figure = figure.reshape((4, 7) + figure.shape[1:]) | |
figure = stack_images(figure) | |
figure = np.clip(figure * 255, 0, 255).astype('uint8') | |
cv2.imwrite('./Data/Models/[Progress]/%s to %s/%s. %s to %s.jpg' % (source_actor, target_actor, epoch, source_actor, target_actor), figure) | |
plt.plot(dloss_A, label='D_Loss A') | |
plt.plot(gloss_A, label='G_Loss A') | |
plt.plot(dloss_B, label='D_Loss B') | |
plt.plot(gloss_B, label='G_Loss B') | |
plt.legend(loc='upper right') | |
plt.savefig('./Data/Models/[Progress]/%s to %s/%s to %s - Plot.png' % (source_actor, target_actor, source_actor, target_actor)) | |
plt.clf() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment