Skip to content

Instantly share code, notes, and snippets.

@matheushent
Last active April 19, 2020 18:53
Show Gist options
  • Save matheushent/e5ac809c96bb1ab276514e7ed46517bb to your computer and use it in GitHub Desktop.
Save matheushent/e5ac809c96bb1ab276514e7ed46517bb to your computer and use it in GitHub Desktop.
Stand alone code for tensorflow issue #38518
import tensorflow_addons as tfa
import tensorflow as tf
def get_norm_layer(norm):
"""Utility function to get the normalization layer
"""
if norm == None:
return lambda: lambda x: x
elif norm == 'batch_norm':
return tf.keras.layers.BatchNormalization
elif norm == 'instance_norm':
return tfa.layers.InstanceNormalization
elif norm == 'layer_norm':
return tf.keras.layers.LayerNormalization
class LinearDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
def __init__(self, init_learning_rate, epochs, epoch_decay, beta_1):
super(LinearDecay, self).__init__()
self.init_learning_rate = init_learning_rate
self.epochs = epochs
self.epoch_decay = epoch_decay
self.beta_1 = beta_1
self.current_learning_rate = tf.Variable(
initial_value=init_learning_rate,
trainable=False, dtype=tf.float32
)
# define functions to call
self.true_fn = lambda step: self.init_learning_rate * (1 - (1 / ((self.epochs - self.epoch_decay) * (step - self.epoch_decay))))
self.false_fn = lambda: self.init_learning_rate
def __call__(self, step):
self.current_learning_rate.assign(tf.cond(
step >= self.epoch_decay,
true_fn=self.true_fn(step),
false_fn=self.false_fn
))
class ConvDiscriminator:
"""Utility class to build the discriminator.
By the [paper](https://arxiv.org/abs/1703.10593v6) in section 4, the
generative network architecture is adopt from [Johnson et al.](
https://arxiv.org/abs/1603.08155)
"""
def __init__(self,
input_shape=(256, 256, 3),
dim=64,
num_downsamplings=3,
norm='instance_norm',
lr_scheduler=LinearDecay(0.0001, 200, 100, 0.5)):
self.norm = get_norm_layer(norm)
self.input_shape = input_shape
self.dim = dim
self.dim_ = dim
self.num_downsamplings = num_downsamplings
self.lr_scheduler = lr_scheduler
# build model
self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.lr_scheduler, beta_1=self.lr_scheduler.beta_1)
self.model = self.build()
def build(self):
x = inputs = tf.keras.Input(shape=self.input_shape)
# 1
x = tf.keras.layers.Conv2D(self.dim, 4, strides=2, padding='same', kernel_initializer=CONV_KERNEL_INITIALIZER)(x)
x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
for _ in range(self.num_downsamplings - 1):
self.dim = min(self.dim * 2, self.dim_ * 8)
x = tf.keras.layers.Conv2D(self.dim, 4, strides=2, padding='same', kernel_initializer=CONV_KERNEL_INITIALIZER, use_bias=False)(x)
x = self.norm()(x)
x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
# 2
self.dim = min(self.dim * 2, self.dim_ * 8)
x = tf.keras.layers.Conv2D(self.dim, 4, strides=1, padding='same', kernel_initializer=CONV_KERNEL_INITIALIZER, use_bias=False)(x)
x = self.norm()(x)
x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
# 3
x = tf.keras.layers.Conv2D(1, 4, strides=1, padding='same', kernel_initializer=CONV_KERNEL_INITIALIZER)(x)
model = tf.keras.Model(inputs=inputs, outputs=x)
model.compile(
optimizer=self.optimizer,
loss='mse',
metrics=['accuracy']
)
return model
discriminator = ConvDiscriminator()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment