Skip to content

Instantly share code, notes, and snippets.

@Nelthirion
Created December 31, 2017 16:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save Nelthirion/603f5aa3fb4b0f6eb06421fc172325c2 to your computer and use it in GitHub Desktop.
Save Nelthirion/603f5aa3fb4b0f6eb06421fc172325c2 to your computer and use it in GitHub Desktop.
import os
import cv2
import numpy as np
from keras.layers import Input, Dense, Flatten, Reshape, ZeroPadding2D, Convolution2D, K, concatenate
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D
from keras.models import Model
from keras.optimizers import Adam
from pixel_shuffler import PixelShuffler
from umeyama import umeyama
import matplotlib.pyplot as plt
from keras.utils import plot_model
random_transform_args = {
'rotation_range': 10,
'zoom_range': 0.05,
'shift_range': 0.05,
'random_flip': 0.4,
}
optimizer = Adam(lr=3e-5, beta_1=0.5, beta_2=0.999)
def get_training_data(image_paths, batch_size):
indices = np.random.randint(len(image_paths), size=batch_size)
for i, index in enumerate(indices):
image_path = image_paths[index]
image = cv2.imread(image_path) / 255.0
image = random_transform(image, **random_transform_args)
warped_img, target_img = random_warp(image)
if i == 0:
warped_images = np.empty((batch_size,) + warped_img.shape, warped_img.dtype)
target_images = np.empty((batch_size,) + target_img.shape, warped_img.dtype)
warped_images[i] = warped_img
target_images[i] = target_img
return warped_images, target_images
IMAGE_SHAPE = (64, 64, 3)
ENCODER_DIM = 1024
def conv(filters):
def block(x):
x = Conv2D(filters, kernel_size=5, strides=2, padding='same')(x)
x = LeakyReLU(0.1)(x)
return x
return block
def upscale(filters):
def block(x):
x = Conv2D(filters * 4, kernel_size=3, padding='same')(x)
x = LeakyReLU(0.1)(x)
x = PixelShuffler()(x)
return x
return block
def Encoder():
input_ = Input(shape=IMAGE_SHAPE)
x = input_
x = conv(128)(x)
x = conv(256)(x)
x = conv(512)(x)
x = conv(1024)(x)
x = Dense(ENCODER_DIM)(Flatten()(x))
x = Dense(4 * 4 * 1024)(x)
x = Reshape((4, 4, 1024))(x)
x = upscale(512)(x)
return Model(input_, x)
def Decoder():
input_ = Input(shape=(8, 8, 512))
x = input_
x = upscale(256)(x)
x = upscale(128)(x)
x = upscale(64)(x)
x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
return Model(input_, x)
def Discriminator():
inputs = Input(shape=(64, 64, 3 * 2))
d = ZeroPadding2D(padding=(1, 1))(inputs)
d = Convolution2D(64, 4, 4, subsample=(2, 2))(d)
d = LeakyReLU(alpha=0.2)(d)
d = ZeroPadding2D(padding=(1, 1))(d)
d = Convolution2D(128, 4, 4, subsample=(2, 2))(d)
d = LeakyReLU(alpha=0.2)(d)
d = ZeroPadding2D(padding=(1, 1))(d)
d = Convolution2D(256, 4, 4, subsample=(2, 2))(d)
d = LeakyReLU(alpha=0.2)(d)
d = ZeroPadding2D(padding=(1, 1))(d)
d = Convolution2D(512, 4, 4, subsample=(1, 1))(d)
# d = ZeroPadding2D(padding=(1, 1))(d)
# d = Convolution2D(1, 4, 4, subsample=(1, 1), activation='sigmoid')(d)
d = Flatten()(d)
d = Dense(1, activation='sigmoid')(d)
model = Model(inputs, d)
return model
def Generator_Containing_Discriminator(generator, discriminator):
warped_input = Input(IMAGE_SHAPE)
generator_fake_output = generator(warped_input)
merged = concatenate([warped_input, generator_fake_output], axis=-1)
discriminator.trainable = False
discriminator_output = discriminator(merged)
model = Model(warped_input, [generator_fake_output, discriminator_output])
return model
def get_image_paths(directory):
print(os.scandir(directory))
return [x.path for x in os.scandir(directory) if x.name.endswith(".jpg") or x.name.endswith(".png")]
def load_images(image_paths, batch_size=5):
iter_all_images = (cv2.imread(fn) for fn in image_paths)
for i, image in enumerate(iter_all_images):
if i == 0:
all_images = np.empty((len(image_paths),) + image.shape, dtype=image.dtype)
all_images[i] = image
return all_images
def get_transpose_axes(n):
if n % 2 == 0:
y_axes = list(range(1, n - 1, 2))
x_axes = list(range(0, n - 1, 2))
else:
y_axes = list(range(0, n - 1, 2))
x_axes = list(range(1, n - 1, 2))
return y_axes, x_axes, [n - 1]
def stack_images(images):
images_shape = np.array(images.shape)
new_axes = get_transpose_axes(len(images_shape))
new_shape = [np.prod(images_shape[x]) for x in new_axes]
return np.transpose(
images,
axes=np.concatenate(new_axes)
).reshape(new_shape)
def random_transform(image, rotation_range, zoom_range, shift_range, random_flip):
h, w = image.shape[0:2]
rotation = np.random.uniform(-rotation_range, rotation_range)
scale = np.random.uniform(1 - zoom_range, 1 + zoom_range)
tx = np.random.uniform(-shift_range, shift_range) * w
ty = np.random.uniform(-shift_range, shift_range) * h
mat = cv2.getRotationMatrix2D((w // 2, h // 2), rotation, scale)
mat[:, 2] += (tx, ty)
result = cv2.warpAffine(image, mat, (w, h), borderMode=cv2.BORDER_REPLICATE)
if np.random.random() < random_flip:
result = result[:, ::-1]
return result
# get pair of random warped images from aligened face image
def random_warp(image):
assert image.shape == (256, 256, 3)
range_ = np.linspace(128 - 80, 128 + 80, 5)
mapx = np.broadcast_to(range_, (5, 5))
mapy = mapx.T
mapx = mapx + np.random.normal(size=(5, 5), scale=5)
mapy = mapy + np.random.normal(size=(5, 5), scale=5)
interp_mapx = cv2.resize(mapx, (80, 80))[8:72, 8:72].astype('float32')
interp_mapy = cv2.resize(mapy, (80, 80))[8:72, 8:72].astype('float32')
warped_image = cv2.remap(image, interp_mapx, interp_mapy, cv2.INTER_LINEAR)
src_points = np.stack([mapx.ravel(), mapy.ravel()], axis=-1)
dst_points = np.mgrid[0:65:16, 0:65:16].T.reshape(-1, 2)
mat = umeyama(src_points, dst_points, True)[0:2]
target_image = cv2.warpAffine(image, mat, (64, 64))
return warped_image, target_image
def generator_l1_loss(y_true, y_pred):
return K.mean(K.abs(y_pred - y_true), axis=(1, 2, 3))
def discriminator_on_generator_loss(y_true, y_pred):
# return K.mean(K.binary_crossentropy(y_pred, y_true), axis=(1, 2, 3))
return K.mean(K.binary_crossentropy(y_pred, y_true))
source_actor = 'SOURCE_ACTOR'
target_actor = 'TARGET_ACTOR'
encoder = Encoder()
decoder_A = Decoder()
decoder_B = Decoder()
x = Input(shape=IMAGE_SHAPE)
generator_A = Model(x, decoder_A(encoder(x)))
generator_B = Model(x, decoder_B(encoder(x)))
generator_A.compile(optimizer=optimizer, loss='mean_absolute_error')
generator_B.compile(optimizer=optimizer, loss='mean_absolute_error')
discriminator_A = Discriminator()
discriminator_B = Discriminator()
discriminator_A.compile(loss=discriminator_on_generator_loss, optimizer=optimizer)
discriminator_B.compile(loss=discriminator_on_generator_loss, optimizer=optimizer)
GAN_A = Generator_Containing_Discriminator(generator_A, discriminator_A)
GAN_A.compile(loss=[generator_l1_loss, discriminator_on_generator_loss], loss_weights=[100.0, 1.0], optimizer=optimizer)
GAN_B = Generator_Containing_Discriminator(generator_B, discriminator_B)
GAN_B.compile(loss=[generator_l1_loss, discriminator_on_generator_loss], loss_weights=[100.0, 1.0], optimizer=optimizer)
# print('autoencoder_A.summary()')
# generator_A.summary()
# plot_model(generator_A, to_file='./Data/Models/[Progress]/%s to %s/Autoencoder_A.png' % (source_actor, target_actor), show_shapes=True)
# print('autoencoder_B.summary()')
# generator_B.summary()
# plot_model(generator_B, to_file='./Data/Models/[Progress]/%s to %s/Autoencoder_B.png' % (source_actor, target_actor), show_shapes=True)
#
# print('GAN_A.summary()')
# GAN_A.summary()
# plot_model(GAN_A, to_file='./Data/Models/[Progress]/%s to %s/GAN_A.png' % (source_actor, target_actor), show_shapes=True)
# print('GAN_B.summary()')
# GAN_B.summary()
# plot_model(GAN_B, to_file='./Data/Models/[Progress]/%s to %s/GAN_B.png' % (source_actor, target_actor), show_shapes=True)
#
# print('Discriminator_A.summary()')
# discriminator_A.summary()
# plot_model(discriminator_A, to_file='./Data/Models/[Progress]/%s to %s/Discriminator_A.png' % (source_actor, target_actor), show_shapes=True)
# print('GAN_B.summary()')
# discriminator_B.summary()
# plot_model(discriminator_B, to_file='./Data/Models/[Progress]/%s to %s/Discriminator_B.png' % (source_actor, target_actor), show_shapes=True)
def load_model_weights():
try:
print("Loading Previous Models...")
encoder.load_weights(".\Models/encoder_%s_%s.h5" % (source_actor, target_actor))
generator_A.load_weights(".\Models/generator_A_%s.h5" % source_actor)
generator_B.load_weights(".\Models/generator_B_%s.h5" % target_actor)
discriminator_A.load_weights(".\Models/discriminator_A_%s.h5" % source_actor)
discriminator_B.load_weights(".\Models/discriminator_B_%s.h5" % target_actor)
GAN_A.load_weights(".\Models/GAN_A_%s.h5" % source_actor)
GAN_B.load_weights(".\Models/GAN_B__%s.h5" % target_actor)
except Exception as e:
print("Couldn't Load Some or All of Models => %s" % e)
load_model_weights()
def save_model_weights():
encoder.save_weights(".\Models/encoder_%s_%s.h5" % (source_actor, target_actor))
generator_A.save_weights(".\Models/generator_A_%s.h5" % source_actor)
generator_B.save_weights(".\Models/generator_B_%s.h5" % target_actor)
discriminator_A.save_weights(".\Models/discriminator_A_%s.h5" % source_actor)
discriminator_B.save_weights(".\Models/discriminator_B_%s.h5" % target_actor)
GAN_A.save_weights(".\Models/GAN_A_%s.h5" % source_actor)
GAN_B.save_weights(".\Models/GAN_B__%s.h5" % target_actor)
image_paths_A = get_image_paths("./Actors/%s" % source_actor)
image_paths_B = get_image_paths("./Actors/%s" % target_actor)
# images_A += images_B.mean(axis=(0, 1, 2)) - images_A.mean(axis=(0, 1, 2))
print("press 'q' to stop training and save model")
dloss_A = []
dloss_B = []
gloss_A = []
gloss_B = []
batch_size = 15
for epoch in range(1000000):
warped_A, target_A = get_training_data(image_paths_A, batch_size)
warped_B, target_B = get_training_data(image_paths_B, batch_size)
predict_A = generator_A.predict(warped_A, batch_size)
predict_B = generator_B.predict(warped_B, batch_size)
# Train discriminator A
real_A = np.concatenate((warped_A, target_A), axis=-1)
fake_A = np.concatenate((warped_A, predict_A), axis=-1)
real_fake_A = np.concatenate((real_A, fake_A), axis=0)
real_fake_A_labels = np.zeros(2 * batch_size)
real_fake_A_labels[:batch_size] = 1.0
discriminator_A.trainable = True
dloss_A.append(discriminator_A.train_on_batch(real_fake_A, real_fake_A_labels))
discriminator_A.trainable = False
# Train GAN A
GAN_A_Labels = np.ones(batch_size)
gloss_A.append(GAN_A.train_on_batch(warped_A, [target_A, GAN_A_Labels])[1])
# Train discriminator B
real_B = np.concatenate((warped_B, target_B), axis=-1)
fake_B = np.concatenate((warped_B, predict_B), axis=-1)
real_fake_B = np.concatenate((real_B, fake_B), axis=0)
real_fake_B_labels = np.zeros(2 * batch_size)
real_fake_B_labels[:batch_size] = 1.0
discriminator_B.trainable = True
dloss_B.append(discriminator_B.train_on_batch(real_fake_B, real_fake_B_labels))
discriminator_B.trainable = False
# Train GAN B
GAN_B_Labels = np.ones(batch_size)
gloss_B.append(GAN_B.train_on_batch(warped_B, [target_B, GAN_B_Labels])[1])
if epoch % 500 == 0 and epoch > 0:
print("Saving Model -- Epoch %s" % str(epoch))
save_model_weights()
print("Saving Model Finished")
if epoch % 100 == 0:
test_A = target_A[0:14]
test_B = target_B[0:14]
figure_A = np.stack([
test_A,
generator_A.predict(test_A),
generator_B.predict(test_A),
], axis=1)
figure_B = np.stack([
test_B,
generator_B.predict(test_B),
generator_A.predict(test_B),
], axis=1)
figure = np.concatenate([figure_A, figure_B], axis=0)
# figure = figure.reshape((4, 7) + figure.shape[1:])
figure = stack_images(figure)
figure = np.clip(figure * 255, 0, 255).astype('uint8')
cv2.imwrite('./Data/Models/[Progress]/%s to %s/%s. %s to %s.jpg' % (source_actor, target_actor, epoch, source_actor, target_actor), figure)
plt.plot(dloss_A, label='D_Loss A')
plt.plot(gloss_A, label='G_Loss A')
plt.plot(dloss_B, label='D_Loss B')
plt.plot(gloss_B, label='G_Loss B')
plt.legend(loc='upper right')
plt.savefig('./Data/Models/[Progress]/%s to %s/%s to %s - Plot.png' % (source_actor, target_actor, source_actor, target_actor))
plt.clf()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment