Skip to content

Instantly share code, notes, and snippets.

@lixin9311
Last active April 15, 2019 14:37
Show Gist options
  • Save lixin9311/4cf69dc01632c27131e1b6f5b9124c8a to your computer and use it in GitHub Desktop.
Save lixin9311/4cf69dc01632c27131e1b6f5b9124c8a to your computer and use it in GitHub Desktop.
from keras.models import Model
from keras.layers import Conv2D, Input, UpSampling2D, Lambda, Layer
from keras.optimizers import *
from keras import backend as K
from keras.applications import VGG19
from ops import *
class UNet():
def __init__(self):
self.imgshape = (None, None, 3)
self.alpha = tf.placeholder_with_default(1., shape=[], name='alpha')
self.encoder = self.build_encoder()
self.encoder.trainable = False
self.model = self.build_model()
print(self.model.summary())
def build_model(self):
cinput = Input(self.imgshape, name='content_input')
sinput = Input(self.imgshape, name='style_input')
content_encoded = self.encoder(cinput)
style_encoded = self.encoder(sinput)
intermediate = Lambda(lambda x: AdaIN(x))([content_encoded, style_encoded, self.alpha])
decoder = self.build_decoder()
output = decoder(intermediate)
return Model([cinput, sinput], output)
def build_encoder(self):
vgg19_model = VGG19(include_top=False, weights='imagenet')
content_layer = vgg19_model.get_layer('block4_conv1').output
return Model(inputs=vgg19_model.input, outputs=content_layer, name='encoder_model')
def build_decoder(self):
layers = [ # HxW / InC->OutC
Conv2DReflect(256, 3, padding='valid', activation='relu'), # 32x32 / 512->256
UpSampling2D(), # 32x32 -> 64x64
Conv2DReflect(256, 3, padding='valid', activation='relu'), # 64x64 / 256->256
Conv2DReflect(256, 3, padding='valid', activation='relu'), # 64x64 / 256->256
Conv2DReflect(256, 3, padding='valid', activation='relu'), # 64x64 / 256->256
Conv2DReflect(128, 3, padding='valid', activation='relu'), # 64x64 / 256->128
UpSampling2D(), # 64x64 -> 128x128
Conv2DReflect(128, 3, padding='valid', activation='relu'), # 128x128 / 128->128
Conv2DReflect(64, 3, padding='valid', activation='relu'), # 128x128 / 128->64
UpSampling2D(), # 128x128 -> 256x256
Conv2DReflect(64, 3, padding='valid', activation='relu'), # 256x256 / 64->64
Conv2DReflect(3, 3, padding='valid', activation=None) # 256x256 / 64->3
]
input = Input((None,None,512))
x = input
with tf.variable_scope('decoder_vars'):
for layer in layers:
x = layer(x)
return Model(input, x, name='decoder_model')
from __future__ import division, print_function
import tensorflow as tf
from keras.layers import Conv2D, Lambda, Layer
import keras.backend as K
def pad_reflect(x, padding=1):
return tf.pad(
x, [[0, 0], [padding, padding], [padding, padding], [0, 0]],
mode='REFLECT')
def Conv2DReflect(*args, **kwargs):
return Lambda(lambda x: Conv2D(*args, **kwargs)(pad_reflect(x)))
def AdaIN(args, epsilon=1e-5):
'''
Borrowed from https://github.com/jonrei/tf-AdaIN
Normalizes the `content_features` with scaling and offset from `style_features`.
See "5. Adaptive Instance Normalization" in https://arxiv.org/abs/1703.06868 for details.
'''
content_features, style_features, alpha = args[0], args[1], args[2]
style_mean, style_variance = tf.nn.moments(style_features, [1,2], keep_dims=True)
content_mean, content_variance = tf.nn.moments(content_features, [1,2], keep_dims=True)
normalized_content_features = tf.nn.batch_normalization(content_features, content_mean,
content_variance, style_mean,
tf.sqrt(style_variance), epsilon)
normalized_content_features = alpha * normalized_content_features + (1 - alpha) * content_features
return normalized_content_features
def AdaBN(args, epsilon=1e-5):
content_features, style_features, alpha = args[0], args[1], args[2]
style_mean, style_variance = tf.nn.moments(style_features, [0,1,2], keep_dims=True)
content_mean, content_variance = tf.nn.moments(content_features, [1,2], keep_dims=True)
normalized_content_features = tf.nn.batch_normalization(content_features, content_mean,
content_variance, style_mean,
tf.sqrt(style_variance), epsilon)
normalized_content_features = alpha * normalized_content_features + (1 - alpha) * content_features
return normalized_content_features
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment